任务如下:网站需要添加车盘图片去除背景的功能。
我决定使用 rembg 库作为基础:https://github.com/danielgatis/rembg
这个库又是在u2net的基础上工作的:https://github.com/xuebinqin/U-2-Net
但是,标准 u2net 模型会删除外部的所有背景,使磁盘内部的空间保持不变 - 辐条、孔等之间。
谷歌搜索了一下后,我得出的结论是,我可以根据我的特定需求进一步训练 u2net 模型。
动作的算法如下:
我成功训练了标准的 u2net 模型,它根据我的需要完美地剪出了黑白面具。
但是,当将模型从 .pth 转换为 .onnx 格式(这是在 rembg 中工作所需的)时,它开始工作得很差。 面具模糊且带有肥皂味。我尝试转换标准的未经训练的 u2net 模型 并在 rembg 中使用它 - 结果是相同的,蒙版模糊,背景裁剪不起作用。
因此,结论是培训成功。 问题出在转换上。
这是我训练的模型的掩码示例。
要在 u2net 中生成掩码,我使用:
python3 u2net_test.py
要在 rembg 中生成掩码,我使用以下命令:
rembg i -om -m u2net_custom -x '{"model_path": "~/.u2net/u2net_custom.onnx"}' 55.jpg 55.png
我尝试转换完成的模型。这是转换代码:
import torch
import torch.onnx
from model.u2net import U2NET
def load_model(model_path, model_class):
checkpoint = torch.load(model_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model = model_class()
model.load_state_dict(checkpoint['state_dict'])
else:
model = model_class()
model.eval()
return model
def convert_to_onnx(model, output_path):
dummy_input = torch.randn(1, 3, 320, 320)
torch.onnx.export(model, dummy_input, output_path, opset_version=12,
dynamic_axes={'input': {0: 'batch_size', 2: 'height', 3: 'width'},
'output': {0: 'batch_size', 2: 'height', 3: 'width'}})
print(f"success {output_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="conversion PyTorch to ONNX")
parser.add_argument('--model-path', type=str, required=True, help='path to .pth file')
parser.add_argument('--output-path', type=str, required=True, help='save ONNX file')
args = parser.parse_args()
model = load_model(args.model_path, U2NET)
convert_to_onnx(model, args.output_path)
并尝试在训练过程中保存模型:
if ite_num % save_frq == 0:
timestamp = int(time.time())
filePath = model_dir + model_name+"_%d_%d." % (ite_num, timestamp)
torch.save(net.state_dict(), filePath + 'pth')
dummy_input = torch.randn(1, 3, 320, 320)
net.eval()
torch.onnx.export(net, dummy_input, filePath + 'onnx', opset_version=12)
running_loss = 0.0
running_tar_loss = 0.0
net.train() # resume train
ite_num4val = 0
我尝试更改设置、更改库版本、更改 opset_version 以及 chatGpt 建议的其他所有内容。 结果总是一样的。 转换后模型停止工作。
我犯了什么错误?
我找到了解决办法。 我继续搜索并发现了这个教程: https://www.kaggle.com/code/bumbleboss/u2net-pth-to-onnx
我按照教程完成了转换:
import torch
import torch.onnx
from model import U2NET
def convert(model_path, output_path):
torch_model = U2NET(3, 1)
torch_model.load_state_dict(torch.load(model_path,
map_location=torch.device('cpu')), strict=False) # Добавлен map_location
torch_model.eval()
x = torch.randn(1, 3, 320, 320, requires_grad=True)
torch_out = torch_model(x)
torch.onnx.export(
torch_model,
x,
output_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
print(f"success {output_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="conversion PyTorch to ONNX")
parser.add_argument('--model-path', type=str, required=True, help='path to .pth file')
parser.add_argument('--output-path', type=str, required=True, help='save ONNX file')
args = parser.parse_args()
convert(args.model_path, args.output_path)
现在 rembg 完美地剪出了汽车轮辋上的背景:
我希望它对某人有帮助!