如何在 PyTorch 中使用不同输入形状(batch_first 与默认)在 MultiheadAttention 之后正确应用 LayerNorm?

问题描述 投票:0回答:1

我正在使用 PyTorch 中基于 Transformer 的模型来完成音频识别任务。我的输入特征由基于 CNN 的嵌入层生成,形状为 [batch_size, d_model, n_token],其中 n_token 是序列长度,d_model 是特征维度。

默认情况下,nn.MultiheadAttention(当batch_first = False时)期望输入形状(seq,batch,feature)。为了让事情更直观,我选择设置batch_first = True,然后将数据从[batch_size,d_model,n_token]排列为[batch_size,n_token,d_model],以便时间维度位于特征维度之前。这是一个简化的代码片段:

# Original shape: [batch_size, d_model, n_token]
data = concat_cls_token(data)   # [batch_size, d_model, n_token+1]
data = data.permute(0, 2, 1)    # [batch_size, n_token+1, d_model]

multihead_att = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
data, _ = multihead_att(data, data, data)
# Result shape: [batch_size, n_token+1, d_model]

应用多头注意力之后,我直接在这个 [batch_size, n_token+1, d_model] 张量上使用 LayerNorm(d_model)。我的理解是,LayerNorm 对特征维度进行归一化,因此只要特征维度(d_model)是最后一个,它就应该可以正常工作。但我有两个主要问题: 1. 如果我坚持使用默认的多头注意力格式(seq、batch、feature)——即使用 [n_token+1, batch_size, d_model]——LayerNorm(d_model) 仍然会沿着特征维度正确标准化而不需要排列又是张量? 2. 在实践中,对于像我这样的任务(音频序列识别),最好的方法是什么?是建议在调用LayerNorm之前将数据保留为[batch_size, seq_len, d_model]格式,还是只要特征维度在最后就完全可以接受使用(seq, batch, feature)?

我和我的顾问都有点不确定。我真的很感激任何指导或澄清。下面是我的forward方法的更多细节以及相应的AttentionBlock实现供参考:

def forward(self, x: torch.Tensor):
    # Initial: x is [batch_size, d_model, num_tokens]
    x = self.expand(x)
    x = self.concat_cls_token(x)   # [batch_size, d_model, num_tokens+1]
    x = x.permute(0, 2, 1)         # [batch_size, num_tokens+1, d_model]
    x = self.positional_encoder(x)
    x = self.attention_block(x)    # [batch_size, num_tokens+1, d_model]
    x = x.permute(0, 2, 1)         # [batch_size, d_model, num_tokens+1]

    x = self.get_cls_token(x)      # [batch_size, d_model, 1]
    y = self.class_mlp(x)          # [batch_size, n_classes]
    return y

以及AttentionBlock的实现:

class AttentionBlock(nn.Module):
    @staticmethod
    def make_ffn(hidden_dim: int) -> torch.nn.Module:
        return nn.Sequential(
            OrderedDict([
                ("ffn_linear1", nn.Linear(in_features=hidden_dim, out_features=hidden_dim)),
                ("ffn_relu", nn.ReLU()),
                ("ffn_linear2", nn.Linear(in_features=hidden_dim, out_features=hidden_dim))
            ])
        )

    def __init__(self, embed_dim, n_head):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, n_head, batch_first=True)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.feed_forward = self.make_ffn(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x: torch.Tensor):
        attn_output, _ = self.attention(x, x, x)
        x = self.layer_norm1(x + attn_output)
        ff_output = self.feed_forward(x)
        x = self.layer_norm2(x + ff_output)
        return x

任何建议或最佳实践将不胜感激。非常感谢!

audio deep-learning pytorch transformer-model pattern-recognition
1个回答
0
投票

来自 layernorm 文档

torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None)

""
Applies Layer Normalization over a mini-batch of inputs.

    This layer implements the operation as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
    is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
    is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
    the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
    The standard-deviation is calculated via the biased estimator, equivalent to
    `torch.var(input, unbiased=False)`.
"""

正如文档所述,平均值和标准差是根据最后一个

D
维度计算的。如果您将图层创建为 nn.LayerNorm(d_model)
,则假定输入的最后一个维度为形状 
d_model
,并在该维度上应用图层规范。张量的其他维度不相关。

© www.soinside.com 2019 - 2024. All rights reserved.