如何将使用 nn.DataParallel 训练和保存的模型中的检查点加载到不使用 nn.DataParallel 的模型上?我尝试删除“模块”。来自 state_dict,但我现在遇到了不同的错误。这是 ResNet-50 检查点的链接。
from torchvision.models import ResNet50_Weights, resnet50
# Load the model
model = resnet50()
checkpoint_path = 'C:/res50-debiased.pth.tar'
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['state_dict']
# creating new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
这会产生错误
RuntimeError: Error(s) in loading state_dict for ResNet:
Unexpected key(s) in state_dict: "bn1.aux_bn.weight", "bn1.aux_bn.bias", "bn1.aux_bn.running_mean", "bn1.aux_bn.running_var", "bn1.aux_bn.num_batches_tracked", "layer1.0.bn1.aux_bn.weight", "layer1.0.bn1.aux_bn.bias", "layer1.0.bn1.aux_bn.running_mean", "layer1.0.bn1.aux_bn.running_var", "layer1.0.bn1.aux_bn.num_batches_tracked", "layer1.0.bn2.aux_bn.weight", "layer1.0.bn2.aux_bn.bias", "layer1.0.bn2.aux_bn.running_mean", "layer1.0.bn2.aux_bn.running_var", "layer1.0.bn2.aux_bn.num_batches_tracked", "layer1.0.bn3.aux_bn.weight", "layer1.0.bn3.aux_bn.bias", "layer1.0.bn3.aux_bn.running_mean", "layer1.0.bn3.aux_bn.running_var", "layer1.0.bn3.aux_bn.num_batches_tracked", "layer1.0.downsample.1.aux_bn.weight", "layer1.0.downsample.1.aux_bn.bias", "layer1.0.downsample.1.aux_bn.running_mean", "layer1.0.downsample.1.aux_bn.running_var", "layer1.0.downsample.1.aux_bn.num_batches_tracked", "layer1.1.bn1.aux_bn.weight", "layer1.1.bn1.aux_bn.bias", "layer1.1.bn1.aux_bn.running_mean", "layer1.1.bn1.aux_bn.running_var", "layer1.1.bn1.aux_bn.num_batches_tracked", "layer1.1.bn2.aux_bn.weight", "layer1.1.bn2.aux_bn.bias", "layer1.1.bn2.aux_bn.running_mean", "layer1.1.bn2.aux_bn.running_var", "layer1.1.bn2.aux_bn.num_batches_tracked", "layer1.1.bn3.aux_bn.weight", "layer1.1.bn3.aux_bn.bias",
正常加载是这样的
# Load the model
model = resnet50()
checkpoint_path = 'C:/res50-debiased.pth.tar'
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['state_dict']
model.load_state_dict(state_dict)
给出错误
Unexpected key(s) in state_dict: "module.conv1.weight",
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", ...
Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.bn1.num_batches_tracked", "module.bn1.aux_bn.weight", "module.bn1.aux_bn.bias", "module.bn1.aux_bn.running_mean", "module.bn1.aux_bn.running_var", "module.bn1.aux_bn.num_batches_tracked", "module.layer1.0.conv1.weight", "module.layer1.0.bn1.weight", "module.layer1.0.bn1.bias", "module.layer1.0.bn1.running_mean", "module.layer1.0.bn1.running_var", "module.layer1.0.bn1.num_batches_tracked", "module.layer1.0.bn1.aux_bn.weight", "module.layer1.0.bn1.aux_bn.bias", "module.layer1.0.bn1.aux_bn.running_mean", "module.layer1.0.bn1.aux_bn.running_var", "module.layer1.0.bn1.aux_bn.num_batches_tracked", "module.layer1.0.conv2.weight", "module.layer1.0.bn2.weight", "module.layer1.0.bn2.bias", "module.layer1.0.bn2.running_mean", "module.layer1.0.bn2.running_var", "module.layer1.0.bn2.num_batches_tracked", "module.layer1.0.bn2.aux_bn.weight", "module.layer1.0.bn2.aux_bn.bias", "module.layer1.0.bn2.aux_bn.running_mean", "module.layer1.0.bn2.aux_bn.running_var", "module.layer1.0.bn2.aux_bn.num_batches_tracked", "module.layer1.0.conv3.weight", "module.layer1.0.bn3.weight", "module.layer1.0.bn3.bias", "module.layer1.0.bn3.running_mean", "module.layer1.0.bn3.running_var", "module.layer1.0.bn3.num_batches_tracked", "module.layer1.0.bn3.aux_bn.weight", "module.layer1.0.bn3.aux_bn.bias", "module.layer1.0.bn3.aux_bn.running_mean", "module.layer1.0.bn3.aux_bn.running_var", "module.layer1.0.bn3.aux_bn.num_batches_tracked", "module.layer1.0.downsample.0.weight", "module.layer1.0.downsample.1.weight", "module.layer1.0.downsample.1.bias", "module.layer1.0.downsample.1.running_mean", "module.layer1.0.downsample.1.running_var", "module.
非常感谢。
您做了正确的事情,删除了
"module."
前缀,但剩下的问题来自于这个 resnet50
模型是使用自定义标准化层初始化的,您可以在 here 看到它的使用。该层在 aux_bn.py
中定义,并产生您看到的类型为 "bn*.aux_bn"
的键。
代码应该在正确的初始化下运行:
checkpoint = torch.load(checkpoint_path)
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
model = resnet50(num_classes=1_000, norm_layer=MixBatchNorm2d)
model.load_state_dict(state_dict)