我有以下几行:
torch.einsum("bfts,bhfs->bhts", decay_kernel, decay_q)
我有一个代码库,需要将这些 einsum 转换为原始方法,如
reshape
、matmul
、multiplication
、sum
等。有可能吗?我有很多其他的 einsum,我可以很好地转换它们,但我似乎无法理解这种情况下发生了什么。有人有线索吗?预先感谢。
回答有点晚了,你可能同时解决了。我目前正在研究类似的事情,并以这种方式解决了这个问题。也许我的回答对其遇到此问题的其他人有用。
我认为一般来说你可以转换任何
einsum
表达式。毕竟,内部也是如此。
要了解表达式的作用,查找求和的维度很有用。这些是出现在所有输入表达式中的表达式,但不会出现在输出中。在本例中,这是
f
。如果我们将输入称为 A
和 B
,将输出称为 C
,那么它的定义方式如下:
C_bhts = \sum_{f} A_bfts B_bhfs
在一维上求和是标准矩阵乘法的作用。然而,标准矩阵乘法适用于矩阵,即二维张量。这里我们有更多的维度,但由于它们应该保持完整,我们可以使用批处理
matmul
来解决整个问题,torch.matmul
支持开箱即用。不过,它预计“标准矩阵”维度排在最后,因此我们需要稍微重新安排一下。具体来说,bf
和ft
就是这里的,所以我们需要将s
移到其他地方。然后我们就可以进行批处理matmul
。对于最终结果,我们需要将 s
移回末尾。我使用以下代码说服了自己正确性:
import torch
b_size = 11
f_size = 13
t_size = 14
s_size = 15
h_size = 16
A_bfts = torch.rand(b_size, f_size, t_size, s_size)
B_bhfs = torch.rand(b_size, h_size, f_size, s_size)
C_bhts = torch.einsum("bfts, bhfs -> bhts", A_bfts, B_bhfs)
# move s axis from back to front
A_sbft = torch.moveaxis(A_bfts, 3, 0)
B_sbhf = torch.moveaxis(B_bhfs, 3, 0)
# perform batch matmul
C_sbht = torch.matmul(B_sbhf, A_sbft)
# move s axis back to last position
C_bhts_2 = torch.moveaxis(C_sbht, 0, -1)
assert torch.allclose(C_bhts, C_bhts_2)