这个问题在这里已有答案:
我试图提取NumPy 2D数组的四个角元素:
import numpy as np
data = np.arange(16).reshape((4, -1))
#array([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11],
# [12, 13, 14, 15]])
预期的输出是[[0,3],[12,15]]
或[0,3,12,15]
(任何事情都有)。真正的2D花式索引仅提供主对角线的末端:
data[[0,-1],[0,-1]]
#array([ 0, 15])
伪2D花式索引(第一行,然后是列式)提供了正确的答案,但看起来很尴尬:
data[[0,-1]][:,[0,-1]]
#array([[ 0, 3],
# [12, 15]])
有没有办法使用真正的花式索引,例如data[XXX,YYY]
,其中XXX
和YYY
是列表/数组/切片,以提取所有四个角?
你可以做:
data[[0, 0, -1, -1], [0, -1, 0, -1]]
这有两种可能性。 (好吧,第一个实际上并不奇特):
>>> a = np.arange(9).reshape(3, 3)
>>>
>>> m, n = a.shape
>>> a[::m-1, ::n-1]
array([[0, 2],
[6, 8]])
>>>
>>> a[np.ix_((0,-1), (0,-1))]
array([[0, 2],
[6, 8]])
更明确地说:
>>> idx = np.ix_((0,-1), (0,-1))
>>> idx
(array([[ 0],
[-1]]), array([[ 0, -1]]))
>>> a[idx]
array([[0, 2],
[6, 8]])
诀窍是利用指数上的广播。 np.ix_
知道如何做的细节。