使用 PyTorch 高级库执行 MAML 时何时应该调用 .eval() 和 .train()?

问题描述 投票:0回答:1

我正在查看 omniglot maml 示例,发现他们的测试代码

顶部有 
net.train()。这似乎是一个错误,因为这意味着元测试中每个任务的统计数据是共享的:

def test(db, net, device, epoch, log): # Crucially in our testing procedure here, we do *not* fine-tune # the model during testing for simplicity. # Most research papers using MAML for this task do an extra # stage of fine-tuning here that should be added if you are # adapting this code for research. net.train() n_test_iter = db.x_test.shape[0] // db.batchsz qry_losses = [] qry_accs = [] for batch_idx in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') task_num, setsz, c_, h, w = x_spt.size() querysz = x_qry.size(1) # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? n_inner_iter = 5 inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) for i in range(task_num): with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. # This adapts the model's meta-parameters to the task. for _ in range(n_inner_iter): spt_logits = fnet(x_spt[i]) spt_loss = F.cross_entropy(spt_logits, y_spt[i]) diffopt.step(spt_loss) # The query loss and acc induced by these parameters. qry_logits = fnet(x_qry[i]).detach() qry_loss = F.cross_entropy( qry_logits, y_qry[i], reduction='none') qry_losses.append(qry_loss.detach()) qry_accs.append( (qry_logits.argmax(dim=1) == y_qry[i]).detach()) qry_losses = torch.cat(qry_losses).mean().item() qry_accs = 100. * torch.cat(qry_accs).float().mean().item() print( f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' ) log.append({ 'epoch': epoch + 1, 'loss': qry_losses, 'acc': qry_accs, 'mode': 'test', 'time': time.time(), })
然而,每当我进行 eval 时,我都会发现我的 MAML 模型出现了分歧(尽管我的测试是在 mini-imagenet 上进行的):

>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>) >maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>) >maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>) >maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>) >maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5941, grad_fn=<NormBackward1>) >maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>) >maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5942, grad_fn=<NormBackward1>) >maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>) >maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>) >maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>) eval_loss=0.9859228551387786, eval_acc=0.5907692521810531 args.meta_learner.lr_inner=0.01 ==== in forward2 >maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(171440.6875, grad_fn=<NormBackward1>) >maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(208426.0156, grad_fn=<NormBackward1>) >maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(17067344., grad_fn=<NormBackward1>) >maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(40371.8125, grad_fn=<NormBackward1>) >maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1.0911e+11, grad_fn=<NormBackward1>) >maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(21.3515, grad_fn=<NormBackward1>) >maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(5.4257e+13, grad_fn=<NormBackward1>) >maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(128.9109, grad_fn=<NormBackward1>) >maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(3994.7734, grad_fn=<NormBackward1>) >maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>) >maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1682896., grad_fn=<NormBackward1>) eval_loss_sanity=nan, eval_acc_santiy=0.20000000298023224
那么我们应该怎样做才能避免这种分歧呢?

注:

    再培训真的很贵。我花了 18 天用 maml 训练一个 5cnn。分布式解决方案在这里真的很有帮助
  • https://github.com/learnables/learn2learn/issues/170
  • 也许只在训练期间使用训练(即使在训练期间进行评估可能是一个好主意,以便将批次统计信息保存在检查点中)
  • 或者下次从一开始就用批量统计数据来训练东西

相关:

  • https://github.com/facebookresearch/higher/issues/107
  • https://discuss.pytorch.org/t/when-should-one-call-eval-and-train-when-doing-maml-with-the-pytorch-higher-library/136022
  • 如何使用批规范而不忘记刚刚在 Pytorch 中使用的批统计?
  • https://discuss.pytorch.org/t/how-does-pytorch-s-batch-norm-know-if-the-forward-pass-its-doing-is-for-inference-or-training/ 16857/10
  • https://stats.stackexchange.com/questions/544048/what-does-the-batch-norm-layer-for-maml-model-agnostic-meta-learning-do-for-du/551153#551153
  • https://github.com/tristandeleu/pytorch-maml/issues/19
machine-learning deep-learning pytorch meta-learning
1个回答
0
投票

TLDR:使用 mdl.train()

,因为它使用批量统计(但推理将不再具有确定性)。
您可能不想在元学习中使用 mdl.eval()


BN 预期行为:

    重要的是,在推理(评估/测试)过程中,使用 running_mean、running_std - 这是根据训练计算得出的(因为他们想要确定性的输出并使用总体统计数据的估计)。
  • 在训练期间,使用批量统计数据,但使用运行平均值来估计总体统计数据。我认为在训练期间使用batch_stats的原因是引入使训练正则化的噪声(噪声鲁棒性)
  • 在元学习中,我认为在测试期间使用批量统计是最好的(而不是计算运行平均值),因为无论如何我们都应该看到新的/tasksdistribution。我们付出的代价是失去决定论。出于好奇,使用元三角形估计的人口统计数据的准确性可能会很有趣。
这可能就是为什么我在使用

mdl.train()

 进行的测试中没有看到差异。

因此,请确保使用

mdl.train()

 (因为它使用批量统计
https://pytorch.org/docs/stable/ generated/torch.nn.BatchNorm2d.html#torch.nn.BatchNorm2d),但要么作弊的新跑步统计数据不会保存或稍后使用。

© www.soinside.com 2019 - 2024. All rights reserved.