我有一个大小为 (3, 2) 的源张量和一个大小为 (3, 3) 的索引张量,其中包含整数值 0、1 或 2。在 pytorch 中,我可以进行张量索引
source[index]
来获取大小为的张量(3,3,2)。示例:
source:
tensor([[1, 6],
[2, 3],
[8, 0]])
index:
tensor([[2, 1, 2],
[1, 1, 2],
[2, 0, 0]])
source[index]:
tensor([[[8, 0],
[2, 3],
[8, 0]],
[[2, 3],
[2, 3],
[8, 0]],
[[8, 0],
[1, 6],
[1, 6]]])
我想做以上操作但是批量了
例如,批量大小为 2:
源形状 --> (2, 3, 2)
索引形状 --> (2, 3, 3)
批量
source[index]
形状 --> (2, 3, 3, 2)我能够通过一些重塑来达到预期的效果:
source = torch.randint(0, 10, (2, 3, 2))
index = torch.randint(0, 3, (2, 3, 3))
index = index.flatten(-2)
index = index[:, :, None].expand(-1, -1, 2)
out = source.gather(1, index).view(2, 3, 3, 2)