我有以下类型的整数数组:
import numpy as np
seed_idx = np.asarray([[0, 1],
[1, 2],
[2, 3],
[3, 4]], dtype=np.int_)
target_idx = np.asarray([[2,9,4,1,8],
[9,7,6,2,4],
[1,0,0,4,9],
[7,1,2,3,8]], dtype=np.int_)
对于
target_idx
的每一行,我想选择索引为而不是seed_idx
中的元素。由此产生的数组应该是:
[[4,1,8],
[9,2,4],
[1,0,9],
[7,1,2]]
换句话说,我想做类似np.take_along_axis(target_idx, seed_idx, axis=1)
的事情,但排除索引而不是保留它们。最优雅的方法是什么?我发现找到一些整洁的东西是令人惊讶的烦人。
>>> np.put_along_axis(target_idx, seed_idx, -1, axis=1)
>>> target_idx[np.where(target_idx != -1)].reshape(len(target_idx), -1)
array([[4, 1, 8],
[9, 2, 4],
[1, 0, 9],
[7, 1, 2]])