如何在时间序列中保存最佳模型 tsai

问题描述 投票:0回答:2

我正在使用 tsai TSClassifier 来训练我的数据。 但是我不知道如何保存所有时代的最佳模型。就像keras/tensorflow中有modelcheckpoint。在所有时期中,我想保存在测试集上最好的 val_loss 模型。

下面是我的代码,有人可以帮助让我知道如何在 tsai 库中保存最好的模型。

import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
from pickle import load
from multiprocessing import Process
import numpy as np
from tsai.all import *
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve

dataset_idx = 0
X_train = load(open(r"X_train_"+str(dataset_idx)+".pkl", 'rb'))
y_train = load(open(r"y_train_"+str(dataset_idx)+".pkl", 'rb'))
X_test = load(open(r"X_test_"+str(dataset_idx)+".pkl", 'rb'))
y_test = load(open(r"y_test_"+str(dataset_idx)+".pkl", 'rb'))
print("dataset loaded")

learn = TSClassifier(X_train, y_train, arch=InceptionTimePlus, arch_config=dict(fc_dropout=0.5))

print("training started")
learn.fit_one_cycle(5, 0.0005)
learn.export("tsai_"+str(dataset_idx)+".pkl") 
python deep-learning time-series
2个回答
0
投票

下面是对我有用的实现。

dataset_idx = 0
print("training info:", dataset_idx)
X_train = load(open(r"X_train_"+str(dataset_idx)+".pkl", 'rb')).transpose((0,2,1))
y_train = load(open(r"y_train_"+str(dataset_idx)+".pkl", 'rb'))
X_test = load(open(r"X_test_"+str(dataset_idx)+".pkl", 'rb')).transpose((0,2,1))
y_test = load(open(r"y_test_"+str(dataset_idx)+".pkl", 'rb'))
l = X_train.shape[0]
X = np.concatenate([X_train,X_test],axis=0)
del X_train, X_test
y = np.concatenate([y_train,y_test],axis=0)
del y_train, y_test
print("dataset loaded")
learn = TSClassifier(X, y, splits = [(list(range(l))), (list(range((l), y.shape[0])))], arch=InceptionTimePlus, bs=256, arch_config=dict(fc_dropout=0.5))
print("training started")
learn.fit_one_cycle(20, 0.001, cbs=SaveModelCallback(monitor='valid_loss', fname="ITP_"+str(dataset_idx)))

-2
投票

可以使用fastai提供的SaveModelCallback回调

learn.fit_one_cycle(5, 0.0005, cbs=SaveModelCallback(monitor='valid_loss', fname='best_model'))



dl = learn.dls.test_dl(X_test, y_test)
preds, targets = learn.validate(dl=dl)
print("Test set accuracy:", accuracy(preds, targets))
© www.soinside.com 2019 - 2024. All rights reserved.