这是一个简单的例子。我尝试将网络(Resnet50)分为两部分:
head
和使用tail
的children
。从概念上讲,这应该可行,但事实并非如此。这是为什么?
import torch
import torch.nn as nn
from torchvision.models import resnet50
head = nn.Sequential(*list(resnet.children())[:-2])
tail = nn.Sequential(*list(resnet.children())[-2:])
x = torch.zeros(1, 3, 160, 160)
resnet(x).shape # torch.Size([1, 1000])
head(x).shape # torch.Size([1, 2048, 5, 5])
tail(head(x)).shape # Error: RuntimeError: size mismatch, m1: [2048 x 1], m2: [2048 x 1000] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:136
仅供参考,尾巴只不过是
Sequential(
(0): AdaptiveAvgPool2d(output_size=(1, 1))
(1): Linear(in_features=2048, out_features=1000, bias=True)
)
所以我实际上知道如果我能这样做的话。但是,为什么重塑功能(
view
)不在孩子身上呢?
pool =resnet._modules['avgpool']
fc = resnet._modules['fc']
fc(pool(head(x)).view(1, -1))
您想要做的是将特征提取器与分类器分开。
我应该立即指出的是,Resnet不是顺序模型(顾名思义 - 残差网络 - 它是残差)!
因此将其编译为
nn.Sequential
将不准确。模型定义(按 .children()
排序的层)与该模型的 forward
函数的实际底层实现之间存在差异。
您使用
view(1, -1)
执行的展平并未在所有 torchvision.models.resnet*
模型中注册为图层。相反,它是在 forward
定义中的 this line上执行的:
x = torch.flatten(x, 1)
他们可以将其注册为
__init__
中的一个层,作为 self.flatten = nn.Flatten()
,在 forward
实现中用作 x = self.flatten(x)
。
即便如此,
fc(pool(head(x)).view(1, -1))
与resnet(x)
完全不同(cf.第一点)。
将
nn.Flatten
模块添加到 tail
似乎可以解决您的问题:
import torch
import torch.nn as nn
from torchvision.models import resnet50
resnet = resnet50()
head = nn.Sequential(*list(resnet.children())[:-2])
tail = nn.Sequential(*[list(resnet.children())[-2], nn.Flatten(start_dim=1), list(resnet.children())[-1]])
x = torch.zeros(1, 3, 160, 160)
resnet(x).shape # torch.Size([1, 1000])
head(x).shape # torch.Size([1, 2048, 5, 5])
tail(head(x)).shape # torch.Size([1, 1000])