在 Kaggle 笔记本中使用 ModelCheckpoint 训练 Keras 模型时出现问题(“train_function”的意外结果(空日志))

问题描述 投票:0回答:1

我在尝试使用 TensorFlow 的 ModelCheckpoint 回调在 Kaggle 笔记本中训练 Keras 模型时遇到问题。这是我的设置和我面临的错误:

设置:

我正在使用 TensorFlow 构建用于多标签分类的 Keras 模型。这是我的代码的相关部分:

from keras.models import Sequential
from keras.layers import Dense, Dropout, LSTM, BatchNormalization
from keras.callbacks import TensorBoard
from keras.callbacks import ModelCheckpoint
from keras.optimizers import AdamW

epochs = 4
loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
classifier_model.compile(optimizer='adam',
                         loss=loss,
                         metrics = 'roc-auc')

print(f'Training model with {tfhub_handle_encoder}')
checkpoint_filepath = '/kaggle/working/tmp_weights.h5'

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True)

history = classifier_model.fit(x=train_ds,
                               validation_data=val_ds,
                               epochs=epochs,
                               callbacks = [model_checkpoint_callback])

错误:

运行训练脚本时,遇到以下错误:

ValueError: Unexpected result of `train_function` (Empty logs). This could be due to issues in input pipeline that resulted in an empty dataset. Otherwise, please use `Model.compile(..., run_eagerly=True)`, or `tf.config.run_functions_eagerly(True)` for more information of where went wrong, or file a issue/bug to `tf.keras`.

其他背景:

  • 我正在使用 TensorFlow Hub 编码器 (
    tfhub_handle_encoder
    ) 进行文本嵌入。
  • train_ds
    val_ds
    分别是包含我的训练和验证数据的对象,它们的格式如下:
    <_TakeDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.string, name=None), TensorSpec(shape=(None, 6), dtype=tf.int64, name=None))>
  • 我已验证我的数据加载和预处理步骤是否正确,并且
    train_ds
    val_ds
    不为空。

要求:

对于如何使用 Kaggle 上的 Keras 训练脚本中的 ModelCheckpoint 回调解决此问题,我将不胜感激。谢谢!

python tensorflow keras deep-learning
1个回答
0
投票

检查您的指标和优化器是否包含在列表中,正如 Keras 所希望的那样。另外,它需要一个 ROC 指标,就文档而言

loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
classifier_model.compile(optimizer=['adam'],
                         loss=loss,
                         metrics = ['ROC'])
© www.soinside.com 2019 - 2024. All rights reserved.