为什么 PyTorch 优化器中叶子变量的就地操作不会导致错误?

问题描述 投票:0回答:1

当我想对叶变量进行就地操作时,出现错误:

import torch
x = torch.tensor([2.0, 10.0], requires_grad=True)

y = x[0]**2 + x[1]**2

y.backward()

x.add_(0.2, alpha=0.2) # RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

但是我们在训练循环中的优化器

optim.step()
中没有这个问题。您知道该方法是对神经网络参数进行操作,这些参数也是叶变量。我检查了
step()
函数代码,没有看到
with no_grad()
语句,所以我想知道为什么这个方法中的就地操作不会导致错误。这是优化器代码中的就地操作:
param.add_(grad, alpha=-lr)
。中间没有梯度分离。

python pytorch
1个回答
0
投票

tl;dr 他们正在使用

no_grad
,只是在另一个位置并且以稍微复杂的方式

详细:

沿着堆栈向上走。

param.add_(grad, alpha=-lr)
_single_tensor_sgd 中的一条线。
_single_tensor_sgd
sgd 调用。
sgd
optim.SGD.step 调用。

step
函数具有_use_grad_for_ Differentiable装饰器。这是装饰器的代码:

def _use_grad_for_differentiable(func):
    def _use_grad(self, *args, **kwargs):
        import torch._dynamo

        prev_grad = torch.is_grad_enabled()
        try:
            # documentation here omitted for brevity
            torch.set_grad_enabled(self.defaults["differentiable"])
            torch._dynamo.graph_break()
            ret = func(self, *args, **kwargs)
        finally:
            torch._dynamo.graph_break()
            torch.set_grad_enabled(prev_grad)
        return ret

    functools.update_wrapper(_use_grad, func)
    return _use_grad

所有 pytorch 优化器都有一个

self.defaults
属性,该属性包含多个标志,包括
differentiable
标志。
_use_grad_for_differentiable
装饰器根据优化器的
self.defaults["differentiable"]
标志设置分级状态。

这样做是为了实现灵活性。

大多数优化器(即Pytorch中的标准优化器)都有

self.defaults["differentiable"]=False
,所以这个装饰器的操作就像
no_grad
。这就是 SGD 中的就地操作不会引发错误的原因。

使用

differentiable
标志允许人们以一种如果简单地将
no_grad
应用于整个事物就不可能实现的方式来试验可微优化器。

您也可以测试一下。这段代码运行良好:

import torch
import torch.nn as nn

x = nn.Parameter(torch.tensor([2.0, 10.0]))

opt = torch.optim.SGD([x], 1e-3)

loss = x.sum()
loss.backward()

opt.step()

此代码在就地操作中引发错误:

import torch
import torch.nn as nn

x = nn.Parameter(torch.tensor([2.0, 10.0]))

opt = torch.optim.SGD([x], 1e-3)
opt.defaults['differentiable'] = True

loss = x.sum()
loss.backward()

opt.step()
© www.soinside.com 2019 - 2024. All rights reserved.