删除 numpy ndarray 每行中的特定索引

问题描述 投票:0回答:1

我有以下类型的整数数组:

import numpy as np

seed_idx = np.asarray([[0, 1],
                       [1, 2],
                       [2, 3],
                       [3, 4]], dtype=np.int_)

target_idx = np.asarray([[2,9,4,1,8],
                         [9,7,6,2,4],
                         [1,0,0,4,9],
                         [7,1,2,3,8]], dtype=np.int_)

对于

target_idx
的每一行,我想选择索引为而不是seed_idx
中的元素。由此产生的数组应该是:

[[4,1,8], [9,2,4], [1,0,9], [7,1,2]]
换句话说,我想做类似

np.take_along_axis(target_idx, seed_idx, axis=1)

的事情,但排除索引而不是保留它们。

最优雅的方法是什么?我发现找到一些整洁的东西是令人惊讶的烦人。

numpy numpy-ndarray numpy-slicing
1个回答
0
投票
您可以使用

np.put_along_axis

 屏蔽掉不需要的值,然后为其他值建立索引:

>>> np.put_along_axis(target_idx, seed_idx, -1, axis=1) >>> target_idx[np.where(target_idx != -1)].reshape(len(target_idx), -1) array([[4, 1, 8], [9, 2, 4], [1, 0, 9], [7, 1, 2]])
    
© www.soinside.com 2019 - 2024. All rights reserved.