onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument:[ONNXRuntimeError]:2:INVALID_ARGUMENT

问题描述 投票:0回答:0

我试图将火炬模型转换为 onnx 并收到此错误:

raise errors.UnsupportedOperatorError(
torch.onnx.errors.UnsupportedOperatorError: ONNX export failed on an operator with unrecognized namespace 'torch_scatter::scatter_max'. If you are trying to export a custom operator, make sure you registered it with the right domain and version.

所以我用下面的代码注册了scatter_max

def scatter_max(
        src: torch.Tensor, index: torch.Tensor, dim: int = -1,
        out: Optional[torch.Tensor] = None,
        dim_size: Optional[int] = None,fill_value=0) -> Tuple[torch.Tensor, torch.Tensor]:
    return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)

# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("torch_scatter::scatter_max", scatter_max,9)

我得到了 onnx 文件,但是当我运行演示时,出现以下错误:

onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running ScatterElements node. Name:'/hsn/ScatterElements_1' Status Message: Indices vs updates dimensions differs at position=1 1 vs 64

我做错了什么?


python pytorch torch onnx torch-scatter
© www.soinside.com 2019 - 2024. All rights reserved.