Python darts 中的 RNN 训练指标

问题描述 投票:0回答:1

我目前正在使用 python darts 训练 RNNModel。为了比较不同的训练模型,我想从 fit 方法中提取 train_loss 和 val_loss 。我该怎么做?我读过一些有关度量集合的内容,但不知道如何使用它。

这是我当前的代码

from darts.models import RNNModel
from darts import TimeSeries

train = # training data as TimeSeries
model = RNNModel(model="LSTM", input_chunk_length=self.past_samples)
model.fit(train)

训练期间控制台中会显示损失,但我不知道如何访问它。

到目前为止,我已尝试在网上查找任何文档,并向 bing chat 和 ChatGPT 寻求帮助。然而他们告诉我使用不存在的

model.history.history["loss"]

python machine-learning recurrent-neural-network metrics
1个回答
0
投票
要从 Darts 中

train_loss

val_loss
 方法中提取 
fit
RNNModel
,您可以使用自定义 PyTorch Lightning 回调。

from darts import TimeSeries from darts.models import RNNModel from pytorch_lightning.callbacks import Callback import matplotlib.pyplot as plt
创建自定义回调来记录每个时期结束时的训练和验证损失。

class LossRecorder(Callback): def __init__(self): self.train_loss_history = [] self.val_loss_history = [] def on_train_epoch_end(self, trainer, pl_module): self.train_loss_history.append(trainer.callback_metrics["train_loss"].item()) self.val_loss_history.append(trainer.callback_metrics["val_loss"].item())
初始化 

LossRecorder

 回调。

loss_recorder = LossRecorder()
初始化 

RNNModel

 并将 
LossRecorder
 回调传递给 
pl_trainer_kwargs
 参数。

model = RNNModel( model="LSTM", pl_trainer_kwargs={"callbacks": [loss_recorder]} )
使用您的训练和验证数据训练模型。

model.fit(train_series, val_series=val_series)
训练后,绘制记录的训练和验证损失。

plt.plot(loss_recorder.train_loss_history, label='Train Loss') plt.plot(loss_recorder.val_loss_history, label='Val Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.show()
    
© www.soinside.com 2019 - 2024. All rights reserved.