如果索引叶变量进行梯度更新,如何解决就地操作错误?

问题描述 投票:0回答:4

当我尝试索引叶子变量以使用自定义的 Shrink 函数更新梯度时,遇到就地操作错误。我无法解决这个问题。非常感谢任何帮助!

import torch.nn as nn
import torch
import numpy as np
from torch.autograd import Variable, Function

# hyper parameters
batch_size = 100 # batch size of images
ld = 0.2 # sparse penalty
lr = 0.1 # learning rate

x = Variable(torch.from_numpy(np.random.normal(0,1,(batch_size,10,10))), requires_grad=False)  # original

# depends on size of the dictionary, number of atoms.
D = Variable(torch.from_numpy(np.random.normal(0,1,(500,10,10))), requires_grad=True)

# hx sparse representation
ht = Variable(torch.from_numpy(np.random.normal(0,1,(batch_size,500,1,1))), requires_grad=True)

# Dictionary loss function
loss = nn.MSELoss()

# customized shrink function to update gradient
shrink_ht = lambda x: torch.stack([torch.sign(i)*torch.max(torch.abs(i)-lr*ld,0)[0] for i in x])

### sparse reprsentation optimizer_ht single image.
optimizer_ht = torch.optim.SGD([ht], lr=lr, momentum=0.9) # optimizer for sparse representation

## update for the batch
for idx in range(len(x)):
    optimizer_ht.zero_grad() # clear up gradients
    loss_ht = 0.5*torch.norm((x[idx]-(D*ht[idx]).sum(dim=0)),p=2)**2
    loss_ht.backward() # back propogation and calculate gradients
    optimizer_ht.step() # update parameters with gradients
    ht[idx] = shrink_ht(ht[idx])  # customized shrink function.

RuntimeError Traceback (most recent call last) in ()
15 loss_ht.backward() # back propogation and calculate gradients
16 optimizer_ht.step() # update parameters with gradients
—> 17 ht[idx] = shrink_ht(ht[idx]) # customized shrink function.
18
19

/home/miniconda3/lib/python3.6/site-packages/torch/autograd/variable.py in setitem(self, key, value)
85 return MaskedFill.apply(self, key, value, True)
86 else:
—> 87 return SetItem.apply(self, key, value)
88
89 def deepcopy(self, memo):

RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

具体来说,下面的这行代码似乎给出了错误,因为它同时索引和更新叶变量。

ht[idx] = shrink_ht(ht[idx])  # customized shrink function.

谢谢。

W.S.

python neural-network deep-learning gradient-descent pytorch
4个回答
12
投票

我刚刚发现:为了更新变量,需要是

ht.data[idx]
而不是
ht[idx]
。我们可以使用
.data
直接访问张量。


2
投票

问题来自于

ht
需要 grad:

ht = Variable(torch.from_numpy(np.random.normal(0,1,(batch_size,500,1,1))), requires_grad=True)

对于需要梯度的变量,pytorch 不允许您为它们的(切片)赋值。你不能这样做:

ht[idx] = some_tensor

这意味着您需要找到另一种方法来使用内置的 pytorch 函数(如

squeeze
unsqueeze
等)来执行自定义收缩功能。

另一种选择是将

shrink_ht(ht[idx])
切片分配给另一个不需要梯度的变量或张量。


1
投票

这里使用

ht.data[idx]
是可以的,但是新的约定是明确使用
torch.no_grad()
,例如:

with torch.no_grad(): 
    ht[idx] = shrink_ht(ht[idx])

请注意,此就地操作没有梯度。换句话说,梯度仅向后到

shrunk
ht
值,而不是向后到
unshrunk
ht
值。


0
投票

我在下面遇到了同样的错误:

运行时错误:需要 grad 的叶变量正在就地操作中使用。

当我将

requires_grad=True
inplace=True
分别设置为 tensor()ReLU() 时,如下所示:

import torch
from torch import nn
                                                    # ↓↓↓↓ Here ↓↓↓↓
tensor1 = torch.tensor([-3, -2, -1, 0, 1, 2, 3], requires_grad=True)
               # ↓ Here ↓
relu = nn.ReLU(inplace=True)
relu(input=tensor1) # Error

所以,我只使用

requires_grad=True
inplace=True
或使用no_grad()或非叶张量,然后我可以得到如下所示的结果:

import torch
from torch import nn
                                                  # ↓↓↓↓ Here ↓↓↓↓
tensor1 = torch.tensor([-3, -2, -1, 0, 1, 2, 3.], requires_grad=True)

relu = nn.ReLU()
relu(input=tensor1)
# tensor([0., 0., 0., 0., 1., 2., 3.], grad_fn=<ReluBackward0>)

tensor1 = torch.tensor([-3., -2., -1., 0., 1., 2., 3.])
               # ↓ Here ↓
relu = nn.ReLU(inplace=True)
relu(input=tensor1)
# tensor([0., 0., 0., 0., 1., 2., 3.])

tensor1 = torch.tensor([-3, -2, -1, 0, 1, 2, 3.], requires_grad=True)
     # ↓↓ Here ↓↓
with torch.no_grad():
    relu = nn.ReLU(inplace=True)
    print(relu(input=tensor1))
    # tensor([0., 0., 0., 0., 1., 2., 3.], requires_grad=True)

tensor1 = torch.tensor([-3, -2, -1, 0, 1, 2, 3.], requires_grad=True)
tensor2 = tensor1 + 0 # Here

relu = nn.ReLU(inplace=True)
print(relu(input=tensor2)) # Here
# tensor([0., 0., 0., 0., 1., 2., 3.], grad_fn=<ReluBackward0>)
© www.soinside.com 2019 - 2024. All rights reserved.