“在 PyTorch 中加载状态字典时出错:遇到 TypeError”

问题描述 投票:0回答:1

我在尝试在 PyTorch 中加载状态字典时遇到问题。这是我的情况的细分:

  1. 我有一个名为 model_fg 的模型

  2. 在训练期间,我使用

    model_fg.parameters()
    返回其参数。

  3. 我在训练中使用 Adam 优化器,并使用以下代码行:

    state = optim.Adam(params, lr=lr)
    其中 params 的类型为
    <class 'generator'>

  4. 在评估阶段,我尝试使用以下方法加载经过训练的参数:

model.load_state_dict(trained_params)

但是我遇到了错误:

TypeError:预期 state_dict 类似于 dict,得到

我认为问题源于 params 是生成器而不是字典。我正在考虑将生成器转换为字典,但我不确定如何继续。有人可以提供有关如何解决此问题的指导吗?

提前感谢您的帮助。

python deep-learning pytorch neural-network
1个回答
0
投票

我相信问题源于参数是生成器而不是字典

你是对的,你拥有的对象是一个生成器,而

load_state_dict
方法需要一个 dict。生成器包含作为参数数组的参数,而字典包含每个层名称 - 参数关联。将生成器转换为字典是不可能的,因为生成器没有关于哪些参数对应于哪一层的信息,即。一般来说,您无法将每个参数分配给目标模型的正确层。

© www.soinside.com 2019 - 2024. All rights reserved.