我有方程式:
$e_{ij} = \frac{X_i W^Q (X_j W^K + A^K_{ij}) }{\sqrt{D_z}}$
$\alpha_{ij} = softmax(e_{ij})$
$z_{i} = \sum_j \alpha_{ij} (X_j W^V + A^V_{ij})$
其中尺寸:
X: [B, S, H,D]
each W: [H,D,D]
each A: [S, S, H,D]
我如何通过矩阵运算来计算它?
我有部分解决方案
import torch
import torch.nn.functional as F
B, S, H, D = X.shape
d_z = D # Assuming d_z is equal to D for simplicity
W_Q = torch.randn(H, D, D)
W_K = torch.randn(H, D, D)
W_V = torch.randn(H, D, D)
a_K = torch.randn(S, S, H, D)
a_V = torch.randn(S, S, H, D)
}
XW_Q = torch.einsum('bshd,hde->bshe', X, W_Q) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
XW_K = torch.einsum('bshd,hde->bshe', X, W_K) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
e_ij_numerator = XW_Q.unsqueeze(2) @ (XW_K.unsqueeze(1) + a_K).transpose(-1, -2) # [B, S, 1, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]
e_ij = e_ij_numerator / torch.sqrt(torch.tensor(d_z, dtype=torch.float32)) # [B, S, S, H, D]
XW_V = torch.einsum('bshd,hde->bshe', X, W_V) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
alpha = F.softmax(e_ij, dim=2) # [B, S, S, H, D]
z_i = torch.einsum('bshij,bshjd->bshid', alpha, XW_V.unsqueeze(1) + a_V) # [B, S, S, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]
但是 z 应该是 [B, S, H,D]
所以,如果我正确理解你的问题,那么你正在批量中的第
i
和 j
序列之间实施注意机制。首先,线性投影数据 (X) 以获取查询:XW_Q,然后线性投影数据以获取键:XW_K。然后添加偏差 a_K,最后计算 XW_Q @ (XW_K + a_K) 之间的点积(相似度)。
在这种情况下,查询中的每个 D 维嵌入都会与键中的每个 D 维嵌入相乘(在点积意义上)。两个向量点积的输出是一个标量,对于 e_ij 的形状应该是 [B, S, S, H],而不是 [B, S, H, D]。
import torch
import torch.nn.functional as F
X = torch.randn((10, 20, 30, 40))
B, S, H, D = X.shape
d_z = D # Assuming d_z is equal to D for simplicity
W_Q = torch.randn(H, D, D)
W_K = torch.randn(H, D, D)
W_V = torch.randn(H, D, D)
a_K = torch.randn(S, S, H, D)
a_V = torch.randn(S, S, H, D)
XW_Q = torch.einsum('bshd,hde->bshe', X, W_Q) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
XW_K = torch.einsum('bshd,hde->bshe', X, W_K) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
e_ij_numerator = (XW_Q.unsqueeze(2) * (XW_K.unsqueeze(1) + a_K)).sum(dim=-1) # [B, S, S, H]
e_ij = e_ij_numerator / torch.sqrt(torch.tensor(d_z, dtype=torch.float32)) # [B, S, S, H]
alpha = F.softmax(e_ij, dim=2) # [B, S, S, H]
XW_V = torch.einsum('bshd,hde->bshe', X, W_V) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
z_i = torch.einsum('bijh,bijhd -> bihd', alpha, (XW_V.unsqueeze(1) + a_V)) # [B, S, S, H] * [B, S, S, H, D] -> [B, S, S, H, D]
print(z_i.shape) # [B, S, H, D].
Then, after normalization, you apply softmax such that every ith row sums to 1 to get the scaling matrix alpha which is also [B, S, S, H]
Now, you project your input the get the values: X@W_V. This should result in a [B, S, H, D] Tensor.
Finally, you get the new ith sequence (z_i) by scaling every jth sequence column of XW_V by the jth scaling factor in alpha_i and sum, resulting in a [B, S, H, D] tensor as you expected.
See the modified code below: