如何从 PyTorch 中的 TensorFlow 检查点加载 Adam 值? PyTorch 的优化器 state_dict 的结构是什么?

问题描述 投票:0回答:0

我需要将模型从 TF1 移植到 PyTorch。除了如何传递 Adam 值外,我几乎已经想通了所有事情。

准确地说,TF 检查点中的每个权重都有一个关联的

Adam
Adam_1
值,我假设它们是 PyTorch 优化器中的
exp_avg
exp_avg_sq
state_dict

问题是 PyTorch 优化器

state_dict
没有我理解的结构。我们有:

state_dict["optimizer"]["state"][0]["exp_avg"].shape = torch.Size([150697984])

state_dict["optimizer"]["state"][1]["exp_avg"].shape = torch.Size([25327])

总共有

150 723 311
值。但是,如果我计算
state_dict["model"]
中的元素数量,我会得到
159 319 780
,即使我忽略 BatchNorm 参数,它仍然是
159 213 487
,所以比优化器值多了大约 900 万。

还有一个小问题,优化器值存储在一个长列表(一维张量)中,我不知道如何将该列表中的位置映射到模型中的参数,我需要将其映射到 TF 检查点,因为它存储嵌套在权重旁边的 Adam 值。

(同样的问题与优化器值被分成两个块一样。)

我找不到关于这个映射的信息,也对数字不匹配感到困惑。

tensorflow deep-learning pytorch
© www.soinside.com 2019 - 2024. All rights reserved.