一些网络采用多个不同维度的张量作为输入。
使用 torch.jit.trace 似乎由于内部处理错误而失败。
这是一个最小的可重现示例:
import torch
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear1 = torch.nn.Linear(100, 200)
self.linear2 = torch.nn.Linear(50, 200)
self.activation = torch.nn.ReLU()
self.softmax = torch.nn.Softmax()
def forward(self, t):
x1, x2 = t
x = self.linear1(x1) + self.linear2(x2)
x = self.activation(x)
x = self.softmax(x)
return x
model = SimpleModel()
sample_input = (torch.rand(1, 100), torch.rand(1, 50))
# This works as intended
output = model(sample_input)
# This breaks
traced_model = torch.jit.trace(model, sample_input)
使用 torch==1.9.0 会产生以下错误消息:
/home/user/test_error.py:17: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
x = self.softmax(x)
Traceback (most recent call last):
File "/home/user/test_error.py", line 28, in <module>
traced_model = torch.jit.trace(model, sample_input)
File "/home/antonio/Sources/unicc/face_detection/poc_pytorch_mobile/.venv/lib/python3.9/site-packages/torch/jit/_trace.py", line 735, in trace
return trace_module(
File "/home/antonio/Sources/unicc/face_detection/poc_pytorch_mobile/.venv/lib/python3.9/site-packages/torch/jit/_trace.py", line 952, in trace_module
module._c._create_method_from_trace(
File "/home/antonio/Sources/unicc/face_detection/poc_pytorch_mobile/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/antonio/Sources/unicc/face_detection/poc_pytorch_mobile/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given
显然,只需将 Sample_input 包装在元组中即可解决问题:
traced_model = torch.jit.trace(model, (sample_input,))
如果模块采用单个张量作为输入,则不需要。