在 numpy/torch 中 - 对于向量 v 和另一个索引向量,我们可以重新索引:
v[IX]
当我有一批向量 v 和一批索引时如何做同样的事情?
我的意思是 v - 是 v[i,:] 的二维数组 - 第 i 个向量,它应该由 IX[i,:] 重新索引。 慢Python方式就是:
for i in range(v.shape[0]):
new_v[i,:] = v[i,:][IX[i,:]]
但问题是用 numpy/torch 的方式来做——没有缓慢的 Python 循环。
我想到的想法是 - v.ravel()[ (IX + range(v.shape[0) ).ravel() ].reshape(N,-1), 但可能有更规范/可读的方式吗?
np.take_along_axis
:
import numpy as np
np.random.seed(34)
H, W = 300, 400
v = np.random.randint(0, 10, (H, W))
idxs = np.array([np.random.permutation(W) for _ in range(H)])
def reorder(v, idxs):
v1 = np.zeros_like(v)
for i in range(v.shape[0]):
v1[i, :] = v[i, :][idxs[i, :]]
return v1
v1 = reorder(v, idxs)
v2 = v[np.arange(H)[:, None], idxs]
v3 = np.take_along_axis(v, idxs, axis=1)
assert np.all(v1 == v2)
assert np.all(v2 == v3)
但速度并不快:
>>> %timeit reorder(v, idxs)
>>> %timeit v[np.arange(H)[:, None], idxs]
>>> %timeit np.take_along_axis(v, idxs, axis=1)
490 µs ± 419 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
414 µs ± 418 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
418 µs ± 537 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)