.pth 到 .onnx 的转换破坏了 u2net 模型

问题描述 投票:0回答:1

任务如下:网站需要添加车盘图片去除背景的功能。

我决定使用 rembg 库作为基础:https://github.com/danielgatis/rembg

这个库又是在u2net的基础上工作的:https://github.com/xuebinqin/U-2-Net

但是,标准 u2net 模型会删除外部的所有背景,使磁盘内部的空间保持不变 - 辐条、孔等之间。

谷歌搜索了一下后,我得出的结论是,我可以根据我的特定需求进一步训练 u2net 模型。

动作的算法如下:

  1. 进一步训练模型
  2. 将其加载到rembg中
  3. 使用自定义模型进行裁剪

我成功训练了标准的 u2net 模型,它根据我的需要完美地剪出了黑白面具。

但是,当将模型从 .pth 转换为 .onnx 格式(这是在 rembg 中工作所需的)时,它开始工作得很差。 面具模糊且带有肥皂味。我尝试转换标准的未经训练的 u2net 模型 并在 rembg 中使用它 - 结果是相同的,蒙版模糊,背景裁剪不起作用。

因此,结论是培训成功。 问题出在转换上。

这是我训练的模型的掩码示例。

原图

我训练的模型生成的掩码

我训练的模型转换为.onnx格式后生成的掩码

要在 u2net 中生成掩码,我使用:

python3 u2net_test.py

要在 rembg 中生成掩码,我使用以下命令:

rembg i -om -m u2net_custom -x '{"model_path": "~/.u2net/u2net_custom.onnx"}' 55.jpg 55.png

我尝试转换完成的模型。这是转换代码:

import torch
import torch.onnx
from model.u2net import U2NET

def load_model(model_path, model_class):
    checkpoint = torch.load(model_path, map_location='cpu')
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        model = model_class()
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model = model_class()
    model.eval()
    return model

def convert_to_onnx(model, output_path):
    dummy_input = torch.randn(1, 3, 320, 320)
    torch.onnx.export(model, dummy_input, output_path, opset_version=12,
                          dynamic_axes={'input': {0: 'batch_size', 2: 'height', 3: 'width'},
                                        'output': {0: 'batch_size', 2: 'height', 3: 'width'}})
    print(f"success {output_path}")

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="conversion PyTorch to ONNX")
    parser.add_argument('--model-path', type=str, required=True, help='path to .pth file')
    parser.add_argument('--output-path', type=str, required=True, help='save ONNX file')

    args = parser.parse_args()
    model = load_model(args.model_path, U2NET)
    convert_to_onnx(model, args.output_path)

并尝试在训练过程中保存模型:

        if ite_num % save_frq == 0:
            timestamp = int(time.time())
            filePath = model_dir + model_name+"_%d_%d." % (ite_num, timestamp)

            torch.save(net.state_dict(), filePath + 'pth')

            dummy_input = torch.randn(1, 3, 320, 320)
            net.eval()
            torch.onnx.export(net, dummy_input, filePath + 'onnx', opset_version=12)

            running_loss = 0.0
            running_tar_loss = 0.0
            net.train()  # resume train
            ite_num4val = 0

我尝试更改设置、更改库版本、更改 opset_version 以及 chatGpt 建议的其他所有内容。 结果总是一样的。 转换后模型停止工作。

我犯了什么错误?

python machine-learning onnx rembg pth
1个回答
0
投票

我找到了解决办法。 我继续搜索并发现了这个教程: https://www.kaggle.com/code/bumbleboss/u2net-pth-to-onnx

我按照教程完成了转换:

import torch
import torch.onnx
from model import U2NET

def convert(model_path, output_path):
torch_model = U2NET(3, 1)
torch_model.load_state_dict(torch.load(model_path, 
map_location=torch.device('cpu')), strict=False)  # Добавлен map_location
torch_model.eval()

x = torch.randn(1, 3, 320, 320, requires_grad=True)
torch_out = torch_model(x)

torch.onnx.export(
    torch_model,
    x,
    output_path,
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

print(f"success {output_path}")

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="conversion PyTorch to ONNX")
parser.add_argument('--model-path', type=str, required=True, help='path to .pth file')
parser.add_argument('--output-path', type=str, required=True, help='save ONNX file')

args = parser.parse_args()
convert(args.model_path, args.output_path)

现在 rembg 完美地剪出了汽车轮辋上的背景:

我希望它对某人有帮助!

© www.soinside.com 2019 - 2024. All rights reserved.