目前我正在使用 TensorBoardLogger 来满足我的所有需求,它很完美,但我不喜欢它处理检查点命名的方式。我希望能够手动指定文件名和放置检查点的文件夹,我应该怎么做?
是的,这要归功于 ModelCheckPoint 回调:
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
dirpath="best_models",
filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
)
trainer = Trainer(callbacks=[checkpoint_callback])
将在目录中创建一个检查点
best_models/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
例如
上面的答案效果很好,但会导致一个恼人的错误,你不知道哪个检查点先出现。例如,每个 epoch 多次设置检查点会让您得到如下名称:
['epoch=0_val_loss=4.18.ckpt',
...
'epoch=2_val_loss=2.29.ckpt',
'epoch=2_val_loss=2.18.ckpt']
想象一下,由于过度拟合或其他原因,你的 val-loss 开始攀升。现在您实际上并不知道哪个检查点是最新的。也许您可以尝试通过查看文件写入时间来手动排序,但可能会出现权限错误、分布式训练系统/策略的奇怪计时问题等。(不幸的是,我的团队就是这种情况)。
接受@LukeTheWalker 给出的伟大的答案,再向前迈进一步:
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = CustomModelCheckpoint( # Define a simple custom class
dirpath="best_models",
filename='{epoch}-{val_loss:.2f}')
class CustomModelCheckpoint(ModelCheckpoint):
def __init__(self, dirpath, filename):
self.num_ckpts = 0
self.file_name = f"ckpt_{self.num_ckpts}" + filename
super().__init__(dirpath=dirpath, filename=self.file_name)
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
super().on_save_checkpoint(trainer=trainer, pl_module=pl_module, checkpoint=checkpoint)
self.num_ckpts += 1
self.file_name = f"ckpt_{self.num_ckpts}" + "_{epoch}_{val_loss:.2f}" # Update filename for next checkpoint
trainer.checkpoint_callback.filename = self.file_name # Money line! this is where the update gets applied
trainer = Trainer(callbacks=[checkpoint_callback])
这对我来说有点不直观,但是
filename
参数仅在训练器中保存和编辑,而不是在检查点对象(这只是一个字典)上保存和编辑。因为 CustomModelCheckpoint
的超级 __init__
只被调用一次,所以你无法通过更改对象的文件名来更新它。 无论如何,这将在目录中创建检查点 best_models/
:
['ckpt_0_epoch=2-val_loss=4.18.ckpt',
...
'ckpt_23_epoch=2-val_loss=2.18.ckpt',
'ckpt_24_epoch=2-val_loss=2.29.ckpt']
您可以想象通过这种命名方式进行更多的定制。事实上,运行
print(dir(trainer))
,您会发现可以通过这种方式管理很多事情。