我正在寻找一种优雅的方法来检查给定索引是否位于 numpy 数组内(例如网格上的 BFS 算法)。
以下代码可以实现我想要的功能:
import numpy as np
def isValid(np_shape: tuple, index: tuple):
if min(index) < 0:
return False
for ind,sh in zip(index,np_shape):
if ind >= sh:
return False
return True
arr = np.zeros((3,5))
print(isValid(arr.shape,(0,0))) # True
print(isValid(arr.shape,(2,4))) # True
print(isValid(arr.shape,(4,4))) # False
但我更喜欢内置的或更优雅的东西,而不是编写自己的函数,包括 python for 循环(yikes)
你可以尝试:
def isValid(np_shape: tuple, index: tuple):
index = np.array(index)
return (index >= 0).all() and (index < arr.shape).all()
arr = np.zeros((3,5))
print(isValid(arr.shape,(0,0))) # True
print(isValid(arr.shape,(2,4))) # True
print(isValid(arr.shape,(4,4))) # False
我对答案进行了相当多的基准测试,并得出结论,实际上我的代码中提供的显式 for 循环性能最好。
Dmitri 的解决方案由于多种原因是错误的(tuple1 < tuple2 just compares the first value; ideas like np.all(ni < sh for ind,sh in zip(index,np_shape)) fail as the input to all returns a generator, not a list etc).
@mozway 的解决方案是正确的,但所有的强制转换都会使其变慢很多。此外,它总是需要考虑所有用于转换的数字,而我认为显式循环可以更早停止。
这是我的基准测试(方法0是@mozway的解决方案,方法1是我的解决方案):
对于 OP 来说,优雅和性能哪个最重要并不明显。对于不同的用例,我会采用不同的解决方案。
对于低维问题,我会简单地使用
def isValid(shape : tuple, index : tuple):
return (0 <= index[0] < shape[0] and
0 <= index[1] < shape[1] and
0 <= index[2] < shape[2])
它优雅且不言自明。同样在性能方面,这种方法优于其他方法。它比OP的解决方案快大约4-5倍。
对于高维问题,我通常会选择
def isValid(shape : tuple, index : tuple):
for i in range(len(shape)):
if not (0 <= index[i] < shape[i]):
return False
return True
对于高维度来说,它比以前的解决方案更优雅,并且只慢了几纳秒。