我有一个像这样的 numpy 数组:
arr = np.array([
[1, 2, 3],
[4, -5, 6],
[-1, -1, -1]
])
我想对它进行 argsort,但带有
arr <= 0
掩码。输出应该是:
array([[0, 1, 2],
[0, 2], # (Note that the indices are still relative to original un-masked array)
[]])
但是,我使用
np.ma.argsort()
得到的输出是:
array([[0, 1, 2],
[0, 2, 1],
[0, 1, 2]])
该方法需要非常高效,因为实际数组有数百万列。我认为这需要综合一些操作,但我不知道是哪些操作。
输入数组
arr = np.array([
[1, 2, 3],
[4, -5, 6],
[-1, -1, -1]
])
有效元素掩码
mask = arr > 0
将结果预分配为对象数组以保存可变长度索引
result = np.empty(arr.shape[0], dtype=object)
每行高效屏蔽
argsort
for i in range(arr.shape[0]):
valid_indices = np.where(mask[i])[0] # Get indices of valid (masked) elements
result[i] = valid_indices[np.argsort(arr[i, valid_indices])] # Sort valid indices by their values
输出:
[array([0, 1, 2]) array([0, 2]) array([], dtype=int64)]