让我们有一个大小为 (batch_size, N, N) 的初始张量和一个索引张量 (batch_size, N),指定批次中每个 2D 张量中元素的新顺序。目标是根据索引张量重新排列batch中张量的元素,以获得目标张量。
目前我可以使用以下循环在 CPU 上完成此操作:
for batch in range(batch_size):
old_ids = indexes[batch]
for i in range(N):
for j in range(N):
target[batch][i][j] = initial_tensor[batch][old_ids[i]][old_ids[j]]
我正在寻找一个等效的矢量解决方案来消除 CPU 利用率。
我尝试了利用散射和切片的各种选项,但无法找出循环的等效项。
您正在寻找的是沿两个轴收集值:
out[b, i, j] = x[b, index[b,i], index[b,j]]
torch.gather
的用例进行比较,这里 x 仅在单个轴上索引:dim=1
:
out[b,i] = x[b, index[b,i]]
所以你要做的就是展平
x
,并相应地指数。这是基本设置:B, N = 2, 4
、x = torch.rand(B,N,N)
和indices = torch.randint(0,N,(B,N))
。
您可以通过以下方式轻松获得扁平化索引:
findex = indices.repeat_interleave(N,1)*N + indices.repeat(1,N)
然后简单地展平
(N,N)
的 x
尺寸并使用 dim=1
在 findex
上应用索引:
x.flatten(1).gather(1,findex).view(B,N,N)