我处于一个不寻常的环境中,我不应该使用运行统计数据(因为这会被认为是作弊,例如元学习)。然而,我经常对一组点(实际上是 5 个点)进行前向传递,然后我只想使用之前的统计数据对 1 个点进行评估,但批归一化会忘记它刚刚使用的批统计数据。我尝试对它应该的值进行硬编码,但出现奇怪的错误(即使我取消注释诸如检查尺寸大小之类的 pytorch 代码本身的内容)。 如何对之前的批次统计数据进行硬编码,以便批次范数适用于新的单个数据点,然后重置它们以获取新的下一批数据?
注意:我不想更改批量标准化图层类型。
我尝试过的示例代码:
def set_tracking_running_stats(model):
for attr in dir(model):
if 'bn' in attr:
target_attr = getattr(model, attr)
target_attr.track_running_stats = True
target_attr.running_mean = torch.nn.Parameter(torch.zeros(target_attr.num_features, requires_grad=False))
target_attr.running_var = torch.nn.Parameter(torch.ones(target_attr.num_features, requires_grad=False))
target_attr.num_batches_tracked = torch.nn.Parameter(torch.tensor(0, dtype=torch.long), requires_grad=False)
# target_attr.reset_running_stats()
return
我最多的评论错误:
raise ValueError('预期 2D 或 3D 输入(获得 {}D 输入)'
ValueError:预期 2D 或 3D 输入(获得 1D 输入)和
IndexError: 维度超出范围(预计在 [-1, 0] 范围内,但得到 1)
相关:
解决方案是使用
mdl.train()
它本身使用批量统计:
同样默认情况下,在训练过程中,该层会不断运行对其计算的均值和方差的估计,然后将其用于评估期间的归一化。运行估计值保持默认值https://pytorch.org/docs/stable/ generated/torch.nn.BatchNorm2d.htmlhttps://discuss.pytorch.org/t/how-to-use-have-batch-norm-not-forget-batch-statistics-it-just-used/103437/4momentum
0.1.
如果track_running_stats
设置为
,则该层不会继续运行估计,并且也会在评估期间使用批量统计数据。False
batch_norm.reset_running_stats()
。
您可以在任何给定时间在您的模型上运行此函数(在您的情况下,当包含 5 个训练示例的新批次到达时),它应该可以解决问题:def reset_all_running_stats(model):
for module in model.modules():
if isinstance(module, torch.nn.BatchNorm2d):
module.reset_running_stats()