用数组索引火炬张量

问题描述 投票:0回答:1

我有以下火炬张量:

tensor([[-0.2,  0.3],
    [-0.5,  0.1],
    [-0.4,  0.2]])

和以下numpy数组:(如果需要,我可以将其转换为其他内容)

[1 0 1]

我想获得以下张量:

tensor([0.3, -0.5, 0.2])

即我希望numpy数组索引我的张量的每个子元素。最好不使用循环。

提前感谢

python indexing pytorch tensor torch
1个回答
0
投票

简单地说,将范围(len(index))用于第一维。

import torch

a = torch.tensor([[-0.2,  0.3],
    [-0.5,  0.1],
    [-0.4,  0.2]])

c = [1, 0, 1]


b = a[range(3),c]

print(b)
© www.soinside.com 2019 - 2024. All rights reserved.