我有以下火炬张量:
tensor([[-0.2, 0.3],
[-0.5, 0.1],
[-0.4, 0.2]])
和以下numpy数组:(如果需要,我可以将其转换为其他内容)
[1 0 1]
我想获得以下张量:
tensor([0.3, -0.5, 0.2])
即我希望numpy数组索引我的张量的每个子元素。最好不使用循环。
提前感谢
简单地说,将范围(len(index))用于第一维。
import torch
a = torch.tensor([[-0.2, 0.3],
[-0.5, 0.1],
[-0.4, 0.2]])
c = [1, 0, 1]
b = a[range(3),c]
print(b)