为了定位错误,我正在尝试内省 PyTorch 中的向后计算。按照 torch Autograd 机制的描述,我向模型的每个参数添加了向后钩子,并在每个激活的
grad_fn
上添加了钩子。以下代码片段说明了如何将钩子添加到 grad_fn
:
import torch.distributed as dist
def make_hook(grad_fn, note=None):
if grad_fn is not None and grad_fn.name is not None:
def hook(*args, **kwargs):
print(f"[{dist.get_rank()}] {grad_fn.name()} with {len(args)} args "
f"and {len(kwargs)} kwargs [{note or '/'}]")
return hook
else:
return None
def register_hooks_on_grads(grad_fn, make_hook_fn):
if not grad_fn:
return
hook = make_hook_fn(grad_fn)
if hook:
grad_fn.register_hook(hook)
for fn, _ in grad_fn.next_functions:
if not fn:
continue
var = getattr(fn, "variable", None)
if var is None:
register_hooks_on_grads(fn, make_hook_fn)
x = torch.zeros(15, requires_grad=True)
y = x.exp()
z = y.sum()
register_hooks_on_grads(z.grad_fn, make_hook)
运行模型时,我注意到每次调用
hook
都会获得两个参数,但没有关键字参数。对于 AddBackward
函数,第一个参数是两个张量的列表,第二个参数是一个张量的列表。对于 LinearWithGradAccumulationAndAsyncCommunicationBackward
函数也是如此。对于 MeanBackward
函数,两个参数都是各自包含一个张量的列表。
我对此的猜测是,第一个参数可能包含运算符的输入(或使用
ctx.save_for_backward
保存的任何内容),第二个参数包含梯度。我这样说对吗?我可以用 grad_fn(*args)
复制后向计算还是还有更多内容(例如状态)?
不幸的是,我没有找到任何关于此的文档。我很感激任何指向相关文档的指示。
重新查看上述文档后,我注意到在节点上注册钩子指的是
grad_fn.register_hook
,并且节点有两种不同的钩子:一种在节点执行之前执行,一种在节点执行之后执行。在上面的代码中,我只注册了在节点执行后运行的钩子,因此当我运行训练代码并且向后运算符报告错误时,我看不到当前运行的运算符,只能看到最后一个成功的运算符。在我在节点上注册了 prehook 后,它起作用了:
def register_hooks_on_grads(grad_fn, make_hook_fn):
if not grad_fn:
return
prehook, posthook = make_hook_fn(grad_fn)
if prehook:
grad_fn.register_prehook(prehook)
if posthook:
grad_fn.register_hook(posthook)
for fn, _ in grad_fn.next_functions:
if not fn:
continue
var = getattr(fn, "variable", None)
if var is None:
register_hooks_on_grads(fn, make_hook_fn)
pre-hook 在执行向后函数之前执行,并从前面的向后函数获取当前梯度作为输入。 post-hook 在向后函数之后执行,并另外获得
grad_fn
的输出。
事实上,我可以使用
grad_fn(*args, **kwargs)
来复制反向计算,其中 args
和 kwargs
是 prehook 函数的输入。