在 PyTorch 上使用矩阵计算公式

问题描述 投票:0回答:1

我有方程式:

$e_{ij} = \frac{X_i W^Q (X_j W^K + A^K_{ij}) }{\sqrt{D_z}}$
$\alpha_{ij} = softmax(e_{ij})$
$z_{i} = \sum_j \alpha_{ij} (X_j W^V + A^V_{ij})$

其中尺寸:

X: [B, S, H,D]
each W: [H,D,D]
each A: [S, S, H,D]

我如何通过矩阵运算来计算它?

我有部分解决方案

import torch
import torch.nn.functional as F

B, S, H, D = X.shape
d_z = D  # Assuming d_z is equal to D for simplicity

W_Q = torch.randn(H, D, D)
W_K = torch.randn(H, D, D)
W_V = torch.randn(H, D, D)

a_K = torch.randn(S, S, H, D)
a_V = torch.randn(S, S, H, D)
}
XW_Q = torch.einsum('bshd,hde->bshe', X, W_Q)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
XW_K = torch.einsum('bshd,hde->bshe', X, W_K)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]

e_ij_numerator = XW_Q.unsqueeze(2) @ (XW_K.unsqueeze(1) + a_K).transpose(-1, -2)  # [B, S, 1, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]
e_ij = e_ij_numerator / torch.sqrt(torch.tensor(d_z, dtype=torch.float32))  # [B, S, S, H, D]

XW_V = torch.einsum('bshd,hde->bshe', X, W_V)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
alpha = F.softmax(e_ij, dim=2)  # [B, S, S, H, D]

z_i = torch.einsum('bshij,bshjd->bshid', alpha, XW_V.unsqueeze(1) + a_V)  # [B, S, S, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]

但是 z 应该是 [B, S, H,D]

python machine-learning math pytorch
1个回答
0
投票

所以,如果我正确理解你的问题,那么你正在批量中的第

i
j
序列之间实施注意机制。首先,线性投影数据 (X) 以获取查询:XW_Q,然后线性投影数据以获取键:XW_K。然后添加偏差 a_K,最后计算 XW_Q @ (XW_K + a_K) 之间的点积(相似度)。

在这种情况下,查询中的每个 D 维嵌入都会与键中的每个 D 维嵌入相乘(在点积意义上)。两个向量点积的输出是一个标量,对于 e_ij 的形状应该是 [B, S, S, H],而不是 [B, S, H, D]。

import torch
import torch.nn.functional as F
X = torch.randn((10, 20, 30, 40))
B, S, H, D = X.shape
d_z = D  # Assuming d_z is equal to D for simplicity

W_Q = torch.randn(H, D, D)
W_K = torch.randn(H, D, D)
W_V = torch.randn(H, D, D)

a_K = torch.randn(S, S, H, D)
a_V = torch.randn(S, S, H, D)

XW_Q = torch.einsum('bshd,hde->bshe', X, W_Q)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
XW_K = torch.einsum('bshd,hde->bshe', X, W_K)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]

e_ij_numerator = (XW_Q.unsqueeze(2) * (XW_K.unsqueeze(1) + a_K)).sum(dim=-1)  # [B, S, S, H]
e_ij = e_ij_numerator / torch.sqrt(torch.tensor(d_z, dtype=torch.float32))  # [B, S, S, H]
alpha = F.softmax(e_ij, dim=2)  # [B, S, S, H]
XW_V = torch.einsum('bshd,hde->bshe', X, W_V)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]


z_i = torch.einsum('bijh,bijhd -> bihd', alpha, (XW_V.unsqueeze(1) + a_V))  # [B, S, S, H] * [B, S, S, H, D] -> [B, S, S, H, D]
print(z_i.shape) # [B, S, H, D]. 

Then, after normalization, you apply softmax such that every ith row sums to 1 to get the scaling matrix alpha which is also [B, S, S, H]

Now, you project your input the get the values: X@W_V. This should result in a [B, S, H, D] Tensor. 

Finally, you get the new ith sequence (z_i) by scaling every jth sequence column of XW_V by the jth scaling factor in alpha_i and sum, resulting in a [B, S, H, D] tensor as you expected. 
See the modified code below:

© www.soinside.com 2019 - 2024. All rights reserved.