我有一个pytorch张量
t = torch.tensor(
[[1.0, 1.5, 0.5, 2.0],
[5.0, 3.0, 4.5, 5.5],
[0.5, 1.0, 3.0, 2.0]]
)
t[:, [-1]]
给我每行的最后一列值:
tensor([[2.0000],
[5.5000],
[2.0000]])
但是,我想对每行不同列的值进行切片。例如,在
t
中,对于第一、第二和第三行,我想分别在 2、-1、0 索引处进行切片以获得以下张量:
tensor([[0.5],
[5.5],
[0.5]])
如何在
torch
中做到这一点?
t[[i for i in range(3)], [2, -1, 0]]
列表推导式创建一个充满行索引的列表,然后为每行指定列索引。
您可以使用以下内容:
t = torch.tensor(
[[1.0, 1.5, 0.5, 2.0],
[5.0, 3.0, 4.5, 5.5],
[0.5, 1.0, 3.0, 2.0]]
)
t
>tensor([[1.0000, 1.5000, 0.5000, 2.0000],
[5.0000, 3.0000, 4.5000, 5.5000],
[0.5000, 1.0000, 3.0000, 2.0000]])
rows = [0, 1, 2]
cols = [2, -1, 0]
t[rows, cols]
>tensor([0.5000, 5.5000, 0.5000])
您可以使用
gather
功能来执行此操作。
a = torch.rand((3, 5))
print(a)
>>> tensor([[0.2646, 0.9824, 0.7346, 0.5089, 0.8017],
[0.6044, 0.6533, 0.4774, 0.5840, 0.3478],
[0.1689, 0.7777, 0.3727, 0.2958, 0.4059]])
我们需要索引
[0, 0, 3]
处的值(分别为第 1 行、第 2 行和第 3 行)。然后,我们有:
index = torch.Tensor([[0, 0, 3]]).type(torch.LongTensor).permute(1, 0)
result = torch.gather(input=a, dim=1, index=index)
print(result)
>>> tensor([[0.2646],
[0.6044],
[0.2958]])