PyTorch 中的 L1/L2 正则化

问题描述 投票:0回答:8

如何在 PyTorch 中添加 L1/L2 正则化而不需要手动计算?

python pytorch loss-function regularized
8个回答
92
投票

使用

weight_decay > 0
进行 L2 正则化:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

81
投票

请参阅文档。向优化器添加

weight_decay
参数以进行 L2 正则化。


47
投票

以前的答案虽然在技术上是正确的,但在性能方面效率低下,并且不太模块化(很难在每层的基础上应用,例如由

keras
层提供的)。

PyTorch L2 实现

为什么 PyTorch 在

L2
实例中实现
torch.optim.Optimizer

让我们看一下

torch.optim.SGD
源代码(目前为功能优化程序),特别是这部分:

for i, param in enumerate(params):
    d_p = d_p_list[i]
    # L2 weight decay specified HERE!
    if weight_decay != 0:
        d_p = d_p.add(param, alpha=weight_decay)
  • 可以看到,
    d_p
    (参数的导数,梯度)被修改并重新分配以加快计算速度(不保存临时变量)
  • 它具有
    O(N)
    的复杂性,没有像
    pow
  • 那样复杂的数学
  • 它不涉及
    autograd
    无需任何扩展图

将其与

O(n)
**2
操作、加法以及参与反向传播进行比较。

数学

让我们看看带有

L2
正则化因子的
alpha
方程(对于 L1 ofc 也可以这样做):

如果我们用

L2
正则化对任何损失求导。参数
w
(与损失无关),我们得到:

所以它只是为每个权重的梯度添加

alpha * weight
这正是 PyTorch 上面所做的!

L1 正则化层

使用这个(和一些 PyTorch 魔法),我们可以想出非常通用的 L1 正则化层,但让我们首先看看

L1
的一阶导数(
sgn
是正负号函数,返回
1
表示正输入,
-1 
为负数,
0
0
):

带有

WeightDecay
接口的完整代码位于torchlayers第三方库中,提供诸如仅正则化权重/偏差/特定命名参数之类的东西(免责声明:我是作者),但下面概述的想法的本质(参见评论):

class L1(torch.nn.Module):
    def __init__(self, module, weight_decay):
        super().__init__()
        self.module = module
        self.weight_decay = weight_decay

        # Backward hook is registered on the specified module
        self.hook = self.module.register_full_backward_hook(self._weight_decay_hook)

    # Not dependent on backprop incoming values, placeholder
    def _weight_decay_hook(self, *_):
        for param in self.module.parameters():
            # If there is no gradient or it was zeroed out
            # Zeroed out using optimizer.zero_grad() usually
            # Turn on if needed with grad accumulation/more safer way
            # if param.grad is None or torch.all(param.grad == 0.0):

            # Apply regularization on it
            param.grad = self.regularize(param)

    def regularize(self, parameter):
        # L1 regularization formula
        return self.weight_decay * torch.sign(parameter.data)

    def forward(self, *args, **kwargs):
        # Simply forward and args and kwargs to module
        return self.module(*args, **kwargs)

阅读有关钩子的更多信息在此答案中或相应的 PyTorch 文档(如果需要)。

使用也非常简单(应该与梯度累积和 PyTorch 层一起使用):

layer = L1(torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)) 

28
投票

对于 L2 正则化,

l2_lambda = 0.01
l2_reg = torch.tensor(0.)

for param in model.parameters():
    l2_reg += torch.norm(param)

loss += l2_lambda * l2_reg

参考资料:


22
投票

开箱即用的 L2 正则化

是的,pytorch optimizers 有一个名为

weight_decay
的参数,它对应于 L2 正则化因子:

sgd = torch.optim.SGD(model.parameters(), weight_decay=weight_decay)

L1正则化实现

L1 没有类似的参数,但这很容易手动实现:

loss = loss_fn(outputs, labels)
l1_lambda = 0.001
l1_norm = sum(torch.linalg.norm(p, 1) for p in model.parameters())

loss = loss + l1_lambda * l1_norm

L2 的等效手动实现是:

l2_reg = sum(p.pow(2).sum() for p in model.parameters())

来源:使用 PyTorch 进行深度学习 (8.5.2)


18
投票

用于 L1 正则化并仅包含

weight

l1_reg = torch.tensor(0., requires_grad=True)

for name, param in model.named_parameters():
    if 'weight' in name:
        l1_reg = l1_reg + torch.linalg.norm(param, 1)

total_loss = total_loss + 10e-4 * l1_reg

6
投票

有趣的是

torch.norm
与直接方法相比,在 CPU 上速度较慢,在 GPU 上速度更快。

import torch
x = torch.randn(1024,100)
y = torch.randn(1024,100)

%timeit torch.sqrt((x - y).pow(2).sum(1))
%timeit torch.norm(x - y, 2, 1)

出:

1000 loops, best of 3: 910 µs per loop
1000 loops, best of 3: 1.76 ms per loop

另一方面:

import torch
x = torch.randn(1024,100).cuda()
y = torch.randn(1024,100).cuda()

%timeit torch.sqrt((x - y).pow(2).sum(1))
%timeit torch.norm(x - y, 2, 1)

出:

10000 loops, best of 3: 50 µs per loop
10000 loops, best of 3: 26 µs per loop

1
投票

扩展好的答案:正如所说,添加到损失中的 L2 范数相当于权重衰减 iff 您使用没有动量的普通 SGD。否则,例如对于亚当来说,情况并不完全相同。 AdamW论文[1]指出权重衰减实际上更稳定。这就是为什么您应该使用权重衰减,这是优化器的一个选项。并考虑使用

AdamW
代替
Adam

另请注意,您可能不希望所有参数 (

model.parameters()
) 上的权重衰减,而只希望在一个子集上进行权重衰减。请参阅此处的示例:

[1] 解耦权重衰减正则化 (AdamW),2017

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.