我正在 PyTorch 中实现一个神经网络,需要在前向传递过程中标准化某些层的权重。具体来说,我想通过某些层的 L2 范数对权重进行标准化。这是我的代码的简化版本:
import torch
import torch.nn.functional as F
class MyModel(torch.nn.Module):
def __init__(self, layers, activation_function):
super(MyModel, self).__init__()
self.layers = torch.nn.ModuleList(layers)
self.act_fun = activation_function
def forward(self, X):
output = X
for i, layer in enumerate(self.layers):
if i > 0:
# Normalize the weights
layer.weight.data = F.normalize(layer.weight, p=2, dim=1)
if i < len(self.layers) - 1:
output = self.act_fun(layer(output))
else:
output = layer(output)
return output.squeeze()
我担心的是:
我读到直接修改 .data 可能会导致梯度跟踪问题,但我不确定如何在这种情况下正确实现权重标准化。
@Karl,你重复缩小权重的问题(在每次前进后不恢复它们)是很恰当的。我以前没有见过这个,怀疑它能让权重消失(指数缩小)。
但是,缩放权重而不是激活/输出 可能是一个有效的选择,以避免在以下情况下推理精度损失(尤其是使用
BFloat16
进行训练时):
scalar
(不一定是 L2 归一化中的动态)缩放特定层的激活/输出。假设训练完成,模型已保存,现在您想要加载它并使用经典的 pytorch nn.Module
前向调用进行推断。 (W.T * x)*scalar
,而之后,在推理中,您将像这样转发:(W.T * scalar) * x
。其中 x
是缩放层的输入。这种精度损失可能会很严重。在这种情况下,某些值的不精确度可能达到 5-10%。
我尝试在特定上下文中缩放权重(llm 预训练),然后进行前向传递,然后通过分配给 self.weight.data
之前克隆的(未缩放的)权重数据来恢复它们。因为如果我将新的
scaled-weights
分配给 self.weight
它就会中断,因为它需要 torch.nn.Parameter.parameter
而不是 torch.Tensor
。而且我无法就地缩放权重,因为对图的叶变量进行就地修改是不可能的。 然而,这以我无法理解的方式破坏了自动分级图连接。下一层会以某种方式失去它的等级属性。
torch.Tensor
并且这释放了它们的参数类
。