这是一个说明此问题的示例函数。它试图获取连接的 QKV 矩阵(来自 Transformer 大语言模型)的 Q 矩阵(用于 KV 缓存)。在 QKV 投影被分割为头,然后分割为每个头的 q、k、v 投影之前,
manual_q
与 unsplit_q
匹配。这些操作发生后,manual_q
和q
不仅不同,而且相差很大。这里发生了什么导致这种差异?
def get_q_matrix(self, x):
batch_size, seq_len, n_embd = x.size()
debug_qkv = self.query_key_value(x) # shape (batch_size, seq_len, n_embd)
unsplit_q, _, _ = debug_qkv.split(
self.n_embd, dim=-1
) # shape (batch_size, seq_len, n_embd // 3)
debug_qkv = debug_qkv.view(
batch_size, seq_len, self.n_head, 3 * self.head_size
) # shape (batch_size, seq_len, 4, 96)
q, _, _ = debug_qkv.split(
self.head_size, dim=-1
) # shape (batch_size, seq_len, 4, 96 // 3)
# Ensure correct weight and bias extraction
weight = self.query_key_value.weight
bias = self.query_key_value.bias
q_weight, k_weight, v_weight = weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = bias.chunk(3, dim=0)
manual_q = F.linear(x, q_weight, q_bias)
manual_q = manual_q # shape (batch_size, seq_len, n_embd)
assert torch.allclose(unsplit_q, manual_q) # Passes
print(torch.max(torch.abs(unsplit_q - manual_q))) # tensor(0.)
manual_q = manual_q.view(batch_size, seq_len, self.n_head, self.head_size)
print(torch.max(torch.abs(q - manual_q))) # tensor(35.6218)
assert torch.allclose(q, manual_q) # AssertionError
return manual_q
您的视图操作正在改变张量中数据的布局。这导致了差异。举个简单的例子:
创建一个代表打包 QKV 值的虚拟张量。张量的批量大小和序列长度为 1,以及 int 值,可以轻松跟踪值的移动情况。
import torch
d_emb = 1
n_heads = 4
# size (1, 1, d_emb*n_heads*3)
debug_qkv = torch.arange(d_emb*n_heads*3)[None,None,:]
print(debug_qkv.shape)
> torch.Size([1, 1, 12])
print(debug_qkv)
> tensor([[[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]])
现在我们通过沿最终轴分割
unsplit_q
来计算 debug_qkv
。我们看到 unsplit_q
具有值 [0, 1, 2, 3]
。这是有道理的。 debug_qkv
的最后一个暗淡的值从 0 到 11。我们将其分成 3 个连续的块。第一个块 unsplit_q
的值从 0 到 3。
unsplit_q, _, _ = debug_qkv.split(n_heads, -1)
print(unsplit_q.shape)
> tensor([[[0, 1, 2, 3]]])
print(unsplit_q)
> torch.Size([1, 1, 4])
现在我们看一下视图操作。在您的代码中,这是行
debug_qkv = debug_qkv.view(batch_size, seq_len, self.n_head, 3 * self.head_size)
我们可以看到张量布局如何变化。视图操作首先填充
n_head
维度,然后填充最终维度。我们正在寻找的值 - [0, 1, 2, 3]
- 实际上已经被分割了。
qkv_reshaped = debug_qkv.view(1, 1, n_heads, d_emb*3)
print(qkv_reshaped.shape)
> torch.Size([1, 1, 4, 3])
print(qkv_reshaped)
> tensor([[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]]])
当我们将
q
从 qkv_reshaped
中分离出来时,它会向前传播
q, _, _ = qkv_reshaped.split(d_emb, dim=-1)
print(q)
> tensor([[[[0],
[3],
[6],
[9]]]])
我们可以清楚地看到这些值是如何混合的。
unsplit_q
具有连续值 [0, 1, 2, 3]
,而 q
具有重塑值 [0, 3, 6, 9]
。
这种情况下的解决方案是重塑
debug_qkv
以使头部尺寸位于最后
qkv_reshaped2 = debug_qkv.view(1, 1, d_emb*3, n_heads)
print(qkv_reshaped2)
> tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]]])
q2, _, _ = qkv_reshaped2.split(d_emb, dim=-2)
print(q2)
> tensor([[[[0, 1, 2, 3]]]])
如果你阅读了 pytorch 的多头注意力的源代码,你会发现很多类似这样的操作:
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
由于这个确切的问题,此操作最后会在头部变暗的情况下进行重塑。