如何在 PyTorch 训练期间正确标准化权重而不绕过 Autograd?

问题描述 投票:0回答:1

我正在 PyTorch 中实现一个神经网络,需要在前向传递过程中标准化某些层的权重。具体来说,我想通过某些层的 L2 范数对权重进行标准化。这是我的代码的简化版本:

import torch
import torch.nn.functional as F

class MyModel(torch.nn.Module):
    def __init__(self, layers, activation_function):
        super(MyModel, self).__init__()
        self.layers = torch.nn.ModuleList(layers)
        self.act_fun = activation_function

    def forward(self, X):
        output = X
        for i, layer in enumerate(self.layers):
            if i > 0:
                # Normalize the weights
                layer.weight.data = F.normalize(layer.weight, p=2, dim=1)
            if i < len(self.layers) - 1:
                output = self.act_fun(layer(output))
            else:
                output = layer(output)
        return output.squeeze()

我担心的是:

  1. Autograd兼容性:通过直接修改layer.weight.data,我是否绕过了PyTorch的autograd系统?这会阻止反向传播过程中正确计算梯度吗?
  2. 正确的梯度更新:当我调用loss.backward()时,是否会考虑权重归一化,或者我是否需要以不同的方式处理这个问题以确保正确的梯度计算?
  3. 更好的实践:是否有推荐的方法在 PyTorch 训练期间标准化层权重,以保持与 autograd 的兼容性并确保正确的梯度更新?

我读到直接修改 .data 可能会导致梯度跟踪问题,但我不确定如何在这种情况下正确实现权重标准化。

python deep-learning pytorch neural-network
1个回答
0
投票

@Karl,你重复缩小权重的问题(在每次前进后不恢复它们)是很恰当的。我以前没有见过这个,怀疑它能让权重消失(指数缩小)。
但是,缩放权重而不是激活/输出 可能是一个有效的选择,以避免在以下情况下推理精度损失(尤其是使用

BFloat16
进行训练时):

  • 假设您希望在训练期间按特定的
    scalar
    (不一定是 L2 归一化中的动态)缩放特定层的激活/输出。假设训练完成,模型已保存,现在您想要加载它并使用经典的 pytorch
    nn.Module
    前向调用进行推断。
    由于此方法没有缩放功能,因此您需要在训练后和保存模型之前将标量与权重融合(权重*=标量)。 这会引入舍入误差,因为在训练中您将像这样转发:
    (W.T * x)*scalar
    ,而之后,在推理中,您将像这样转发:
    (W.T * scalar) * x
    。其中
    x
    是缩放层的输入。
如果权重值较低(例如 ~1e-3、1e-4),则

这种精度损失可能会很严重。在这种情况下,某些值的不精确度可能达到 5-10%。 我尝试在特定上下文中缩放权重(llm 预训练),然后进行前向传递,然后通过分配给 self.weight.data
之前克隆的(未缩放的)权重数据来恢复它们。因为如果我将新的

scaled-weights
分配给
self.weight
它就会中断,因为它需要
torch.nn.Parameter.parameter
而不是
torch.Tensor
。而且我无法就地缩放权重,因为对图的叶变量进行就地修改是不可能的。
然而,这以我无法理解的方式破坏了自动分级图连接。下一层会以某种方式失去它的等级属性。 


如果有人

知道如何在按层转发和后退之前缩放权重,然后为下一个前向/后退再次恢复未缩放的权重,而不破坏自动梯度图,也不使权重变得

torch.Tensor 并且这释放了它们的参数类

© www.soinside.com 2019 - 2024. All rights reserved.