我想在 PyTorch 中对 3D 张量进行切片。 3D 张量 src_tensor 的形状是 (batch, max_len, hide_dim),我有一个形状为 (batch,) 的 1D 索引向量索引。我想沿着 src_tensor 的第二个维度进行切片。我可以通过以下代码实现此功能:
import torch
nums = 30
l = [i for i in range(nums)]
src_tensor = torch.Tensor(l).reshape((3,5,2))
indices = [1,2,3]
slice_tensor = torch.zeros((3,2,2))
for i in range(3):
p1,p2 = indices[i],indices[i]+1
slice_tensor[i,:,:]=src_tensor[i,[p1,p2],:]
print(src_tensor)
print(indices)
print(slice_tensor)
"""
tensor([[[ 0., 1.],
[ 2., 3.],
[ 4., 5.],
[ 6., 7.],
[ 8., 9.]],
[[10., 11.],
[12., 13.],
[14., 15.],
[16., 17.],
[18., 19.]],
[[20., 21.],
[22., 23.],
[24., 25.],
[26., 27.],
[28., 29.]]])
[1, 2, 3]
tensor([[[ 2., 3.],
[ 4., 5.]],
[[14., 15.],
[16., 17.]],
[[26., 27.],
[28., 29.]]])
"""
我的问题是上面的代码是否可以简化,比如去掉for循环。