有没有一种优雅的方法来检查是否可以在 numpy 数组中请求索引?

问题描述 投票:0回答:3

我正在寻找一种优雅的方法来检查给定索引是否位于 numpy 数组内(例如网格上的 BFS 算法)。

以下代码可以实现我想要的功能:

import numpy as np

def isValid(np_shape: tuple, index: tuple):
    if min(index) < 0:
        return False
    for ind,sh in zip(index,np_shape):
        if ind >= sh:
            return False
    return True

arr = np.zeros((3,5))
print(isValid(arr.shape,(0,0))) # True
print(isValid(arr.shape,(2,4))) # True
print(isValid(arr.shape,(4,4))) # False

但我更喜欢内置的或更优雅的东西,而不是编写自己的函数,包括 python for 循环(yikes)

python numpy built-in
3个回答
2
投票

你可以尝试:

def isValid(np_shape: tuple, index: tuple):
    index = np.array(index)
    return (index >= 0).all() and (index < arr.shape).all()

arr = np.zeros((3,5))
print(isValid(arr.shape,(0,0))) # True
print(isValid(arr.shape,(2,4))) # True
print(isValid(arr.shape,(4,4))) # False

2
投票

我对答案进行了相当多的基准测试,并得出结论,实际上我的代码中提供的显式 for 循环性能最好。

Dmitri 的解决方案由于多种原因是错误的(tuple1 < tuple2 just compares the first value; ideas like np.all(ni < sh for ind,sh in zip(index,np_shape)) fail as the input to all returns a generator, not a list etc).

@mozway 的解决方案是正确的,但所有的强制转换都会使其变慢很多。此外,它总是需要考虑所有用于转换的数字,而我认为显式循环可以更早停止。

这是我的基准测试(方法0是@mozway的解决方案,方法1是我的解决方案):

enter image description here


0
投票

对于 OP 来说,优雅和性能哪个最重要并不明显。对于不同的用例,我会采用不同的解决方案。

对于低维问题,我会简单地使用

def isValid(shape : tuple, index : tuple):
    return (0 <= index[0] < shape[0] and
            0 <= index[1] < shape[1] and
            0 <= index[2] < shape[2])

它优雅且不言自明。同样在性能方面,这种方法优于其他方法。它比OP的解决方案快大约4-5倍。

对于高维问题,我通常会选择

def isValid(shape : tuple, index : tuple):
    for i in range(len(shape)):
        if not (0 <= index[i] < shape[i]):
            return False
    return True

对于高维度来说,它比以前的解决方案更优雅,并且只慢了几纳秒。

这是我的基准测试(方法0是OP的解决方案,方法1是我的解决方案): benchmark test

© www.soinside.com 2019 - 2024. All rights reserved.