我想微调 PyTorch 中的对象检测器。为此,我使用了本教程:
https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
但是,FastRCNN 模型不适合我的用例,因此我对 SSDLight 进行了微调。我写这段代码是为了设置一个新的分类头:
from functools import partial
from torchvision.models.detection import _utils as det_utils
from torchvision.models.detection.ssdlite import SSDLiteClassificationHead
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True)
in_channels = det_utils.retrieve_out_channels(model.backbone, (320, 320))
num_anchors = model.anchor_generator.num_anchors_per_location()
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
num_classes = 2
model.head.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)
由于我的模型表现不佳,想请问社区上面的代码是否正确?
提前致谢。
如果您的目标是创建具有自定义 num_classes 的模型,那么您可以:
如下:
num_classes = 2
# Step 1.
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=False, num_classes=num_classes)
checkpoint = torch.load(default_pretrained_model_path) # in windows, you could check the model here C:\Users\user\.cache\torch\hub\checkpoints
# Step 2, load the model state_dict and the default model's state_dict
mstate_dict = model.state_dict()
cstate_dict = torch.load(args.weights)
# Step 3.
for k in mstate_dict.keys():
if mstate_dict[k].shape != cstate_dict[k].shape:
print('key {} will be removed, orishape: {}, training shape: {}'.format(k, cstate_dict[k].shape, mstate_dict[k].shape))
cstate_dict.pop(k)
# Step 4.
model.load_state_dict(cstate_dict, strict=False)
希望对你有帮助,加油~
所以,这是我第一次做这种事情,但我得到了很好的结果:
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(num_classes=num_classes, weights_backbone='DEFAULT', trainable_backbone_layers=0)
所以我只使用现有的骨干网,从头开始其余部分,不训练骨干网。与问题中的想法和 Briliantn 的answer相比,达到相似点所需的训练时间至少要少 10 倍 - 并且通过冻结主干,您可以大幅增加批量大小,从而进一步加快训练速度。一旦模型在冻结的主干上停止改进(以较小的学习率),我将解冻主干并以非常小的学习率进行更多训练。