如果自上一个时期提高了验证准确性,如何编写自定义回调以在每个时期保存模型

问题描述 投票:0回答:1

下面是我编写的自定义回调函数,但是它不起作用:

class bestval(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.history={'loss': [],'acc': [],'val_loss': [],'val_acc': []}

    def on_epoch_end(self, epoch, logs={}):
        #appending val_acc in history
        if logs.get('val_acc', -1) != -1:
            self.history['val_acc'].append(logs.get('val_acc'))
        # Trying to compare current epoch val_acc with all the values in self.history['val_acc']
        if logs.get('val_acc')> [i for i in self.history['val_acc']]:
            filepath="model_save/weights-{epoch:02d}-{val_acc:.4f}.hdf5"
            # Saving the model using TF built-in callback 
            checkpoint = tensorflow.keras.callbacks.ModelCheckpoint(filepath=filepath, 
            monitor='val_acc',  verbose=1, mode='auto')
bestobj= bestval()

拟合模型:

model.fit(xtr,ytr, epochs=4, validation_data=(xte,yte), batch_size=128, callbacks=[bestobj])

当我执行以上操作时,出现以下错误:

ValueError:包含多个元素的数组的真值不明确。使用a.any()或a.all()

我知道我在做一些愚蠢的事情,但我不知道如何解决。任何帮助,将不胜感激。

tensorflow keras callback deep-learning tf.keras
1个回答
0
投票

我猜错误是在下一行中,您正在尝试将值与列表进行比较。if logs.get('val_acc')> [i for i in self.history['val_acc']]:

尝试,for i in self.history['val_acc']: if logs.get('val_acc')>i: #your code

© www.soinside.com 2019 - 2024. All rights reserved.