这个多头自注意力代码导致训练损失和验证损失变成 NaN,但是当我删除这部分时,一切都恢复正常。我知道当训练损失和验证损失变为 NaN 时,这意味着存在梯度爆炸。但是,我不知道我的代码有什么问题导致梯度爆炸。当我将它与官方 PyTorch 代码进行比较时,它看起来很相似。当我使用 nn.MultiheadSelfAttention 时,渐变不会爆炸,但是当我使用自己的代码时,渐变开始爆炸。没有显示错误消息。有谁知道我下面的代码有什么问题吗?
class MultiHeadAttention(nn.Module):
def __init__(self, in_dim, num_heads=8, dropout=0):
super().__init__()
self.num_heads = num_heads
self.head_dim = in_dim // num_heads
self.conv_q = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.conv_k = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.conv_v = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.att_drop = nn.Dropout(dropout)
self.proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.proj_drop = nn.Dropout(dropout)
def forward(self, x):
b, _, h, w = x.shape
q = self.conv_q(x)
k = self.conv_k(x)
v = self.conv_v(x)
q = rearrange(q, "b (nh hd) h w -> b nh (h w) hd", nh=self.num_heads)
k = rearrange(k, "b (nh hd) h w -> b nh (h w) hd", nh=self.num_heads)
v = rearrange(v, "b (nh hd) h w -> b nh (h w) hd", nh=self.num_heads)
att_score = q @ k.transpose(2, 3) ** (self.head_dim ** -0.5)
att_score = F.softmax(att_score, dim=-1)
att_score = self.att_drop(att_score)
x = att_score @ v
x = rearrange(x, 'b nh (h w) hd -> b (nh hd) h w', h=h, w=w)
x = self.proj(x)
x = self.proj_drop(x)
return x
更改缩放比例以使用
/ (math.sqrt(self.head_dim))
可以解决此问题。我不确定为什么会发生这种情况,因为 self.head_dim ** -0.5
应该等同于 / (math.sqrt(self.head_dim))
。也许其他人可以解释一下?