在批处理 Torch 中重新排列 2D 张量

问题描述 投票:0回答:1

让我们有一个大小为 (batch_size, N, N) 的初始张量和一个索引张量 (batch_size, N),指定批次中每个 2D 张量中元素的新顺序。目标是根据索引张量重新排列batch中张量的元素,以获得目标张量。

目前我可以使用以下循环在 CPU 上完成此操作:

    for batch in range(batch_size):
        old_ids = indexes[batch]

        for i in range(N):
            for j in range(N):
                target[batch][i][j] = initial_tensor[batch][old_ids[i]][old_ids[j]]

我正在寻找一个等效的矢量解决方案来消除 CPU 利用率。

我尝试了利用散射和切片的各种选项,但无法找出循环的等效项。

vector pytorch gpu cpu torch
1个回答
0
投票

您正在寻找的是沿两个轴收集值:

out[b, i, j] = x[b, index[b,i], index[b,j]]

没有开箱即用的功能,您需要解决它。将您的设置与

torch.gather
的用例进行比较,这里 x 仅在单个轴上索引:
dim=1
:

out[b,i] = x[b, index[b,i]]

所以你要做的就是展平

x
,并相应地指数。这是基本设置:
B, N = 2, 4
x = torch.rand(B,N,N)
indices = torch.randint(0,N,(B,N))

您可以通过以下方式轻松获得扁平化索引:

findex = indices.repeat_interleave(N,1)*N + indices.repeat(1,N)

然后简单地展平

(N,N)
x
尺寸并使用
dim=1
findex
上应用索引:

x.flatten(1).gather(1,findex).view(B,N,N)
© www.soinside.com 2019 - 2024. All rights reserved.