为张量中的每一行选择不同的列索引

问题描述 投票:0回答:3

我有一个pytorch张量

t = torch.tensor(
            [[1.0, 1.5, 0.5, 2.0],
             [5.0, 3.0, 4.5, 5.5],
             [0.5, 1.0, 3.0, 2.0]]
)

t[:, [-1]]
给我每行的最后一列值:

tensor([[2.0000],
        [5.5000],
        [2.0000]])

但是,我想对每行不同列的值进行切片。例如,在

t
中,对于第一、第二和第三行,我想分别在 2、-1、0 索引处进行切片以获得以下张量:

tensor([[0.5],
        [5.5],
        [0.5]])

如何在

torch
中做到这一点?

python pytorch slice
3个回答
2
投票
t[[i for i in range(3)], [2, -1, 0]]

列表推导式创建一个充满行索引的列表,然后为每行指定列索引。


0
投票

您可以使用以下内容:

t = torch.tensor(
            [[1.0, 1.5, 0.5, 2.0],
             [5.0, 3.0, 4.5, 5.5],
             [0.5, 1.0, 3.0, 2.0]]
)
t
>tensor([[1.0000, 1.5000, 0.5000, 2.0000],
        [5.0000, 3.0000, 4.5000, 5.5000],
        [0.5000, 1.0000, 3.0000, 2.0000]])

rows = [0, 1, 2]
cols = [2, -1, 0]

t[rows, cols]
>tensor([0.5000, 5.5000, 0.5000])


0
投票

您可以使用

gather
功能来执行此操作。
参考:torch.gather

示例
假设我们有随机生成的张量:

a = torch.rand((3, 5)) 
print(a)

>>> tensor([[0.2646, 0.9824, 0.7346, 0.5089, 0.8017],
            [0.6044, 0.6533, 0.4774, 0.5840, 0.3478],
            [0.1689, 0.7777, 0.3727, 0.2958, 0.4059]])

我们需要索引

[0, 0, 3]
处的值(分别为第 1 行、第 2 行和第 3 行)。然后,我们有:

index = torch.Tensor([[0, 0, 3]]).type(torch.LongTensor).permute(1, 0) 
result = torch.gather(input=a, dim=1, index=index) 
print(result)

>>> tensor([[0.2646],
            [0.6044],
            [0.2958]])

**注意**:此函数不支持索引-1。
© www.soinside.com 2019 - 2024. All rights reserved.