pytorch如何进行最后维度不同的两个张量的乘法?

问题描述 投票:0回答:1

我很难理解这是如何完成的。
想象我们有两个张量,如下所示:

a = torch.arange(0,9).view(3,3)
b = torch.arange(0,30).view(2,3,5)

并且想要计算

a@b

形状不匹配,因此需要广播张量

a
,为此它首先需要批量维度,因此需要
a.unsqueeze(0)
,然后沿着
0th
的维度进行复制操作以类似于张量b,所以
a = torch.cat(tuple(a.clone() for _ in range(2),dim=0)
可以处理这个问题, 现在我们有两个张量,一个的形状分别为 (2,3,3),另一个的形状为 (2,3,5)。
现在我们该如何继续?最后的维度不一样,所以这里没有元素乘积。

如果我们将前两个维度 (2,3) 视为批量维度并得到 (6,3) 和 (6,5) 使用转置,我们可以将这两个张量相乘,但显然,形状将是 (3,5),这与 Pytorch 生成的 (2,3,5) 结果不同。

那么最后一部分是如何完成的?

pytorch torch
1个回答
0
投票

这是批量乘法的情况。如果仔细观察,您会发现,如果忽略批量维度,则其余维度是兼容的,因此剩下的就是将每个子张量单独相乘。

print(f'a:\n{a}')
print(f'b:\n{b}')
c = a@b 
print(f'c:\n{c}')
# which is equivalent to 
a2 = a.view(1,*a.shape)
print(f'{a2.shape=}')
c2 = a2@b 
print(f'c2==c:{torch.all(c2==c)}')
# which is equivalent to 
a2 = a.view(1,*a.shape)
# replicate along the batch dimension
a3 = torch.cat(tuple(a2.clone() for i in range(2)),dim=0)
print(f'{a3.shape=}')
print(f'a3:\n{a3}')
print(f'b:\n{b}')
c3 = a3@b 
print(f'c3==c2:{torch.all(c3==c2)}')
print(f'c3:\n{c3}')

# now do batch-multiplication
c4 = torch.zeros_like(b)
for i in range(b.shape[0]):
    c4[i,...] = a3[i]@b[i]
© www.soinside.com 2019 - 2024. All rights reserved.