我使用位置参数和关键字参数定义了一个带有forward(..) 函数的简单 nn.Module:
import torch
import torch.nn as nn
cuda0 = torch.device('cuda:0')
x = torch.tensor([[[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]]).to(device=cuda0)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 1, 2)
).to(device=cuda0)
def forward(self, cond, **kwargs):
if (cond):
return self.net(kwargs['input'])
else:
return torch.tensor(0).to(device=cuda0)
module = MyModule()
module(torch.tensor(True).to(device=cuda0), **{'input': x})
接下来,我尝试将此模块导出到onnx:
torch.onnx.export(module,
args=(torch.tensor(True).to(device=cuda0), {'input': x}),
f='sample.onnx', input_names=['input'], output_names=['output'], export_params=True)
但这会导致错误:
TypeError: forward() takes 2 positional arguments but 3 were given
我想,我正在根据文档这样做:
元组中除最后一个元素之外的所有元素都将作为非关键字传递 参数和命名参数将从最后一个元素开始设置。
https://pytorch.org/docs/stable/onnx.html
我做错了什么?
火炬1.8.0
您可能需要将命名参数排列为包含在元组中的字典,如下所示:
参数 = ( X, { “y”:输入_y, “z”:输入_z } )
参考:https://pytorch.org/docs/stable/onnx_torchscript.html#module-torch.onnx