4D和2D矩阵的PyTorch广播乘法?

问题描述 投票:0回答:1

如何广播以将这两个矩阵相乘?

x: torch.Size([10, 120, 180, 30]) # (N, H, W, C)
W: torch.Size([64, 30]) # (Y, C)

输出应为:

(10, 120, 180, 64) == (N, H, W, Y)
python pytorch broadcast torch
1个回答
0
投票

我假设x是具有批次的某种示例,w矩阵是相应的权重。在这种情况下,您可以简单地做:

out = x @ w.T

这是张量乘法,而不是元素方式的。您无法进行逐元素乘法来获得这种形状,并且此操作没有意义。您所能做的就是以某种方式将两个矩阵unsqueeze广播,并对由于某些原因您不想要的维应用某些运算:

x : torch.Size([10, 120, 180, 30, 1])
W: torch.Size([1, 1, 1, 30, 64]) # transposition would be needed as well

在这样的unsqueezing之后,您可以沿着第三个x*w进行summeandim以获得所需的形状。

为清楚起见,两种方式均等效。

© www.soinside.com 2019 - 2024. All rights reserved.