Numpy 中的高效屏蔽参数排序

问题描述 投票:0回答:1

我有一个像这样的 numpy 数组:

arr = np.array([
    [1, 2, 3],
    [4, -5, 6],
    [-1, -1, -1]
])

我想对它进行 argsort,但带有

arr <= 0
掩码。输出应该是:

array([[0, 1, 2],
       [0, 2],       # (Note that the indices are still relative to original un-masked array)
       []])

但是,我使用

np.ma.argsort()
得到的输出是:

array([[0, 1, 2],
       [0, 2, 1],
       [0, 1, 2]])

该方法需要非常高效,因为实际数组有数百万列。我认为这需要综合一些操作,但我不知道是哪些操作。

python numpy
1个回答
0
投票

输入数组

arr = np.array([
    [1, 2, 3],
    [4, -5, 6],
    [-1, -1, -1]
])

有效元素掩码

mask = arr > 0

将结果预分配为对象数组以保存可变长度索引

result = np.empty(arr.shape[0], dtype=object)

每行高效屏蔽

argsort

for i in range(arr.shape[0]):
    valid_indices = np.where(mask[i])[0]  # Get indices of valid (masked) elements
    result[i] = valid_indices[np.argsort(arr[i, valid_indices])]  # Sort valid indices by their values

输出:

[array([0, 1, 2]) array([0, 2]) array([], dtype=int64)]
© www.soinside.com 2019 - 2024. All rights reserved.