我有一个简单的 PyTorch 模型,我正在尝试将其转换为 ONNX 格式。
forward()
函数由对 nn.transformer.encoder()
的调用组成。使用 torch.onnx.export()
成功完成到 ONNX 的转换。但是,当我使用 onnxruntime_test
测试模型时,它会失败,除了特定的输入情况。
我怀疑问题与运行时的动态轴和重塑有关,但我不确定确切的原因。
我在下面提供了一个最小的示例,我在其中创建模型,使用两个不同的张量运行它,将其导出到 ONNX,并尝试使用 ONNX 复制该过程。
任何有关为什么会发生此问题的见解都将不胜感激。谢谢!
#!/usr/bin/env python3
import torch.nn as nn
from torch import Tensor, rand
import torch.onnx
import onnx
onnx_model = 'MicroTest.onnx'
class Test_trans(nn.Module):
def __init__( self, emb_size=100):
super(Test_trans, self).__init__()
self.transformer = nn.Transformer(emb_size, 2, 2, 2, 512, 0.1)
def forward(self, src: Tensor):
return self.transformer.encoder(src)
def process_one_torch(session, ten):
print('Tensor In size:', ten.size(), end='\t')
memory = session(ten)
print('Mem size:', memory.size());
def process_one_onnx(session, npa):
ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(npa)
print('In ortvalue.shape:', ortvalue.shape(), end='\t')
memory = session.run(None, {session.get_inputs()[0].name: ortvalue})
print('ONNX mem.shape:', memory[0].shape)
mini = Test_trans()
c_tensor_12 = rand((12,1,100))
c_tensor_10 = rand((10,1,100))
print('################################# Torch ###################################')
process_one_torch(mini, c_tensor_12)
process_one_torch(mini, c_tensor_10)
torch.onnx.export(mini, # model being run
c_tensor_12, # model input (or a tuple for multiple inputs)
onnx_model, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=False,
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes = {'input' : {0: 'max_len'}})
print('################################# ONNX RT #################################')
import onnxruntime
session = onnxruntime.InferenceSession(onnx_model, providers=["CPUExecutionProvider"])
print('Session inputs:', session.get_inputs()[0])
process_one_onnx(session, c_tensor_12.numpy())
process_one_onnx(session, c_tensor_10.numpy()) #This one crashes
使用我得到的创建的onnx模型运行onnxruntime_test: ` onnxruntime_test MicroTest.onnx 2024-06-12 17:11:22.877402346 [E:onnxruntime:,sequential_executor.cc:514 ExecuteKernel]运行Reshape节点时返回非零状态代码。名称:'/encoder/layers.0/self_attn/Reshape_4' 状态消息:/croot/onnxruntime_1711063034809/work/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:44 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime: :TensorShape&, onnxruntime::TensorShapeVector&, bool) input_shape_size == size 为 false。输入张量无法重新整形为请求的形状。输入形状:{1,1,100},请求形状:{12,2,50}
回溯(最近一次调用最后一次): 文件“/home/if/miniconda3/envs/cpu/bin/onnxruntime_test”,第 11 行,位于 系统退出(主()) ^^^^^^ 文件“/home/if/miniconda3/envs/cpu/lib/python3.11/site-packages/onnxruntime/tools/onnxruntime_test.py”,第 159 行,在 main 中 exit_code, _, _ = run_model(args.model_path, args.num_iters, args.debug, args.profile, args.symbolic_dims) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 文件“/home/if/miniconda3/envs/cpu/lib/python3.11/site-packages/onnxruntime/tools/onnxruntime_test.py”,第 118 行,在 run_model 中 outputs = sess.run([], feeds) # 获取所有输出 ^^^^^^^^^^^^^^^^^^^^ 文件“/home/if/miniconda3/envs/cpu/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py”,第 220 行,运行中 返回 self._sess.run(output_names, input_feed, run_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^ onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : 运行 Reshape 节点时返回非零状态代码。名称:'/encoder/layers.0/self_attn/Reshape_4' 状态消息:/croot/onnxruntime_1711063034809/work/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:44 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime: :TensorShape&, onnxruntime::TensorShapeVector&, bool) input_shape_size == size 为 false。输入张量无法重新整形为请求的形状。输入形状:{1,1,100},请求形状:{12,2,50} ` 我所了解的是内部被重塑了与预期不同的东西。
好的这显然是 ONNX 导出器的版本问题。我尝试了旧版本的
pytorch.onnx.export()
,一切都运行顺利,onnxruntime_test
没有产生任何错误。不幸的是,旧版本只支持opset 15,这导致了其他问题,但主要问题已经解决了。