在 PyTorch 中注册中间节点的 Hook

问题描述 投票:0回答:1

想象你有这样简单的事情:

x = torch.tensor([4.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
output = x * y + x / y

grad_x = torch.ones_like(output)
autograd.grad(output, x, grad_outputs=grad_x)

这会产生这样的计算图:

enter image description here

现在我想访问中间节点的梯度,即

MulBackward0
DivBackward0
AddBackward0
。我知道 PyTorch 默认情况下不存储它。我知道我可以将其明确化,然后定义
retain_grad
等。

但是是否有可能只将钩子附加到这些节点中的任何一个,而不必循环遍历图形等?

python pytorch pyhook
1个回答
0
投票

在 PyTorch 中,您可以将钩子注册到模块(顺便说一下,还有张量),用于 backward (分别为 forward)传递,以对其梯度进行操作(分别输出)。

您可以将乘法和除法运算包装在单独的模块中并挂钩它们。

© www.soinside.com 2019 - 2024. All rights reserved.