将 tf.data.dataset 对象提供给模型时,tf.callbacks.EarlyStopping 无法正常工作

问题描述 投票:0回答:1

我设置 Patient=5,如果 val_loss 连续 5 个周期没有减少,它就会停止训练。然而,训练总是在第 5 个 epoch 停止,并且最佳权重设置为第 1 个 epoch,即使 val_loss 仍在减少。

仅当我将 tf.data.dataset 对象提供给模型时才会发生这种情况。当我将 numpy 数组输入模型时,它仍然按照我的预期工作。

train_dataset = train_dataset.repeat().batch(32).prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset.repeat().batch(32).prefetch(tf.data.AUTOTUNE)


early_stopping_ = EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=5,  # Number of epochs with no improvement after which training will be stopped
    verbose=1,
    restore_best_weights=True
)


model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(50, 50, 3)),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(256, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Flatten(),
    Dense(512, activation='relu', kernel_regularizer=l2(0.01)),
    Dropout(0.5),
    Dense(128, activation='relu', kernel_regularizer=l2(0.01)),
    Dropout(0.5),
    Dense(4, activation='softmax')
])


model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])


Train the model
history = model.fit(
    train_dataset,
    epochs=20,
    steps_per_epoch = math.ceil(train_size/32),
    #batch_size=32,
    validation_data = val_dataset,
    validation_steps = math.ceil(val_size//32),
    verbose=1,
    callbacks = [early_stopping]
)```
python numpy tensorflow tensor early-stopping
1个回答
0
投票

发生这种情况可能是因为训练循环中使用 tf.data.Dataset 的方式,特别是使用重复()和steps_per_epoch参数,这里有一些建议:

1- 确保根据您的数据集大小和批量大小正确计算steps_per_epoch和validation_steps:

示例:

steps_per_epoch = math.ceil(train_size / 32)
validation_steps = math.ceil(val_size / 32)

2-您在定义steps_per_epoch时将tf.data.Dataset与repeat()一起使用,这可能会导致模型认为它没有改进,从而导致过早停止。

确保设置了steps_per_epoch,以便每个纪元处理正确数量的批次。

或者,如果您使用repeat()而不指定计数,您可以尝试删除它或将其设置为repeat(count=1),以将数据集限制为每个时期的单遍。

3-如果您不需要重复数据集,您可以考虑删除重复()调用或显式定义重复次数。例如:

train_dataset = train_dataset.batch(32).prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset.batch(32).prefetch(tf.data.AUTOTUNE)

4- 考虑在训练循环中添加更详细的日志记录,以更好地了解 val_loss 如何变化以及为什么可能会触发 EarlyStopping。

如果您需要更多帮助,请告诉我(:

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.