我正在使用 pytorch lighting,并且在每个纪元之后,我都会对一个小数据集进行推理,以生成一个我通过权重和偏差进行监控的数字。
我认为最自然的方法是使用带有
on_train_epoch_end
方法的回调来生成绘图。后一种方法需要进行一些推理,因此我想使用trainer.predict
。然而,在执行此操作时,我收到以下错误,所以我想这不是这样做的预期方法。
最小可重现示例:
import lightning as L
from lightning.pytorch.callbacks import Callback
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
class Model(L.LightningModule):
def __init__(self):
super().__init__()
self.f = nn.Linear(10, 1)
def training_step(self, batch, *args):
out = self(batch)
return out.mean() ** 2
def forward(self, x):
return self.f(x)[:, 0]
def train_dataloader(self):
return DataLoader(torch.randn((100, 10)))
def predict_dataloader(self):
return DataLoader(torch.randn((100, 10)))
def predict_step(self, batch):
return self(batch)
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
class CallbackExample(Callback):
def on_train_epoch_end(self, trainer: L.Trainer, model: Model) -> None:
loader = model.predict_dataloader()
trainer.predict(model, loader)
... # save figure to wandb
model = Model()
callback = CallbackExample()
trainer = L.Trainer(max_epochs=2, callbacks=callback, accelerator="mps")
trainer.fit(model)
File ~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:233, in _LoggerConnector.metrics(self)
231 """This function returns either batch or epoch metrics."""
232 on_step = self._first_loop_iter is not None
--> 233 assert self.trainer._results is not None
234 return self.trainer._results.metrics(on_step)
AssertionError:
最自然、最优雅的做法是什么?
使用
.transfer_batch_to_device
解决了它:
class PlotCallback(Callback):
def on_train_epoch_end(self, trainer: L.Trainer, model: Model) -> None:
loader = model.predict_dataloader()
for batch in loader:
batch = model.transfer_batch_to_device(batch, model.device, 0)
model.predict_step(batch)
... # save figure to wandb
恢复
trainer.state
可能有效(不确定是否有副作用)
class Evaluator(L.Callback):
def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
trainer_state = deepcopy(trainer.state)
current_fx_name = pl_module._current_fx_name
results = trainer.predict(pl_module, return_predictions=True)
trainer.state = trainer_state
pl_module._current_fx_name = current_fx_name
pl_module.log("val/xxx", 0, prog_bar=True)