torch.load(ml_model) 我收到 AttributeError: Can't get attribute 'ResNet1D' on <module '__main__'>

问题描述 投票:0回答:3

我已经使用 Google Colab 在名为 model_prep.py 的文件中成功训练了卷积神经网络模型。该模型的准确率达到 92%。现在我对模型很满意,我已经使用 pyTorch 来保存我的模型。

torch.save(model, '/content/drive/MyDrive/myModel.pt')

我对此的理解是,一旦模型经过充分训练,我可以使用 pyTorch 保存训练后的模型,然后将其加载到未来的项目中以对新数据进行预测。因此,我创建了一个单独的 test.py 文件,在其中加载经过训练的模型,如下所示,

model = torch.load('/content/drive/MyDrive/myModel.pt')
model.eval()

但是在新的 test.py 文件中,我收到一条错误消息

AttributeError: Can't get attribute 'ResNet1D' on <module '__main__'>

尽管在创建训练模型(model_prep.py)的同一个笔记本中加载模型时不会发生此错误。仅当将模型加载到没有模型架构的单独笔记本中时,才会出现此错误。我该如何解决这个问题?我想将经过训练的模型加载到一个新的单独文件中以对新数据执行。有人可以提出解决方案吗?

将来,我想使用 tkinter 创建一个 GUI,并部署经过训练的模型,以使用 tkinter 文件中的新数据检查预测。这可能吗?

python machine-learning pytorch google-colaboratory
3个回答
2
投票

正如 Pytorch 博客(here)所述,以这种方式保存模型将使用 Python 的 pickle 模块保存整个模块。这种方法的缺点是序列化数据绑定到特定的类以及保存模型时使用的确切目录结构。原因是pickle不保存模型类本身。相反,它保存包含该类的文件的路径,该路径在加载时使用。因此,您的代码在其他项目中使用或重构后可能会以各种方式损坏。

我用 TorchScript 方法修复了这个问题。我们保存模型

model_scripted = torch.jit.script(model)# Export to TorchScript

model_scripted.save('model_scripted.pt') # Save

并加载它

model = torch.jit.load('model_scripted.pt')
model.eval()

更多详情这里


0
投票
  1. 即使我也面临着同样的错误。这试图说的是通过调用类来创建模型的实例,然后执行
    torch.load()
  2. 如果转到 PyTorch 的 Saving and Loading Models 上的博客之一,在 load 部分,您可以清楚地看到这一行
    # Model class must be defined somewhere
  3. 因此,我建议,在您的
    test.py
    文件中尝试定义模型类,就像您在
    train.py
    中所做的那样(猜测这是您创建模型的文件名),然后如下所示加载。
model = ModelClass()
model = torch.load(PATH, , map_location=torch.device('cpu')) #<--- if current device is 'CPU'
model.eval() #<---- To prevent it from going to retraining mode.

0
投票

保存模型仅限于训练期间使用的确切目录结构。 如果您希望能够在任何地方加载模型,请按照以下步骤操作:

  1. 按照标准 PyTorch 步骤(https://pytorch.org/tutorials/recipes/recipes/ saving_and_loading_models_for_inference.html#)加载模型在训练的同一目录中(以便导入与训练期间使用的导入相匹配)保存并加载整个模型
  2. 模型加载后:
torch.save(model.state_dict(), PATH) # save the model weights
  1. 将模型加载到您想要使用的任何位置:
model = ResNet1D() # instantiate model class / definition
model.load_state_dict(torch.load(PATH))
© www.soinside.com 2019 - 2024. All rights reserved.