如何在单个图像像素上运行backward(),而不会导致再次尝试向后浏览图形的错误?

问题描述 投票:0回答:1

我有这个代码

...
vertex_id = 275

deform_verts.retain_grad() # input
predicted_silhouette.retain_grad() # output

impact_img = torch.zeros_like(predicted_silhouette, requires_grad=False)
for i in range(image_size):
  for j in range(image_size):
      pixel = predicted_silhouette[i][j]
      pixel.retain_grad()
      pixel.backward()
      impact = deform_verts.grad[vertex_id] 
      impact_img[i][j] += impact.sum()


plt.imshow(impact_img.detach().cpu().numpy())

我正在尝试根据单一的 Deform_verts 条目如何影响整个图像来创建图像。为此,我遍历输出图像的每个像素并调用

.backward()
并将其渐变插入到新图像中以实现可视化目的。但是,当我在第一个像素上调用
backward()
时,我无法在第二个像素上调用它,因为我怀疑中间变量已在反向传播中使用。我尝试使用
retain_graph
但我认为这没有达到我想要的效果,因为图像看起来与我的预期相去甚远。

python pytorch
1个回答
0
投票

使用

deform_verts.grad.zero_()
读取梯度后需要将其归零。
backward
方法实际上是累积梯度,但你只想保留图形,而不是之前的梯度

© www.soinside.com 2019 - 2024. All rights reserved.