在 tsai 的 fit() 过程中如何“提前停止”

问题描述 投票:0回答:1

下面是我来自tsai笔记本的时间序列回归问题的示例代码。

from tsai.all import *

dsid = 'AppliancesEnergy'
arch_config = {
    'hidden_size':100, 
    'n_layers':2, 
    'rnn_dropout':0.2, 
    'fc_dropout':0.5, 
    'bidirectional':True
    }

X, y, splits = get_regression_data(dsid, split_data=False)
learn = TSRegressor(
    X, 
    y, 
    splits=splits, 
    bs=128, 
    batch_tfms=[TSStandardize(by_sample=True)], 
    arch=LSTM, 
    arch_config=arch_config, 
    metrics=[mae, rmse], 
    cbs=ShowGraph(), 
    verbose=True)

learn.fit_one_cycle(100, lr_max=1e-3)
learn.plot_metrics()

这个效果很好。我想做的是在 fit() 期间提前停止。 我在 fastai 中找到了回调函数“TerminateOnNaNCallback()”,并像下面这样使用 import fastai 应用了它。

learn.fit_one_cycle(100, lr_max=1e-3, cbs=TerminateOnNaNCallback())

但这不起作用。如果有人知道,请告诉我。 谢谢你。

python deep-learning time-series
1个回答
0
投票

您可以通过 tsai 库导入并使用来自 fastai 的提前停止回调

from fastai.callback.all import EarlyStoppingCallback

然后设置您的回调:

cbs = [EarlyStoppingCallback(), ShowGraph()]

您可以定义参数,例如监控哪个指标/损失以及在没有改善的轮次后终止训练。

© www.soinside.com 2019 - 2024. All rights reserved.