下面是我来自tsai笔记本的时间序列回归问题的示例代码。
from tsai.all import *
dsid = 'AppliancesEnergy'
arch_config = {
'hidden_size':100,
'n_layers':2,
'rnn_dropout':0.2,
'fc_dropout':0.5,
'bidirectional':True
}
X, y, splits = get_regression_data(dsid, split_data=False)
learn = TSRegressor(
X,
y,
splits=splits,
bs=128,
batch_tfms=[TSStandardize(by_sample=True)],
arch=LSTM,
arch_config=arch_config,
metrics=[mae, rmse],
cbs=ShowGraph(),
verbose=True)
learn.fit_one_cycle(100, lr_max=1e-3)
learn.plot_metrics()
这个效果很好。我想做的是在 fit() 期间提前停止。 我在 fastai 中找到了回调函数“TerminateOnNaNCallback()”,并像下面这样使用 import fastai 应用了它。
learn.fit_one_cycle(100, lr_max=1e-3, cbs=TerminateOnNaNCallback())
但这不起作用。如果有人知道,请告诉我。 谢谢你。
您可以通过 tsai 库导入并使用来自 fastai 的提前停止回调:
from fastai.callback.all import EarlyStoppingCallback
然后设置您的回调:
cbs = [EarlyStoppingCallback(), ShowGraph()]
您可以定义参数,例如监控哪个指标/损失以及在没有改善的轮次后终止训练。