我有一个Python代码,它使用numpy数组来存储多维数据。数组的维数是在运行时确定的,因此我无法提前知道确切的维数,可以是 3 到 6。不过,我知道的一件事是,当数组中存在某个维时,它有 100 个元素。所以现在我的分析代码看起来像这样:
ndim = myarray.ndim
if ndim == 2:
for p0 in range(100):
do_something(myarray[:, p0])
if ndim == 3:
for p0 in range(100):
for p1 in range(100):
do_something(myarray[:, p0, p1])
elif ndim == 4:
for p0 in range(100):
for p1 in range(100):
for p2 in range(100):
do_something(myarray[:, p0, p1, p2])
elif ndim == 5:
for p0 in range(100):
for p1 in range(100):
for p2 in range(100):
for p3 in range(100):
do_something(myarray[:, p0, p1, p2, p3])
elif ndim == 6:
for p0 in range(100):
for p1 in range(100):
for p2 in range(100):
for p3 in range(100):
for p4 in range(100):
do_something(myarray[:, p0, p1, p2, p3, p4])
当然这是可行的,但我发现代码不是很优雅而且很混乱。有更好的方法吗?我非常确定必须有一种方法可以在不知道先验维数的情况下对数组进行索引,但我在 numpy 文档中找不到正确的函数。
一个通用的解决方案是使用
np.ndindex
,它返回给定特定形状的所有 nd 索引的迭代器。在您的情况下,您只想迭代非批处理(除第一个之外的所有轴)。这可以通过选择形状元组的相应部分来实现。请参阅以下最小示例:
import numpy as np
myarray = np.arange(24).reshape(2, 3, 4)
shape = myarray.shape
def do_something(array):
print(array)
for idx in np.ndindex(shape[1:]):
do_something(myarray[(...,) + idx])
哪个打印:
[ 0 12]
[ 1 13]
[ 2 14]
[ 3 15]
[ 4 16]
[ 5 17]
[ 6 18]
[ 7 19]
[ 8 20]
[ 9 21]
[10 22]
[11 23]
无论数组的维度如何,代码的工作方式都是相同的。
我希望这有帮助!