PyTorch 线性运算在重塑后变化很大

问题描述 投票:0回答:1

这是一个说明此问题的示例函数。它试图获取连接的 QKV 矩阵(来自 Transformer 大语言模型)的 Q 矩阵(用于 KV 缓存)。在 QKV 投影被分割为头,然后分割为每个头的 q、k、v 投影之前,

manual_q
unsplit_q
匹配。这些操作发生后,
manual_q
q
不仅不同,而且相差很大。这里发生了什么导致这种差异?

    def get_q_matrix(self, x):
        batch_size, seq_len, n_embd = x.size()

        debug_qkv = self.query_key_value(x)  # shape (batch_size, seq_len, n_embd)
        unsplit_q, _, _ = debug_qkv.split(
            self.n_embd, dim=-1
        )  # shape (batch_size, seq_len, n_embd // 3)

        debug_qkv = debug_qkv.view(
            batch_size, seq_len, self.n_head, 3 * self.head_size
        )  # shape (batch_size, seq_len, 4, 96)
        q, _, _ = debug_qkv.split(
            self.head_size, dim=-1
        )  # shape (batch_size, seq_len, 4, 96 // 3)

        # Ensure correct weight and bias extraction
        weight = self.query_key_value.weight
        bias = self.query_key_value.bias

        q_weight, k_weight, v_weight = weight.chunk(3, dim=0)
        q_bias, k_bias, v_bias = bias.chunk(3, dim=0)

        manual_q = F.linear(x, q_weight, q_bias)
        manual_q = manual_q  # shape (batch_size, seq_len, n_embd)

        assert torch.allclose(unsplit_q, manual_q) # Passes
        print(torch.max(torch.abs(unsplit_q - manual_q)))  # tensor(0.)

        manual_q = manual_q.view(batch_size, seq_len, self.n_head, self.head_size)

        print(torch.max(torch.abs(q - manual_q)))  # tensor(35.6218)
        assert torch.allclose(q, manual_q) # AssertionError
        return manual_q
python debugging pytorch transformer-model attention-model
1个回答
0
投票

您的视图操作正在改变张量中数据的布局。这导致了差异。举个简单的例子:

创建一个代表打包 QKV 值的虚拟张量。张量的批量大小和序列长度为 1,以及 int 值,可以轻松跟踪值的移动情况。

import torch

d_emb = 1
n_heads = 4

# size (1, 1, d_emb*n_heads*3)
debug_qkv = torch.arange(d_emb*n_heads*3)[None,None,:]

print(debug_qkv.shape)
> torch.Size([1, 1, 12])
print(debug_qkv)
> tensor([[[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]]])

现在我们通过沿最终轴分割

unsplit_q
来计算
debug_qkv
。我们看到
unsplit_q
具有值
[0, 1, 2, 3]
。这是有道理的。
debug_qkv
的最后一个暗淡的值从 0 到 11。我们将其分成 3 个连续的块。第一个块
unsplit_q
的值从 0 到 3。

unsplit_q, _, _ = debug_qkv.split(n_heads, -1)

print(unsplit_q.shape)
> tensor([[[0, 1, 2, 3]]])
print(unsplit_q)
> torch.Size([1, 1, 4])

现在我们看一下视图操作。在您的代码中,这是行

debug_qkv = debug_qkv.view(batch_size, seq_len, self.n_head, 3 * self.head_size)

我们可以看到张量布局如何变化。视图操作首先填充

n_head
维度,然后填充最终维度。我们正在寻找的值 -
[0, 1, 2, 3]
- 实际上已经被分割了。

qkv_reshaped = debug_qkv.view(1, 1, n_heads, d_emb*3)
print(qkv_reshaped.shape)
> torch.Size([1, 1, 4, 3])
print(qkv_reshaped)
> tensor([[[[ 0,  1,  2],
           [ 3,  4,  5],
           [ 6,  7,  8],
           [ 9, 10, 11]]]])

当我们将

q
qkv_reshaped

中分离出来时,它会向前传播
q, _, _ = qkv_reshaped.split(d_emb, dim=-1)
print(q)
> tensor([[[[0],
           [3],
           [6],
           [9]]]])

我们可以清楚地看到这些值是如何混合的。

unsplit_q
具有连续值
[0, 1, 2, 3]
,而
q
具有重塑值
[0, 3, 6, 9]

这种情况下的解决方案是重塑

debug_qkv
以使头部尺寸位于最后

qkv_reshaped2 = debug_qkv.view(1, 1, d_emb*3, n_heads)
print(qkv_reshaped2)

> tensor([[[[ 0,  1,  2,  3],
           [ 4,  5,  6,  7],
           [ 8,  9, 10, 11]]]])

q2, _, _ = qkv_reshaped2.split(d_emb, dim=-2)
print(q2)

> tensor([[[[0, 1, 2, 3]]]])

如果你阅读了 pytorch 的多头注意力的源代码,你会发现很多类似这样的操作:

q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)

由于这个确切的问题,此操作最后会在头部变暗的情况下进行重塑。

© www.soinside.com 2019 - 2024. All rights reserved.