鉴于:
A = tensor([[ 0.4821, -0.3484, 0.0915, -0.1870],
[ 1.3817, 0.3011, 1.0704, 2.1717]])
B = torch.zeros(2,6)
C = torch.tensor([[1,2,2,3], [3,7,2,5]]) (same shape of A)
我想在索引 C 处用 A 替换 B 中的值,其中 < 6
(B.size(-1))
-> B =[[0, 0.4821, 0.0915, -0.1870, 0, 0],
[0, 0, 0.3011, 1.3817, 0, 2.1717]]
注意:C的第一行在A的第二个和第三个位置有两个2。这里我想得到最大值(或者如果你认为更有可能的话求和)
您可以做的是使用
散射操作来剪辑索引,使它们不会超过
B.size(1)
,在这种情况下,最后一个元素将覆盖另一个元素(仅保留第二个2
)。您还可以使用专门的函数来通过求和进行累加或减少到最大值。让我们试试这个:
torch.zeros(2,6).scatter_(1, C.clip(0,B.size(1)-1), A)
tensor([[ 0.0000, 0.4821, 0.0915, -0.1870, 0.0000, 0.0000],
[ 0.0000, 0.0000, 1.0704, 1.3817, 0.0000, 2.1717]])
但这并不总是有效,因为超过最大长度的索引将被放置在末尾。解决方案可能是连接一个额外的缓冲区列来考虑不需要的值,然后修剪最后的张量以丢弃这些值:
torch.zeros(2,6+1).scatter_(1, C.clip(0,B.size(1)-1), A)[:,:-1]
tensor([[ 0.0000, 0.4821, 0.0915, -0.1870, 0.0000, 0.0000],
[ 0.0000, 0.0000, 1.0704, 1.3817, 0.0000, 2.1717]])