torch.Tensor.backward() 是如何工作的?

问题描述 投票:0回答:3

最近在研究Pytorch以及该包的backward函数。 我明白如何使用它,但是当我尝试时

x = Variable(2*torch.ones(2, 2), requires_grad=True)
x.backward(x)
print(x.grad)

我期待

tensor([[1., 1.],
        [1., 1.]])

因为它是恒等函数。然而,它又回来了

tensor([[2., 2.],
        [2., 2.]]).

为什么会出现这种情况?

pytorch gradient torch
3个回答
1
投票

实际上,这就是您要找的:

情况 1:当 z = 2*x**3 + x 时

import torch
from torch.autograd import Variable
x = Variable(2*torch.ones(2, 2), requires_grad=True)
z = x*x*x*2+x
z.backward(torch.ones_like(z))
print(x.grad)

输出:

tensor([[25., 25.],
        [25., 25.]])

情况 2:当 z = x*x

x = Variable(2*torch.ones(2, 2), requires_grad=True)
z = x*x
z.backward(torch.ones_like(z))
print(x.grad)

输出:

tensor([[4., 4.],
        [4., 4.]])

情况 3:当 z = x(您的情况)

x = Variable(2*torch.ones(2, 2), requires_grad=True)
z = x
z.backward(torch.ones_like(z))
print(x.grad)

输出:

tensor([[1., 1.],
        [1., 1.]])

要了解更多如何在 pytorch 中计算梯度,请检查 this


0
投票

我认为您误解了如何使用

tensor.backward()
backward()
里面的参数不是dy/dx的x。

例如,如果通过某种操作从x得到y,则

y.backward(w)
,首先pytorch会得到
l = dot(y,w)
,然后计算
dl/dx
。 因此,对于您的代码,
l = 2x
首先由pytorch计算,然后
dl/dx
是您的代码返回的值。

当你做

y.backward(w)
时,如果y不是标量,只需将
backward()
的参数满1即可;否则就没有参数。


0
投票

backward()可以计算当前张量的梯度,如下所示:

*备注:

  • retain_grad() 可以使非叶张量
    grad
    t2
    t3
    可访问。
  • backward()
    retain_graph=True
    可以累积用于计算的梯度。
import torch

t1 = torch.tensor(3.0, requires_grad=True)

print(f"t1, Value={t1.item()}, Gradient={t1.grad}\n") # t1, Value=3.0, Gradient=None

t2 = t1 * 4

t2.retain_grad()

print(f"t1, Value={t1.item()}, Gradient={t1.grad}")   # t1, Value=3.0, Gradient=None
print(f"t2, Value={t2.item()}, Gradient={t2.grad}\n") # t2, Value=12.0, Gradient=None

t2.backward(retain_graph=True)

print(f"t1, Value={t1.item()}, Gradient={t1.grad}")   # t1, Value=3.0, Gradient=4.0
print(f"t2, Value={t2.item()}, Gradient={t2.grad}\n") # t2, Value=12.0, Gradient=1.0

t3 = t2 * 5

t3.retain_grad()

print(f"t1, Value={t1.item()}, Gradient={t1.grad}")   # t1, Value=3.0, Gradient=4.0
print(f"t2, Value={t2.item()}, Gradient={t2.grad}")   # t2, Value=12.0, Gradient=1.0
print(f"t3, Value={t3.item()}, Gradient={t3.grad}\n") # t3, Value=60.0, Gradient=None

t3.backward()

print(f"t1, Value={t1.item()}, Gradient={t1.grad}") # t1, Value=3.0, Gradient=24.0
print(f"t2, Value={t2.item()}, Gradient={t2.grad}") # t2, Value=12.0, Gradient=6.0
print(f"t3, Value={t3.item()}, Gradient={t3.grad}") # t3, Value=60.0, Gradient=1.0
© www.soinside.com 2019 - 2024. All rights reserved.