如何简化3D张量切片

问题描述 投票:0回答:0

我想在 PyTorch 中对 3D 张量进行切片。 3D 张量 src_tensor 的形状是 (batch, max_len, hide_dim),我有一个形状为 (batch,) 的 1D 索引向量索引。我想沿着 src_tensor 的第二个维度进行切片。我可以通过以下代码实现此功能:

import torch
nums = 30
l = [i for i in range(nums)]
src_tensor = torch.Tensor(l).reshape((3,5,2))
indices = [1,2,3]
slice_tensor = torch.zeros((3,2,2)) 
for i in range(3):
    p1,p2 = indices[i],indices[i]+1
    slice_tensor[i,:,:]=src_tensor[i,[p1,p2],:]
print(src_tensor)
print(indices)
print(slice_tensor)
"""
tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.],
         [ 6.,  7.],
         [ 8.,  9.]],

        [[10., 11.],
         [12., 13.],
         [14., 15.],
         [16., 17.],
         [18., 19.]],

        [[20., 21.],
         [22., 23.],
         [24., 25.],
         [26., 27.],
         [28., 29.]]])
[1, 2, 3]
tensor([[[ 2.,  3.],
         [ 4.,  5.]],

        [[14., 15.],
         [16., 17.]],

        [[26., 27.],
         [28., 29.]]])
"""

我的问题是上面的代码是否可以简化,比如去掉for循环。

python pytorch slice tensor
© www.soinside.com 2019 - 2024. All rights reserved.