每个纪元后的 PyTorch 闪电推理

问题描述 投票:0回答:2

我正在使用 pytorch lighting,并且在每个纪元之后,我都会对一个小数据集进行推理,以生成一个我通过权重和偏差进行监控的数字。

我认为最自然的方法是使用带有

on_train_epoch_end
方法的回调来生成绘图。后一种方法需要进行一些推理,因此我想使用
trainer.predict
。然而,在执行此操作时,我收到以下错误,所以我想这不是这样做的预期方法。

最小可重现示例:

import lightning as L
from lightning.pytorch.callbacks import Callback

import torch
from torch.utils.data import DataLoader
from torch import nn, optim

class Model(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.f = nn.Linear(10, 1)
        
    def training_step(self, batch, *args):
        out = self(batch)
        return out.mean() ** 2
    
    def forward(self, x):
        return self.f(x)[:, 0]

    def train_dataloader(self):
        return DataLoader(torch.randn((100, 10)))
    
    def predict_dataloader(self):
        return DataLoader(torch.randn((100, 10)))
    
    def predict_step(self, batch):
        return self(batch)
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
class CallbackExample(Callback):
    def on_train_epoch_end(self, trainer: L.Trainer, model: Model) -> None:
        loader = model.predict_dataloader()
        trainer.predict(model, loader)
        
        ... # save figure to wandb

model = Model()
callback = CallbackExample()
trainer = L.Trainer(max_epochs=2, callbacks=callback, accelerator="mps")

trainer.fit(model)
File ~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:233, in _LoggerConnector.metrics(self)
    231 """This function returns either batch or epoch metrics."""
    232 on_step = self._first_loop_iter is not None
--> 233 assert self.trainer._results is not None
    234 return self.trainer._results.metrics(on_step)

AssertionError: 

最自然、最优雅的做法是什么?

python pytorch pytorch-lightning
2个回答
1
投票

使用

.transfer_batch_to_device
解决了它:

class PlotCallback(Callback):
    def on_train_epoch_end(self, trainer: L.Trainer, model: Model) -> None:
        loader = model.predict_dataloader()
        for batch in loader:
            batch = model.transfer_batch_to_device(batch, model.device, 0)
            model.predict_step(batch)
        
        ... # save figure to wandb

0
投票

恢复

trainer.state
可能有效(不确定是否有副作用)

class Evaluator(L.Callback):
    def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
        trainer_state = deepcopy(trainer.state)
        current_fx_name = pl_module._current_fx_name
        results = trainer.predict(pl_module, return_predictions=True)
        trainer.state = trainer_state
        pl_module._current_fx_name = current_fx_name
        pl_module.log("val/xxx", 0, prog_bar=True)
© www.soinside.com 2019 - 2024. All rights reserved.