我正在使用 Pytorch 的
nn.MultiheadAttention
(MHA) 开发一个自我关注模块。我的目标是实现一个因果掩码,强制每个标记只关注其之前的标记,排除自身,这与标记可以关注自身的标准自回归因果掩码不同。
这是生成自定义因果掩码的函数:
def generate_causal_mask(seq_length):
# Diagonal = 0, so each element attends only to elements before it, excluding itself
mask = torch.triu(torch.full((seq_length, seq_length), 1, dtype=torch.float32), diagonal=0).bool()
# Allow the first element to attend to itself to avoid NaN results
mask[0, 0] = False
return mask
生成的蒙版如下所示:
tensor([[False, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, True],
[False, False, True, True, True, True, True, True],
[False, False, False, True, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, False, True, True, True],
[False, False, False, False, False, False, True, True],
[False, False, False, False, False, False, False, True]])
这里,
True
的意思是“不能参加”。第一个元素关注自身(位置 [0, 0] 处的False
)以避免 NaN 结果。
重现问题的代码:
if __name__ == "__main__":
embed_dim = 16
batch_size = 1
seq_len = 8
mha = nn.MultiheadAttention(embed_dim, num_heads=1, batch_first=True)
x = torch.randn(batch_size, seq_len, embed_dim).requires_grad_(True)
causal_mask = generate_causal_mask(seq_len)
print(causal_mask)
output, _ = mha(x, x, x, attn_mask=causal_mask)
# Gradient of the output with respect to the token at position t
t = 5
loss = output[:, t].sum().backward()
print("Gradient of the token:")
print(x.grad)
当打印标记
t = 5
的输入(x.grad)的梯度时,我注意到时间步t = 5
的输出取决于其自身的值。这是出乎意料的,因为根据因果掩码,标记应该只关注其自身之前的元素。
张量([[[ 1.7815e-02, 6.0239e-02, 4.4045e-02, -1.7005e-02, -1.2529e-01、-9.8527e-02、-2.5346e-02、4.4857e-02、-9.7425e-02、1.0793e-01、1.4662e-01、1.0073e-01、-9.0143e-02、 -2.5913e-02、1.3379e-03、-9.0163e-02],
[ 2.6240e-01、1.4095e-01、2.9541e-01、6.0876e-02、-1.5522e-01、 -1.5531e-01、4.4279e-02、6.3482e-02、-2.1853e-01、2.4059e-02、2.2273e-01、1.1566e-01、6.6013e-02、-1.2247e-01、 -1.1333e-01,-1.5512e-01],
[ 5.3024e-02、4.4725e-02、6.7385e-02、5.5258e-03、-6.8150e-02、 -5.9587e-02、-1.4061e-04、2.5825e-02、-7.0633e-02、3.8935e-02、8.7158e-02、5.3142e-02、-1.6992e-02、-3.0389e-02、 -2.0005e-02, -5.6871e-02],
[ 2.9774e-01、1.1942e-01、3.1602e-01、8.5978e-02、-8.4358e-02、 -1.0587e-01、7.2915e-02、3.9608e-02、-1.8192e-01、-5.7704e-02、1.4758e-01、5.6968e-02、1.5057e-01、-1.2490e-01、 -1.3581e-01,-1.1233e-01],
[ 1.1037e-01、7.4862e-02、1.3163e-01、1.9109e-02、-1.0056e-01、 -9.2370e-02、9.9104e-03、3.9165e-02、-1.1730e-01、4.2791e-02、1.3410e-01、7.7194e-02、-1.3165e-03、-5.6924e-02、 -4.4891e-02,-8.9721e-02],
[-6.6541e-02、-1.0303e-02、-3.5482e-02、2.1983e-02、-5.1578e-02、2.0161e-01、7.2047e-02、-4.0216e-02、-1.7608 e-02, -1.2176e-02、-5.2893e-02、-1.1424e-01、4.6907e-03、-1.0784e-01、5.8249e-02、9.0503e-03],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]])
给定自定义因果掩码,标记
t
处的输出应仅取决于较早时间步骤(0 到 t-1
)的标记。它不应该依赖于自身,因为对角线被屏蔽了。
此行为是 MultiheadAttention 实现中的错误,还是我误解了
attn_mask
的工作原理?如果这是预期的行为,您能否澄清如何正确实现所需的掩蔽效果?
相关库的版本:
[pip3] numpy==2.1.3 [pip3] 火炬==2.5.1 [pip3] 火炬音频==2.5.1 [pip3] torchvision==0.20.1
思考注意力权重从何而来。注意力计算:
Attention(Q, K, V) = softmax(QK^T)V
你的注意力分数是通过
attn_weights = softmax(QK^T)
计算的,注意力掩码会掩盖 Softmax 之前 QK^T
中的值。
您屏蔽了权重,使得向量
t
不关注自身,但所有其他相关的注意力权重仍然依赖于向量 t
,因此您仍然有一个非零梯度。
以您的
t=5
为例。当向量 t=5
关注第一个向量 t=0
时,这种相互作用由 attn_weights[0, 5]
加权。 attn_weights[0, 5]
的值取决于 both 向量 t=0
和向量 t=5
(加上来自 softmax 分母的向量 0-4)。即使向量 t=5
不关注自身,它关注的其他向量的注意力权重取决于向量 t=5
的值,因此仍然有一个非零梯度。