使用 Pytorch Lightning 实现更快的 RCNN

问题描述 投票:0回答:1

Faster RCNN 模型对于 model.eval() 和 model.train() 有不同的行为。

model.eval()
模式
model(images, targets)
- 返回预测

仅在

model.train()
model model(images)
时有回波损耗

我正在使用 Pytorch Lightning 并希望在训练过程中获得 torchmetrics.detection.mean_ap 一些 (N) 的训练批次。

如何做到? 我可以通过 val_check_intervalval_check_interval + on_validation_epoch_end 来完成,但除了验证集之外还针对训练批次?

pytorch-lightning faster-rcnn
1个回答
0
投票

顺便说一句,问题可以更有针对性。关键部分是如何使用 PyTorch Lightning 在特定批次间隔的训练过程中计算 mAP。

要实现此目的,您可以创建自定义回调:

class MAPCallback(pl.Callback):
    def __init__(self, every_n_batches):
        self.every_n_batches = every_n_batches
        self.map_metric = MeanAveragePrecision()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if (batch_idx + 1) % self.every_n_batches == 0:
            # Calculate and log mAP here

将此回调添加到您的 Trainer:

trainer = pl.Trainer(callbacks=[MAPCallback(every_n_batches=N)])

这种方法允许您在训练期间以指定的时间间隔计算 mAP,而无需修改主训练循环。

请检查闪电回调。 https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Callback.html#lightning.pytorch.callbacks.Callback.on_train_batch_end

© www.soinside.com 2019 - 2024. All rights reserved.