所以我有 m 个不同的向量(比如说 x),每个向量都是 (1,n),水平堆叠,完全在一个 (m,n) 矩阵中,我们称之为 B,以及一个维度为 (n) 的矩阵 (A) ,n).
我想计算所有向量x的x@A^T@A@x^T@x@A@A^T@x^T,输出应该是(m,1)
如何为给定的 B 和 A 编写
einsum
查询?
这是一个没有
einsum
的示例:
import torch
m = 30
n = 4
B = torch.randn(m, n)
A = torch.randn(n, n)
result = torch.zeros(m,1)
for i in range(m):
x = B[i].unsqueeze(0)
result[i] = torch.matmul(x,torch.matmul(A.T,torch.matmul(A,torch.matmul(x.T,torch.matmul(x,torch.matmul(A, torch.matmul(A.T, x.T)))))))
我可以为 xAx^T 编写查询,但不能为 xA^TAx^TxBB^Tx^T 编写查询。这是 xAx^T:
torch.einsum('bi,ij,bj -> b',B,A,B)
您可以拨打以下电话:
result_2 = torch.einsum('ab,bc,cd,da,ae,ef,fg,ga->a', B, A.T, A, B.T, B, A, A.T, B.T)
关键是索引
a
。该索引表示 B 的行。根据您的描述,您不想在 B 和 A 之间进行矩阵乘法。例如,您希望所谓的“B @ A.T @ A @ B.T@”成为长度向量m,而如果你进行矩阵乘法,它将是一个大小为 (m,m) 的矩阵。通过每次引用 a
的行时使用相同的索引 B
,我们可以让每一行与其自身“合并”,而不是与其他行混合。
例如,考虑替代方案
result_3 = torch.einsum('ab,bc,cd,de,ef,fg,gh,ha->a', B, A.T, A, B.T, B, A, A.T, B.T)
此替代方案将仅进行矩阵乘法,并返回结果 (m,m) 矩阵的对角线元素。