pytorch 在根据测试数据评估模型时模拟 dropout

问题描述 投票:0回答:1

我们正在使用 CIFAR 10 数据集训练 resnet 模型,我们正在尝试执行以下操作:

在我们训练好模型之后,我们想在模型评估期间模拟 dropout,当我们向它提供测试数据时。 我知道这听起来可能很奇怪,因为 dropout 是一种正则化机制,但我们将其作为实验的一部分进行

我们正在考虑尝试使用

state_dict
的一个选项,创建一个具有原始值的深层副本,然后手动修改它的值。

我们还看到

net.eval()
正在将 dropout 层改为评估模式而不是训练模式,也许有一种方法可以利用这种机制在评估期间模拟 dropout?

我想问问有没有更好的方法来实现我想做的事情?

python deep-learning pytorch
1个回答
1
投票

dropout 模块在评估模式下被禁用,即在您的模型上调用

nn.Module.eval
之后。如果您希望在调用
.eval
后启用它,那么您可以在模型中的每个
nn.Module.train
模块上调用
nn.Dropout
nn.Module.apply
方法使这非常容易。您可以将这一切封装在模型的
train
方法的重写中。

class MyModel(nn.Module):
    ... # your model's implementation
        
    # note that self.eval() just calls self.train(False) which is why
    # we are overriding train
    def train(self, mode=True):
        # recursively applies train mode to self and submodules
        super().train(mode)

        # When mode=false (i.e. .eval() is called), re-enable dropout layers
        # by recursively applying this function to each submodule
        def enable_dropout(mod: nn.Module):
            if isinstance(mod, nn.Dropout):
                mod.train()

        if not mode:
            self.apply(enable_dropout)

        return self


model = MyModel ...

在调用

model.eval()
之后,除 dropout 层之外的所有内容都将表现得就像它们处于评估模式一样。调用
model.train()
将像以前一样工作。

当然,如果您不想覆盖模型,您可以在模型外执行此操作。

# if you don't want to override the train method then you can just do the
# same thing as the above snippet outside of the class method. E.g.

model.eval()

def enable_dropout(mod: nn.Module):
    if isinstance(mod, nn.Dropout):
        mod.train()

model.apply(enable_dropout)
© www.soinside.com 2019 - 2024. All rights reserved.