选择 numpy 数组上沿第 m 个轴的第 n 个元素

问题描述 投票:0回答:1

我的目标是使用包含整数的 numpy 数组,并使用它来选择 2D 数组相应行中的元素。二维数组沿着第一个轴单调递增,我试图选择每行等于或超过某个阈值的元素。

我有一个大小为 Nx10 的二维数组,其中第二轴上的值单调递增。例如:

sizes = np.array([[1, 3, 6, 6, 6, 7, 8, 8, 10, 10],  [2, 3, 3, 7, 7, 7, 9, 9, 10, 11],  [2, 3, 3, 5, 5, 6, 9, 9, 10, 11],  [2, 3, 3, 9, 9, 9, 9, 9, 10, 11]])

目标是找到等于或超过 5 的元素的索引,由以下代码完成:

threshold_indeces = np.argmax(sizes >= 5, axis=1)
现在我想提取所有行的值。直觉上,我会运行这个:
values = sizes[:, threshold_indeces]
并期望输出为 [6, 7, 5, 9]。但是,我得到一个大小为 NxN 的二维数组,在本例中等于
array([[6, 6, 6, 6], [3, 7, 7, 7], [3, 5, 5, 5], [3, 9, 9, 9]])

我看到我的预期输出在第 1、2 和 3 列中重复,但当我运行数百万行时,出现内存分配错误。

我做错了什么以及如何才能获得值的一维输出?

python python-3.x numpy numpy-ndarray numpy-slicing
1个回答
0
投票

您实际上是在告诉 numpy 您希望列的所有行中的值与

threshold_indeces
中的值相对应(顺便说一下,它的拼写是索引)。您还需要告诉 numpy 每列对应的行。这可以通过使用
np.arange
对行进行索引来完成。

values = sizes[np.arange(threshold_indeces.size), threshold_indeces]
© www.soinside.com 2019 - 2024. All rights reserved.