如何过滤数组以仅保留其重复元素?

问题描述 投票:1回答:1

我有一个NumPy数组。我想从中创建一个新的,只包含重复的元素。例如,在数组之前可能看起来像

[[  3.   0.   1.   0.  12.   1.]
 [ 14.   0.   2.   2.   0.   3.]
 [  3.   0.   1.   2.   0.   3.]
 [ 12.   0.  14.   0.  12.   1.]
 [ 14.   0.   2.  12.   0.  14.]
 [ 15.   4.  13.  13.  14.  15.]
 [ 14.   2.  15.  13.  14.  15.]]

操作后我希望它看起来像

[[ 1.   0.  ]
 [ 0.   2.  ]
 [  3.  0.  ]
 [ 12.  0.  ]
 [ 14.  0.  ]
 [ 15.  13. ]
 [ 14.  15. ]]

现在,我会使用for循环来做,但也许你们中的某个人知道更顺畅,更快捷的方式。

python arrays numpy
1个回答
1
投票

您无法在单个numpy步骤中执行此操作,因为重复项的长度可能会在行与行之间发生变化。

我建议你做以下事情。

定义一个函数来查找重复项:

def dups(a):
    uniques, counts = np.unique(a, return_counts=True)
    return uniques[np.where(counts > 1)]

然后将其应用于数组的每一行:

ans = [dups(row) for row in arr]

对于所有行具有相同数量的重复项的情况,您可以使用ans创建一个numpy数组:

ans = np.stack(ans)

对于您的示例案例,它会打印:

[[  0.   1.]
 [  0.   2.]
 [  0.   3.]
 [  0.  12.]
 [  0.  14.]
 [ 13.  15.]
 [ 14.  15.]]
© www.soinside.com 2019 - 2024. All rights reserved.