我们正在使用 CIFAR 10 数据集训练 resnet 模型,我们正在尝试执行以下操作:
在我们训练好模型之后,我们想在模型评估期间模拟 dropout,当我们向它提供测试数据时。 我知道这听起来可能很奇怪,因为 dropout 是一种正则化机制,但我们将其作为实验的一部分进行
我们正在考虑尝试使用
state_dict
的一个选项,创建一个具有原始值的深层副本,然后手动修改它的值。
我们还看到
net.eval()
正在将 dropout 层改为评估模式而不是训练模式,也许有一种方法可以利用这种机制在评估期间模拟 dropout?
我想问问有没有更好的方法来实现我想做的事情?
dropout 模块在评估模式下被禁用,即在您的模型上调用
nn.Module.eval
之后。如果您希望在调用 .eval
后启用它,那么您可以在模型中的每个 nn.Module.train
模块上调用 nn.Dropout
。 nn.Module.apply
方法使这非常容易。您可以将这一切封装在模型的 train
方法的重写中。
class MyModel(nn.Module):
... # your model's implementation
# note that self.eval() just calls self.train(False) which is why
# we are overriding train
def train(self, mode=True):
# recursively applies train mode to self and submodules
super().train(mode)
# When mode=false (i.e. .eval() is called), re-enable dropout layers
# by recursively applying this function to each submodule
def enable_dropout(mod: nn.Module):
if isinstance(mod, nn.Dropout):
mod.train()
if not mode:
self.apply(enable_dropout)
return self
model = MyModel ...
在调用
model.eval()
之后,除 dropout 层之外的所有内容都将表现得就像它们处于评估模式一样。调用 model.train()
将像以前一样工作。
当然,如果您不想覆盖模型,您可以在模型外执行此操作。
# if you don't want to override the train method then you can just do the
# same thing as the above snippet outside of the class method. E.g.
model.eval()
def enable_dropout(mod: nn.Module):
if isinstance(mod, nn.Dropout):
mod.train()
model.apply(enable_dropout)