Numpy/torch:按批次索引重新索引批次向量

问题描述 投票:0回答:1

在 numpy/torch 中 - 对于向量 v 和另一个索引向量,我们可以重新索引:

v[IX] 

当我有一批向量 v 和一批索引时如何做同样的事情?

我的意思是 v - 是 v[i,:] 的二维数组 - 第 i 个向量,它应该由 IX[i,:] 重新索引。 慢Python方式就是:

for i in range(v.shape[0]):
    new_v[i,:] = v[i,:][IX[i,:]]

但问题是用 numpy/torch 的方式来做——没有缓慢的 Python 循环。

我想到的想法是 - v.ravel()[ (IX + range(v.shape[0) ).ravel() ].reshape(N,-1), 但可能有更规范/可读的方式吗?

python numpy torch
1个回答
0
投票

可读的方式是使用

np.take_along_axis
:

import numpy as np

np.random.seed(34)

H, W = 300, 400

v = np.random.randint(0, 10, (H, W))
idxs = np.array([np.random.permutation(W) for _ in range(H)])

def reorder(v, idxs):
    v1 = np.zeros_like(v)
    for i in range(v.shape[0]):
        v1[i, :] = v[i, :][idxs[i, :]]
    return v1

v1 = reorder(v, idxs)
v2 = v[np.arange(H)[:, None], idxs]
v3 = np.take_along_axis(v, idxs, axis=1)

assert np.all(v1 == v2)
assert np.all(v2 == v3)

但速度并不快:

>>> %timeit reorder(v, idxs)
>>> %timeit v[np.arange(H)[:, None], idxs]
>>> %timeit np.take_along_axis(v, idxs, axis=1)
490 µs ± 419 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
414 µs ± 418 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
418 µs ± 537 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
© www.soinside.com 2019 - 2024. All rights reserved.