批量 PyTorch 张量索引

问题描述 投票:0回答:1

我有一个大小为 (3, 2) 的源张量和一个大小为 (3, 3) 的索引张量,其中包含整数值 0、1 或 2。在 pytorch 中,我可以进行张量索引

source[index]
来获取大小为的张量(3,3,2)。示例:

source: 
tensor([[1, 6],
        [2, 3],
        [8, 0]])

index: 
tensor([[2, 1, 2],
        [1, 1, 2],
        [2, 0, 0]])

source[index]: 
tensor([[[8, 0],
         [2, 3],
         [8, 0]],

        [[2, 3],
         [2, 3],
         [8, 0]],

        [[8, 0],
         [1, 6],
         [1, 6]]])

我想做以上操作但是批量了
例如,批量大小为 2:
源形状 --> (2, 3, 2)
索引形状 --> (2, 3, 3)
批量

source[index]
形状 --> (2, 3, 3, 2)
我可以通过循环轻松完成此操作,但我想知道是否可以使用 torch.gather 或其他内置函数有效地完成此操作?

pytorch tensor
1个回答
0
投票

我能够通过一些重塑来达到预期的效果:

source = torch.randint(0, 10, (2, 3, 2))
index = torch.randint(0, 3, (2, 3, 3))
index = index.flatten(-2)
index = index[:, :, None].expand(-1, -1, 2)
out = source.gather(1, index).view(2, 3, 3, 2)
© www.soinside.com 2019 - 2024. All rights reserved.