理解和反思torch.autograd.backward

问题描述 投票:0回答:1

为了定位错误,我正在尝试内省 PyTorch 中的向后计算。按照 torch Autograd 机制的描述,我向模型的每个参数添加了向后钩子,并在每个激活的

grad_fn
上添加了钩子。以下代码片段说明了如何将钩子添加到
grad_fn
:

import torch.distributed as dist


def make_hook(grad_fn, note=None):
    if grad_fn is not None and grad_fn.name is not None:
        def hook(*args, **kwargs):
            print(f"[{dist.get_rank()}] {grad_fn.name()} with {len(args)} args "
                  f"and {len(kwargs)} kwargs [{note or '/'}]")
        return hook
    else:
        return None


def register_hooks_on_grads(grad_fn, make_hook_fn):
    if not grad_fn:
        return
    hook = make_hook_fn(grad_fn)
    if hook:
        grad_fn.register_hook(hook)
    for fn, _ in grad_fn.next_functions:
        if not fn:
            continue
        var = getattr(fn, "variable", None)
        if var is None:
            register_hooks_on_grads(fn, make_hook_fn)


x = torch.zeros(15, requires_grad=True)
y = x.exp()
z = y.sum()
register_hooks_on_grads(z.grad_fn, make_hook)

运行模型时,我注意到每次调用

hook
都会获得两个参数,但没有关键字参数。对于
AddBackward
函数,第一个参数是两个张量的列表,第二个参数是一个张量的列表。对于
LinearWithGradAccumulationAndAsyncCommunicationBackward
函数也是如此。对于
MeanBackward
函数,两个参数都是各自包含一个张量的列表。

我对此的猜测是,第一个参数可能包含运算符的输入(或使用

ctx.save_for_backward
保存的任何内容),第二个参数包含梯度。我这样说对吗?我可以用
grad_fn(*args)
复制后向计算还是还有更多内容(例如状态)?

不幸的是,我没有找到任何关于此的文档。我很感激任何指向相关文档的指示。

python pytorch torch autograd
1个回答
0
投票

重新查看上述文档后,我注意到在节点上注册钩子指的是

grad_fn.register_hook
,并且节点有两种不同的钩子:一种在节点执行之前执行,一种在节点执行之后执行。在上面的代码中,我只注册了在节点执行后运行的钩子,因此当我运行训练代码并且向后运算符报告错误时,我看不到当前运行的运算符,只能看到最后一个成功的运算符。在我在节点上注册了 prehook 后,它起作用了:

def register_hooks_on_grads(grad_fn, make_hook_fn):
    if not grad_fn:
        return
    prehook, posthook = make_hook_fn(grad_fn)
    if prehook:
        grad_fn.register_prehook(prehook)
    if posthook:
        grad_fn.register_hook(posthook)
    for fn, _ in grad_fn.next_functions:
        if not fn:
            continue
        var = getattr(fn, "variable", None)
        if var is None:
            register_hooks_on_grads(fn, make_hook_fn)

pre-hook 在执行向后函数之前执行,并从前面的向后函数获取当前梯度作为输入。 post-hook 在向后函数之后执行,并另外获得

grad_fn
的输出。

事实上,我可以使用

grad_fn(*args, **kwargs)
来复制反向计算,其中
args
kwargs
是 prehook 函数的输入。

© www.soinside.com 2019 - 2024. All rights reserved.