我读过其他问题,其他人试图用 np.partition
或
np.argpartition找到 k 最小值。然而,这些看起来像是简单的一维或二维数组来操作。
我拥有的是一个 3D 数组,我在其中获取 2D 切片并尝试找到 4 个最小值:
import numpy as np
np.random.seed(556)
big_array = np.random.randint(0,200,size = (15,9,3))
for i in range(3):
min_vals = np.sort(big_array[:,:,i].flatten())[:4]
print("The smallest values for last axis #{} are: {:.1f},{:.1f},{:.1f},{:.1f}".format(i,*min_vals))
不幸的是,当第 k 个参数大于您尝试指定的任何轴的尺寸时,似乎
np.partition
不喜欢它,可能是因为我不完全理解语法。同样,它看起来不像 np.sort
或 np.argsort
有多个轴参数的选项。
有没有一种方法可以在不使用 for 循环或列表理解的情况下完成此任务?
IIUC,您可以沿第一个轴对重塑后的数组进行排序并从那里提取结果。你的循环和展平与:
big_array.reshape(-1, big_array.shape[-1])
所以,你的代码将是:
top_4_per_last_axis = np.sort(big_array.reshape(-1,big_array.shape[-1]), 0)[:4]
结果:
array([[ 0, 2, 0],
[ 1, 3, 2],
[ 1, 6, 3],
[ 2, 9, 12]])