我很难理解这是如何完成的。
想象我们有两个张量,如下所示:
a = torch.arange(0,9).view(3,3)
b = torch.arange(0,30).view(2,3,5)
并且想要计算
a@b
。
形状不匹配,因此需要广播张量
a
,为此它首先需要批量维度,因此需要a.unsqueeze(0)
,然后沿着0th
的维度进行复制操作以类似于张量b,所以 a = torch.cat(tuple(a.clone() for _ in range(2),dim=0)
可以处理这个问题,
现在我们有两个张量,一个的形状分别为 (2,3,3),另一个的形状为 (2,3,5)。如果我们将前两个维度 (2,3) 视为批量维度并得到 (6,3) 和 (6,5) 使用转置,我们可以将这两个张量相乘,但显然,形状将是 (3,5),这与 Pytorch 生成的 (2,3,5) 结果不同。
那么最后一部分是如何完成的?
这是批量乘法的情况。如果仔细观察,您会发现,如果忽略批量维度,则其余维度是兼容的,因此剩下的就是将每个子张量单独相乘。
print(f'a:\n{a}')
print(f'b:\n{b}')
c = a@b
print(f'c:\n{c}')
# which is equivalent to
a2 = a.view(1,*a.shape)
print(f'{a2.shape=}')
c2 = a2@b
print(f'c2==c:{torch.all(c2==c)}')
# which is equivalent to
a2 = a.view(1,*a.shape)
# replicate along the batch dimension
a3 = torch.cat(tuple(a2.clone() for i in range(2)),dim=0)
print(f'{a3.shape=}')
print(f'a3:\n{a3}')
print(f'b:\n{b}')
c3 = a3@b
print(f'c3==c2:{torch.all(c3==c2)}')
print(f'c3:\n{c3}')
# now do batch-multiplication
c4 = torch.zeros_like(b)
for i in range(b.shape[0]):
c4[i,...] = a3[i]@b[i]