如何在 PyTorchLightning 中手动指定检查点路径

问题描述 投票:0回答:2

目前我正在使用 TensorBoardLogger 来满足我的所有需求,它很完美,但我不喜欢它处理检查点命名的方式。我希望能够手动指定文件名和放置检查点的文件夹,我应该怎么做?

pytorch pytorch-lightning
2个回答
2
投票

是的,这要归功于 ModelCheckPoint 回调:

from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath="best_models",
    filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
)
trainer = Trainer(callbacks=[checkpoint_callback])

将在目录中创建一个检查点

best_models/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
例如


0
投票

问题

上面的答案效果很好,但会导致一个恼人的错误,你不知道哪个检查点先出现。例如,每个 epoch 多次设置检查点会让您得到如下名称:

['epoch=0_val_loss=4.18.ckpt',
...
'epoch=2_val_loss=2.29.ckpt', 
'epoch=2_val_loss=2.18.ckpt'] 

想象一下,由于过度拟合或其他原因,你的 val-loss 开始攀升。现在您实际上并不知道哪个检查点是最新的。也许您可以尝试通过查看文件写入时间来手动排序,但可能会出现权限错误、分布式训练系统/策略的奇怪计时问题等。(不幸的是,我的团队就是这种情况)。

解决方案

接受@LukeTheWalker 给出的伟大的答案,再向前迈进一步:

from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = CustomModelCheckpoint(  # Define a simple custom class
    dirpath="best_models",
    filename='{epoch}-{val_loss:.2f}')

class CustomModelCheckpoint(ModelCheckpoint):
    def __init__(self, dirpath, filename):
        self.num_ckpts = 0
        self.file_name = f"ckpt_{self.num_ckpts}" + filename
        super().__init__(dirpath=dirpath, filename=self.file_name)

    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        super().on_save_checkpoint(trainer=trainer, pl_module=pl_module, checkpoint=checkpoint)

        self.num_ckpts += 1
        self.file_name = f"ckpt_{self.num_ckpts}" + "_{epoch}_{val_loss:.2f}"  # Update filename for next checkpoint
        trainer.checkpoint_callback.filename = self.file_name  # Money line! this is where the update gets applied
trainer = Trainer(callbacks=[checkpoint_callback])

这对我来说有点不直观,但是

filename
参数仅在训练器中保存和编辑,而不是在检查点对象(这只是一个字典)上保存和编辑。因为
CustomModelCheckpoint
的超级
__init__
只被调用一次,所以你无法通过更改对象的文件名来更新它。 无论如何,这将在目录中创建检查点
best_models/
:

['ckpt_0_epoch=2-val_loss=4.18.ckpt',
...
'ckpt_23_epoch=2-val_loss=2.18.ckpt',
'ckpt_24_epoch=2-val_loss=2.29.ckpt']

您可以想象通过这种命名方式进行更多的定制。事实上,运行

print(dir(trainer))
,您会发现可以通过这种方式管理很多事情。

© www.soinside.com 2019 - 2024. All rights reserved.