停止多头 Pytorch 模块的梯度流

问题描述 投票:0回答:1

我在 Pytorch 中有一个多头模型,类似于这个:

class Net(nn.Module):
    def __init__(self):
         super(Net, self).__init__()
    
         self.backbone = Backbone()
         self.proxyModule = ProxyModule()

    def forward(self, x):
         backbone_output = self.backbone(x)
         proxy_target = transform_to_targets(backbone_output)
         proxy_output = self.proxyModule(backbone_output)
         return backbone_output, proxy_target, proxy_output 

net = Net()
x,y = get_some_data()
backbone_output, proxy_target, proxy_output = net(x)
backbone_loss = Loss(backbone_output, y)
proxy_loss = Loss(proxy_output, proxy_target)
total_loss = backbone_loss + proxy_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()

基本上,我想通过相同的损失来更新主干模型和代理模型。但是,我不希望 Backbone 模块通过操作

proxy_target = transform_to_targets(backbone_output)
进行更新。那里的目的只是为 proxyModule 生成输出变量。实际上,这类似于 Q-learning 场景。

我考虑了以下修改,但我不确定这是否会按我的预期工作:

class Net(nn.Module):
    def __init__(self):
         super(Net, self).__init__()
    
         self.backbone = Backbone()
         self.proxyModule = ProxyModule()

    def forward(self, x):
         backbone_output = self.backbone(x)

         with torch.set_grad_enabled(False):
              proxy_target = transform_to_targets(backbone_output)

         proxy_output = self.proxyModule(backbone_output)
         return backbone_output, proxy_target, proxy_output 

net = Net()
x,y = get_some_data()

optimizer.zero_grad()
with torch.set_grad_enabled(True):
    backbone_output, proxy_target, proxy_output = net(x)
    backbone_loss = Loss(backbone_output, y)
    proxy_loss = Loss(proxy_output, proxy_target)
    total_loss = backbone_loss + proxy_loss
    optimizer.zero_grad()
    loss.backward()
optimizer.step()

区别在于,现在我已将行

proxy_target = transform_to_targets(backbone_output)
放入上下文管理器中,在其中将梯度计算设置为 false。最近 Pytorch 中的 Autograd 机制变得更加复杂,所以我不确定这是否能达到预期的效果。

python deep-learning pytorch
1个回答
0
投票

文档给出了清晰的例子,这可以达到您想要的效果。我认为维护者有能力编写适当的测试用例来涵盖这一点。

但是如果您感觉不舒服,请做一个小测试:

  1. 检查
    proxy_target.requires_grad
    是否为
    False
  2. 仅返回
    proxy_target
    并在前向函数中创建其余部分作为新张量(它们 100% 对梯度不会产生任何影响)。存储管道之前和之后的完整模型权重,运行一次迭代并检查权重是否有任何差异。
© www.soinside.com 2019 - 2024. All rights reserved.