当达到特定的验证准确率时如何停止训练?

问题描述 投票:0回答:2

我正在训练一个卷积网络,一旦验证错误达到 90%,我想停止训练。我考虑过使用 EarlyStopping 并将基线设置为 0.90,但只要验证准确度低于给定数量的 epoch 的基线(此处仅为 0),它就会停止训练。所以我的代码是:

es=EarlyStopping(monitor='val_acc',mode='auto',verbose=1,baseline=.90,patience=0)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2,callbacks=[es])

当我使用此代码时,我的训练在第一个时期后停止并给出给定的结果:

Train on 60000 samples, validate on 10000 samples

Epoch 1/30
60000/60000 - 7s - loss: 0.4600 - acc: 0.8330 - val_loss: 0.3426 - val_acc: 0.8787

一旦验证准确率达到 90% 或以上,我还能尝试什么来停止训练?

这是其余的代码:

  tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(2, 2),
  tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(2, 2),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(152, activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer=Adam(learning_rate=0.001),loss='sparse_categorical_crossentropy', metrics=['accuracy'])
es=EarlyStopping(monitor='val_acc',mode='auto',verbose=1,baseline=.90,patience=0)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2,callbacks=[es])
python tensorflow keras deep-learning conv-neural-network
2个回答
10
投票

提前停止回调将搜索停止增加(或减少)的值,因此这对于您的问题来说不是一个很好的用途。但是

tf.keras
允许您使用 自定义回调

举个例子:

class MyThresholdCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold):
        super(MyThresholdCallback, self).__init__()
        self.threshold = threshold

    def on_epoch_end(self, epoch, logs=None): 
        val_acc = logs["val_acc"]
        if val_acc >= self.threshold:
            self.model.stop_training = True

对于 TF 2.3 或更高版本,您可能必须使用

"val_accuracy"
而不是
"val_acc"
。感谢您Christian Westbrook在评论中的注释。

上述回调在每个纪元结束时将从所有可用日志中提取验证准确性。然后它会将其与用户定义的阈值(在您的情况下为 90%)进行比较。如果满足标准,训练将停止。

您只需拨打:

my_callback = MyThresholdCallback(threshold=0.9)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2, callbacks=[my_callback])

或者,如果您想立即停止,可以使用

def on_batch_end(...)
。 然而,这需要参数
batch, logs
而不是
epoch, logs


7
投票

现有的答案看起来不错,但我过去使用过较短的版本:

class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if logs.get('accuracy') >= 9e-1:
            self.model.stop_training = True

你可以这样实现:

callback = CustomCallback()

history = model.fit(..., callbacks=[callback])
© www.soinside.com 2019 - 2024. All rights reserved.