如何使用批量标准化而不忘记刚刚在 Pytorch 中使用的批量统计信息?

问题描述 投票:0回答:2

我处于一个不寻常的环境中,我不应该使用运行统计数据(因为这会被认为是作弊,例如元学习)。然而,我经常对一组点(实际上是 5 个点)进行前向传递,然后我只想使用之前的统计数据对 1 个点进行评估,但批归一化会忘记它刚刚使用的批统计数据。我尝试对它应该的值进行硬编码,但出现奇怪的错误(即使我取消注释诸如检查尺寸大小之类的 pytorch 代码本身的内容)。 如何对之前的批次统计数据进行硬编码,以便批次范数适用于新的单个数据点,然后重置它们以获取新的下一批数据?

注意:我不想更改批量标准化图层类型。

我尝试过的示例代码:

def set_tracking_running_stats(model): for attr in dir(model): if 'bn' in attr: target_attr = getattr(model, attr) target_attr.track_running_stats = True target_attr.running_mean = torch.nn.Parameter(torch.zeros(target_attr.num_features, requires_grad=False)) target_attr.running_var = torch.nn.Parameter(torch.ones(target_attr.num_features, requires_grad=False)) target_attr.num_batches_tracked = torch.nn.Parameter(torch.tensor(0, dtype=torch.long), requires_grad=False) # target_attr.reset_running_stats() return

我最多的评论错误:

raise ValueError('预期 2D 或 3D 输入(获得 {}D 输入)'
ValueError:预期 2D 或 3D 输入(获得 1D 输入)


IndexError: 维度超出范围(预计在 [-1, 0] 范围内,但得到 1)

相关:

    https://discuss.pytorch.org/t/how-to-use-have-batch-norm-not-forget-batch-statistics-it-just-used/103437
  • 使用 PyTorch 高级库执行 MAML 时何时应该调用 .eval() 和 .train()?
machine-learning deep-learning pytorch pytorch-higher
2个回答
0
投票

解决方案是使用

mdl.train()

它本身使用批量统计:


同样默认情况下,在训练过程中,该层会不断运行对其计算的均值和方差的估计,然后将其用于评估期间的归一化。运行估计值保持默认值
momentum

0.1.

如果 

track_running_stats

设置为

False
,则该层不会继续运行估计,并且也会在评估期间使用批量统计数据。

https://pytorch.org/docs/stable/ generated/torch.nn.BatchNorm2d.html

参考:

https://discuss.pytorch.org/t/how-to-use-have-batch-norm-not-forget-batch-statistics-it-just-used/103437/4


0
投票
batch_norm.reset_running_stats()

您可以在任何给定时间在您的模型上运行此函数(在您的情况下,当包含 5 个训练示例的新批次到达时),它应该可以解决问题:

def reset_all_running_stats(model): for module in model.modules(): if isinstance(module, torch.nn.BatchNorm2d): module.reset_running_stats()

© www.soinside.com 2019 - 2024. All rights reserved.