在 PyTorch 中,如何计算矩阵乘法相对于前向传播内隐藏状态的梯度?

问题描述 投票:0回答:1

这是我正在研究的模型的简化版本:

class InferContextModel(nn.Module):
    def __init__(self, input_size, context_size, output_size):
        super().__init__()
        self.context_size = context_size
        self.embedding_size = input_size
        self.output_size = output_size

        self.alpha = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
        self.linear_layer = nn.Linear(self.output_size, self.embedding_size)
        self.previous_input = None

    def init_hidden(self, batch_size, device):
        return torch.zeros(batch_size, self.context_size, device=device)

    def recurrence(self, input, hidden):
        context.requires_grad_(True)
        
        if self.previous_input is not None:
            prev_prediction = self.linear(self.previous_input)
            loss = 0.5 * torch.sum((prev_prediction - input) ** 2)
            hidden_grad = torch.autograd.grad(loss, hidden, retain_graph=False)[0]
            
            # Update context without requiring gradients
            with torch.no_grad():
                hidden = hidden - self.alpha * hidden_grad.detach()
                hidden = hidden.detach()

            
        self.previous_input = input.detach()
        return self.output_layer(hidden), hidden

    def forward(self, input, context=None, num_steps=1):
        batch_size = input.shape[1]
        seq_len = input.size(0)
        outputs = torch.empty(
            seq_len, 
            batch_size, 
            self.vocab_size, 
            device=input.device
        )
        
        if context is None:
            context = self.init_hidden(batch_size, input.device)
        
        for i in range(seq_len):
            output = None
            for _ in range(num_steps):
                output, context = self.recurrence(input[i], context)
            outputs[i] = output
        
        return outputs, context

显然这个模型没有多大意义,但我不认为模型功能的复杂性需要解决我无法克服的错误。当我运行模型时,出现错误

Traceback (most recent call last):
  File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/project5_main.py", line 327, in <module>
    losses = estimate_loss(model, eval_iters, train_data, val_data,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/utils.py", line 44, in estimate_loss
    logits, loss = model(X.T, Y)
                   ^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 195, in forward
    logits, context = self.rnn(x)
                      ^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 151, in forward
    output, context = self.recurrence(input[i], context)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 116, in recurrence
    context_grad = torch.autograd.grad(loss, context, retain_graph=False)[0]
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/autograd/__init__.py", line 436, in grad
    result = _engine_run_backward(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

我想做的事情应该比较简单。我只想找到相对于隐藏状态的梯度,但我不希望隐藏状态作为模型参数的一部分进行更新。有什么建议吗?

我已经多次尝试通过多个 sota 法学硕士来解决这个问题,但还没有找到可行的解决方案。

pytorch gradient
1个回答
0
投票

我觉得补充一下应该就够了

context = self.init_hidden(batch_size, input.device).requires_grad_(True)

context
未声明为
nn.Parameter
,因此它不应与模型参数一起受到影响,但基于此张量的计算将记录追溯到它的计算图

© www.soinside.com 2019 - 2024. All rights reserved.