使用 einsum 进行转置时间矩阵时间转置:x@A^T@A@x^T@x@A@A^T@x^T

问题描述 投票:0回答:1

所以我有 m 个不同的向量(比如说 x),每个向量都是 (1,n),水平堆叠,完全在一个 (m,n) 矩阵中,我们称之为 B,以及一个维度为 (n) 的矩阵 (A) ,n).

我想计算所有向量x的x@A^T@A@x^T@x@A@A^T@x^T,输出应该是(m,1)

如何为给定的 B 和 A 编写

einsum
查询?

这是一个没有

einsum
的示例:

import torch
m = 30
n = 4
B = torch.randn(m, n)
A = torch.randn(n, n)
result = torch.zeros(m,1)
for i in range(m):
    x = B[i].unsqueeze(0)
    result[i] = torch.matmul(x,torch.matmul(A.T,torch.matmul(A,torch.matmul(x.T,torch.matmul(x,torch.matmul(A, torch.matmul(A.T, x.T)))))))

我可以为 xAx^T 编写查询,但不能为 xA^TAx^TxBB^Tx^T 编写查询。这是 xAx^T:

torch.einsum('bi,ij,bj -> b',B,A,B)
python numpy pytorch torch
1个回答
0
投票

您可以拨打以下电话:

result_2 = torch.einsum('ab,bc,cd,da,ae,ef,fg,ga->a', B, A.T, A, B.T, B, A, A.T, B.T)

关键是索引

a
。该索引表示 B 的行。根据您的描述,您不想在 B 和 A 之间进行矩阵乘法。例如,您希望所谓的“B @ A.T @ A @ B.T@”成为长度向量m,而如果你进行矩阵乘法,它将是一个大小为 (m,m) 的矩阵。通过每次引用
a
的行时使用相同的索引
B
,我们可以让每一行与其自身“合并”,而不是与其他行混合。

例如,考虑替代方案

result_3 = torch.einsum('ab,bc,cd,de,ef,fg,gh,ha->a', B, A.T, A, B.T, B, A, A.T, B.T)

此替代方案将仅进行矩阵乘法,并返回结果 (m,m) 矩阵的对角线元素。

© www.soinside.com 2019 - 2024. All rights reserved.