所以我有 m 个不同的向量(比如 x),每个向量都是 (1,n),水平堆叠,完全在 (m,n) 矩阵中,我们称之为 B,以及矩阵 ( A) 尺寸为 (n,n).
我想计算所有向量x的xAx^T,输出应该是(m,1)
如何为给定的 B 和 A 编写 einsum 查询?
这是一个没有 einsum 的示例:
import torch
m = 30
n = 4
B = torch.randn(m, n)
A = torch.randn(n, n)
result = torch.zeros(m,1)
for i in range(m):
x = B[i].unsqueeze(0)
result[i] = torch.matmul(x, torch.matmul(A, x.T))
这相当于该循环:
result_einsum = torch.einsum('ki,ij,kj->k', B, A, B)[..., None]
[..., None]
部分的存在是为了向输出添加额外的维度。无法弄清楚如何使用纯 einsum 做到这一点,尽管额外的部分至少是零拷贝的。