torch.einsum 到 matmul 转换

问题描述 投票:0回答:1

我有以下几行:

torch.einsum("bfts,bhfs->bhts", decay_kernel, decay_q)

我有一个代码库,需要将这些 einsum 转换为原始方法,如

reshape
matmul
multiplication
sum
等。有可能吗?我有很多其他的 einsum,我可以很好地转换它们,但我似乎无法理解这种情况下发生了什么。有人有线索吗?预先感谢。

python torch
1个回答
0
投票

回答有点晚了,你可能同时解决了。我目前正在研究类似的事情,并以这种方式解决了这个问题。也许我的回答对其遇到此问题的其他人有用。

我认为一般来说你可以转换任何

einsum
表达式。毕竟,内部也是如此。

要了解表达式的作用,查找求和的维度很有用。这些是出现在所有输入表达式中的表达式,但不会出现在输出中。在本例中,这是

f
。如果我们将输入称为
A
B
,将输出称为
C
,那么它的定义方式如下:

C_bhts = \sum_{f} A_bfts B_bhfs

在一维上求和是标准矩阵乘法的作用。然而,标准矩阵乘法适用于矩阵,即二维张量。这里我们有更多的维度,但由于它们应该保持完整,我们可以使用批处理

matmul
来解决整个问题,
torch.matmul
支持开箱即用。不过,它预计“标准矩阵”维度排在最后,因此我们需要稍微重新安排一下。具体来说,
bf
ft
就是这里的,所以我们需要将
s
移到其他地方。然后我们就可以进行批处理
matmul
。对于最终结果,我们需要将
s
移回末尾。我使用以下代码说服了自己正确性:

import torch

b_size = 11
f_size = 13
t_size = 14
s_size = 15
h_size = 16

A_bfts = torch.rand(b_size, f_size, t_size, s_size)
B_bhfs = torch.rand(b_size, h_size, f_size, s_size)
C_bhts = torch.einsum("bfts, bhfs -> bhts", A_bfts, B_bhfs)

# move s axis from back to front
A_sbft = torch.moveaxis(A_bfts, 3, 0)
B_sbhf = torch.moveaxis(B_bhfs, 3, 0)
# perform batch matmul
C_sbht = torch.matmul(B_sbhf, A_sbft)
# move s axis back to last position
C_bhts_2 = torch.moveaxis(C_sbht, 0, -1)

assert torch.allclose(C_bhts, C_bhts_2)
© www.soinside.com 2019 - 2024. All rights reserved.