这是我正在研究的模型的简化版本:
class InferContextModel(nn.Module):
def __init__(self, input_size, context_size, output_size):
super().__init__()
self.context_size = context_size
self.embedding_size = input_size
self.output_size = output_size
self.alpha = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
self.linear_layer = nn.Linear(self.output_size, self.embedding_size)
self.previous_input = None
def init_hidden(self, batch_size, device):
return torch.zeros(batch_size, self.context_size, device=device)
def recurrence(self, input, hidden):
context.requires_grad_(True)
if self.previous_input is not None:
prev_prediction = self.linear(self.previous_input)
loss = 0.5 * torch.sum((prev_prediction - input) ** 2)
hidden_grad = torch.autograd.grad(loss, hidden, retain_graph=False)[0]
# Update context without requiring gradients
with torch.no_grad():
hidden = hidden - self.alpha * hidden_grad.detach()
hidden = hidden.detach()
self.previous_input = input.detach()
return self.output_layer(hidden), hidden
def forward(self, input, context=None, num_steps=1):
batch_size = input.shape[1]
seq_len = input.size(0)
outputs = torch.empty(
seq_len,
batch_size,
self.vocab_size,
device=input.device
)
if context is None:
context = self.init_hidden(batch_size, input.device)
for i in range(seq_len):
output = None
for _ in range(num_steps):
output, context = self.recurrence(input[i], context)
outputs[i] = output
return outputs, context
显然这个模型没有多大意义,但我不认为模型功能的复杂性需要解决我无法克服的错误。当我运行模型时,出现错误
Traceback (most recent call last):
File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/project5_main.py", line 327, in <module>
losses = estimate_loss(model, eval_iters, train_data, val_data,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/utils.py", line 44, in estimate_loss
logits, loss = model(X.T, Y)
^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 195, in forward
logits, context = self.rnn(x)
^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 151, in forward
output, context = self.recurrence(input[i], context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 116, in recurrence
context_grad = torch.autograd.grad(loss, context, retain_graph=False)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/autograd/__init__.py", line 436, in grad
result = _engine_run_backward(
^^^^^^^^^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
我想做的事情应该比较简单。我只想找到相对于隐藏状态的梯度,但我不希望隐藏状态作为模型参数的一部分进行更新。有什么建议吗?
我已经多次尝试通过多个 sota 法学硕士来解决这个问题,但还没有找到可行的解决方案。
我觉得补充一下应该就够了
context = self.init_hidden(batch_size, input.device).requires_grad_(True)
context
未声明为 nn.Parameter
,因此它不应与模型参数一起受到影响,但基于此张量的计算将记录追溯到它的计算图