多头自注意力中的爆炸梯度(NaN 训练损失和验证损失) - Vision Transformer

问题描述 投票:0回答:1

这个多头自注意力代码导致训练损失和验证损失变成 NaN,但是当我删除这部分时,一切都恢复正常。我知道当训练损失和验证损失变为 NaN 时,这意味着存在梯度爆炸。但是,我不知道我的代码有什么问题导致梯度爆炸。当我将它与官方 PyTorch 代码进行比较时,它看起来很相似。当我使用 nn.MultiheadSelfAttention 时,渐变不会爆炸,但是当我使用自己的代码时,渐变开始爆炸。没有显示错误消息。有谁知道我下面的代码有什么问题吗?

class MultiHeadAttention(nn.Module):
  def __init__(self, in_dim, num_heads=8, dropout=0):
    super().__init__()
    self.num_heads = num_heads
    self.head_dim = in_dim // num_heads
    self.conv_q = nn.Conv2d(in_dim, in_dim, kernel_size=1)
    self.conv_k = nn.Conv2d(in_dim, in_dim, kernel_size=1)
    self.conv_v = nn.Conv2d(in_dim, in_dim, kernel_size=1)
    self.att_drop = nn.Dropout(dropout)
    self.proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
    self.proj_drop = nn.Dropout(dropout)

  def forward(self, x):

    b, _, h, w = x.shape
    
    q = self.conv_q(x)
    k = self.conv_k(x)
    v = self.conv_v(x)

    q = rearrange(q, "b (nh hd) h w -> b nh (h w) hd", nh=self.num_heads)
    k = rearrange(k, "b (nh hd) h w -> b nh (h w) hd", nh=self.num_heads)
    v = rearrange(v, "b (nh hd) h w -> b nh (h w) hd", nh=self.num_heads)

    att_score = q @ k.transpose(2, 3) ** (self.head_dim ** -0.5)
    att_score = F.softmax(att_score, dim=-1)
    att_score = self.att_drop(att_score)

    x = att_score @ v

    x = rearrange(x, 'b nh (h w) hd -> b (nh hd) h w', h=h, w=w)

    x = self.proj(x)
    x = self.proj_drop(x)

    return x
machine-learning deep-learning pytorch computer-vision vision-transformer
1个回答
0
投票

更改缩放比例以使用

/ (math.sqrt(self.head_dim))
可以解决此问题。我不确定为什么会发生这种情况,因为
self.head_dim ** -0.5
应该等同于
/ (math.sqrt(self.head_dim))
。也许其他人可以解释一下?

© www.soinside.com 2019 - 2024. All rights reserved.