我正在尝试从 this repo 运行一个脚本来测试 PyTorch 模型。我只是使用其默认值(使用到
python test.py
的正确路径)运行它。但是运行的时候却报如下错误。问题是什么?怎么解决?
这是
best.pt
中的片段:
test.py
错误:
device = 'cuda:' + str(opt.device) if opt.device != 'cpu' else 'cpu'
model = XLSR(opt.SR_rate)
# load pretrained model
if opt.model.endswith('.pt') and os.path.exists(opt.model):
model.load_state_dict(torch.load(opt.model, map_location=device))
else:
model.load_state_dict(torch.load(os.path.join(opt.save_dir, 'best.pt'), map_location=device))