我在尝试在 PyTorch 中加载状态字典时遇到问题。这是我的情况的细分:
我有一个名为 model_fg 的模型
在训练期间,我使用
model_fg.parameters()
返回其参数。
我在训练中使用 Adam 优化器,并使用以下代码行:
state = optim.Adam(params, lr=lr)
其中 params 的类型为 <class 'generator'>
。
在评估阶段,我尝试使用以下方法加载经过训练的参数:
model.load_state_dict(trained_params)
但是我遇到了错误:
TypeError:预期 state_dict 类似于 dict,得到
。
我认为问题源于 params 是生成器而不是字典。我正在考虑将生成器转换为字典,但我不确定如何继续。有人可以提供有关如何解决此问题的指导吗?
提前感谢您的帮助。
我相信问题源于参数是生成器而不是字典
load_state_dict
方法需要一个 dict。生成器包含作为参数数组的参数,而字典包含每个层名称 - 参数关联。将生成器转换为字典是不可能的,因为生成器没有关于哪些参数对应于哪一层的信息,即。一般来说,您无法将每个参数分配给目标模型的正确层。