想象你有这样简单的事情:
x = torch.tensor([4.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
output = x * y + x / y
grad_x = torch.ones_like(output)
autograd.grad(output, x, grad_outputs=grad_x)
这会产生这样的计算图:
现在我想访问中间节点的梯度,即
MulBackward0
、DivBackward0
和AddBackward0
。我知道 PyTorch 默认情况下不存储它。我知道我可以将其明确化,然后定义 retain_grad
等。
但是是否有可能只将钩子附加到这些节点中的任何一个,而不必循环遍历图形等?