单独遍历模型的每一层会得到与使用前向不同的结果

问题描述 投票:0回答:1

我想修改预训练的 resnet 的一些中间层结果。我注意到,当仅迭代特定块的层时,准确性会下降。经过进一步检查,我发现如果我迭代或使用内置的前向函数,中间结果会发生变化。

net = resnet18()
before = torch.nn.Sequential(*list(net.children())[:7])
middle = list(net.children())[7]

middle
是一个基本块,看起来像这样:

Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (downsample): Sequential(
      (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): BasicBlock(
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

我想做的是从第二个

BasicBlock['conv1']
获得中间结果。这是我写的:

x = self.before(x)
bb1 = list(self.middle.children())[0]
x = bb1(x)
bb2 = list(self.middle.children())[1]

print(bb2(x))
for i, layer in enumerate(bb2.children()):
    x = layer(x)
    if i == 0:
        z = copy.copy(x)

print(x)

对于我得到的第一个打印语句:

tensor([[[[1.1271e-01, 1.1205e-01],
          [1.6054e-01, 1.4965e-01]],...)

对于第二个打印语句,我得到:

tensor([[[[ 0.0533,  0.0498],
          [ 0.0607,  0.0574]],...)

以我的理解,它们应该是完全相同的。这里发生了什么?

python machine-learning deep-learning pytorch
1个回答
0
投票

如果您打算在保持准确性的同时修改中间结果,则可以在迭代各层或操作中间输出之前使用 net.eval() 将网络设置为评估模式。这确保了 BatchNormalization 层使用其学习到的统计数据,并且 Dropout 层在前向传递期间不会应用 dropout。

修改您的代码 进口火炬 导入 torchvision.models 作为模型

net = models.resnet18(预训练=True) net.eval() # 将网络设置为评估模式

之前 = torch.nn.Sequential(*list(net.children())[:7]) 中间 = 列表(net.children())[7]

© www.soinside.com 2019 - 2024. All rights reserved.