使用不同索引的列表沿轴对数组进行切片

问题描述 投票:0回答:1

我有一个具有 3 个维度

(a, b, c)
的多维数组/张量,并且我有一个不同索引的长度
a
列表,每个索引都在
[0, b)
范围内。我想使用索引来获取大小为
(a, c)
的数组。现在我用一个丑陋的列表理解来做到这一点

z = torch.stack([t_[b, :] for t_, b in zip(tensor, B)])

这是在神经网络的前向传递中实现的,所以我真的想避免列表理解。是否有任何 torch(或 numpy)函数可以更高效地完成我想要的事情?

还有一个小例子:

tensor = [[[ 0,  1],
           [ 2,  3],
           [ 4,  5]],
          [[ 6,  7],
           [ 8,  9],
           [10, 11]],
          [[12, 13],
           [14, 15],
           [16, 17]],
          [[18, 19],
           [20, 21],
           [22, 23]]]  # shape: (4, 3, 2)
B = [0, 1, 2, 2]
output = [[ 0,  1],
          [ 8,  9],
          [16, 17],
          [22, 23]]  # shape (4, 2)

背景:我有时间序列数据,其中有不同长度的时间窗口。我使用火炬的

pack_padded_sequence
(和反向)来屏蔽它,但我必须在屏蔽开始之前的时间步获取
LSTM
的输出,因为这样网络的输出将变得不可用。在示例中,我将有 4 个长度为
0, 1, 2, 2
的时间步长,每个时间步长有 2 个特征。

numpy torch numpy-slicing
1个回答
0
投票

使用高级索引。为了获得所需的输出,我们需要第一个轴的相应索引,该索引是使用下面的

torch.arange()
创建的:

output = tensor[torch.arange(len(B)), B]

或使用numpy

output = tensor[np.arange(len(B)), B]

两者都产生:

tensor([[ 0,  1],
        [ 8,  9],
        [16, 17],
        [22, 23]])

使用示例的完整代码:

import torch
tensor = torch.tensor([
    [[ 0,  1],
     [ 2,  3],
     [ 4,  5]],
    [[ 6,  7],
     [ 8,  9],
     [10, 11]],
    [[12, 13],
     [14, 15],
     [16, 17]],
    [[18, 19],
     [20, 21],
     [22, 23]]])
B = [0, 1, 2, 2]
output = tensor[torch.arange(len(B)), B]
© www.soinside.com 2019 - 2024. All rights reserved.