将张量 A 的值投影到索引 C 处的张量 B (pytorch)

问题描述 投票:0回答:1

鉴于:

A = tensor([[ 0.4821, -0.3484,  0.0915, -0.1870],
            [ 1.3817,  0.3011,  1.0704,  2.1717]])

B = torch.zeros(2,6)

C =  torch.tensor([[1,2,2,3], [3,7,2,5]]) (same shape of A)

我想在索引 C 处用 A 替换 B 中的值,其中 < 6

(B.size(-1))

-> B =[[0, 0.4821, 0.0915, -0.1870, 0, 0],
       [0, 0, 0.3011, 1.3817, 0, 2.1717]]

注意:C的第一行在A的第二个和第三个位置有两个2。这里我想得到最大值(或者如果你认为更有可能的话求和)

pytorch vectorization
1个回答
0
投票

您可以做的是使用

散射操作
来剪辑索引,使它们不会超过B.size(1),在这种情况下,最后一个元素将覆盖另一个元素(仅保留第二个
2
)。您还可以使用专门的函数来通过求和进行累加或减少到最大值。让我们试试这个:

torch.zeros(2,6).scatter_(1, C.clip(0,B.size(1)-1), A)
tensor([[ 0.0000,  0.4821,  0.0915, -0.1870,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.0704,  1.3817,  0.0000,  2.1717]])

但这并不总是有效,因为超过最大长度的索引将被放置在末尾。解决方案可能是连接一个额外的缓冲区列来考虑不需要的值,然后修剪最后的张量以丢弃这些值:

torch.zeros(2,6+1).scatter_(1, C.clip(0,B.size(1)-1), A)[:,:-1]
tensor([[ 0.0000,  0.4821,  0.0915, -0.1870,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.0704,  1.3817,  0.0000,  2.1717]])
© www.soinside.com 2019 - 2024. All rights reserved.