为什么 PyTorch Lightning 模块不保存 logged val loss?模型检查点错误

问题描述 投票:0回答:0

我正在 Kaggle 上运行基于 LSTM 的模型训练。为此,我使用 Pytorch Lightning 和 wandb 记录器。

那是我的模特班:

class Model(pl.LightningModule):
    def __init__(
        self,
        bidirectional: bool = False,
        lstm_layers: int = 1,
        lstm_dropout: float = 0.4,
        fc_dropout: float = 0.4,
        lr: float = 0.01,
        lr_scheduler_patience: int = 2,
    ):
        super().__init__()
        self.lr = lr
        self.save_hyperparameters()

        # LSTM
        self.encoder_lstm = nn.LSTM(
            input_size=row_dim_in_embedding,
            hidden_size=embedding_dim,
            num_layers=lstm_layers,
            bidirectional=bidirectional,
            dropout=lstm_dropout if lstm_layers > 1 else 0,
            batch_first=True,
        )

        # Fully-connected
        num_directions = 2 if bidirectional else 1
        self.fc = nn.Sequential(
            nn.Linear(
                embedding_dim * num_directions, embedding_dim * num_directions * 2
            ),
            nn.ReLU(),
            nn.Dropout(fc_dropout),
            nn.Linear(embedding_dim * num_directions * 2, row_dim_in_embedding),
        )

        self.loss_function = nn.MSELoss()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, patience=self.hparams.lr_scheduler_patience
                ),
                "monitor": "val_loss",
            },
        }

    def forward(self, x, prev_state):
        ...

    def training_step(self, batch, batch_idx):
        loss, _ = self._step(batch)

        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, embeddings = self._step(batch)

        self.log("val_loss", loss)
        
        return {
            'val_loss': loss,
            'preds': embeddings # this is consumed by my custom callback
        }

    def test_step(self, batch, batch_idx):
        loss, _ = self._step(batch)

        self.log("test_loss", loss)

这就是我使用它的方式:

model = Model(
    bidirectional=False,
    lstm_layers=1,
    lstm_dropout=0.4,
    fc_dropout=0.4,
    lr=0.01,
    lr_scheduler_patience=2
)

...

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    every_n_train_steps=100,
    verbose=True
)

trainer = pl.Trainer(
    accelerator='gpu',
    precision=16,
    max_epochs=100,
    callbacks=[early_stopping, checkpoint_callback, lr_monitor, custom_callback],
    log_every_n_steps=50,
    logger=wandb_logger,
    auto_lr_find=True,
)

trainer.tune(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

trainer.fit(model, train_dataloader, val_dataloader)

当我不运行时

trainer.tune(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
trainer.fit
完美运行但是当我运行'trainer.tune'时我得到这样的ModelCheckpoint错误:

MisconfigurationException: `ModelCheckpoint(monitor='val_loss')` could not find the monitored key in the returned metrics: ['train_loss', 'epoch', 'step']. HINT: Did you call `log('val_loss', value)` in the `LightningModule`?

所以即使我登录

val_loss
它也不会被保存。在 Trainer 对象上,我设置了
log_every_n_steps=50
,在 ModelCheckpoint 上,我设置了
every_n_train_steps=100
,所以它似乎应该在 ModelCheckpoint 开始运行时记录“val_loss”。

我在

validation_step
中打印了 val loss,它在运行 ModelCheckpoint 之前得到计算。我还在我的自定义回调中定义了一个
on_train_batch_end
函数来查看保存的训练指标。事实证明,val损失实际上是缺失的。

python machine-learning deep-learning pytorch pytorch-lightning
© www.soinside.com 2019 - 2024. All rights reserved.