Keras回调AttributeError:'ModelCheckpoint'对象没有属性'_implements_train_batch_hooks'

问题描述 投票:0回答:3

我正在使用 Keras(带有 TensorFlow 后端)来实现神经网络,并且只想保存在训练期间最小化验证集损失的模型。为此,我实例化了一个 ModelCheckpoint 并在调用模型的 fit 方法时传递它。但是,当我这样做时,我收到以下错误:“

AttributeError: 'ModelCheckpoint' object has no attribute '_implements_train_batch_hooks'
”。我在网上找到的最接近我的问题的是this post,有一个类似的错误,其中问题来自于混合来自
keras
tf.keras
的模块,但这不是我的情况,因为我的所有模块都是从
导入的keras
。我在网上和 Keras 文档中查找了一段时间,但找不到任何可以解释此错误的内容。以下是与该问题最相关的代码部分:

导入模块

from keras.models import Sequential
from keras.layers import Embedding, Conv1D, Dense, Dropout, GlobalMaxPool1D, Concatenate
from keras.callbacks import ModelCheckpoint

ModelCheckpoint实例化、模型编译以及调用fit方法:

checkpoint = ModelCheckpoint('../model_best.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min')

model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

history = model.fit(x_train, y_train, 
                    epochs = 10, batch_size = 64,
                    validation_data = (x_val, y_val),
                    callbacks = [checkpoint])

...这是完整的回溯:

Traceback (most recent call last):

  File "/Users/thisuser/thisrepo/classifier.py", line 39, in <module>
    callbacks = [checkpoint])

  File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
    return method(self, *args, **kwargs)

  File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 826, in fit
    steps=data_handler.inferred_steps)

  File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 231, in __init__
    cb._implements_train_batch_hooks() for cb in self.callbacks)

  File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 231, in <genexpr>
    cb._implements_train_batch_hooks() for cb in self.callbacks)

AttributeError: 'ModelCheckpoint' object has no attribute '_implements_train_batch_hooks'

我使用的版本是:

  • Python:3.7.7
  • Keras:2.3.0-tf

有谁知道可能是什么原因造成的?如果需要,我可以稍微修改我的代码以将其全部放在这里,以便它是可重现的。预先感谢您的帮助!

python tensorflow keras callback tf.keras
3个回答
7
投票

我最近也遇到这个问题了

我发现了什么:最近 keras 或 tensorflow 版本被开发人员更新了,这导致了问题。

解决方案:由于keras的开发者要求大家切换到tf.keras版本,所以你需要替换你的代码import部分

来自:

import keras

致:

import tensorflow.keras as keras

之后一切都对我有用。


0
投票

替换: 从 keras.callbacks 导入 ModelCheckpoint 到: 从tensorflow.keras.callbacks导入ModelCheckpoint


0
投票

我遇到了同样的错误,并在运行“!pip install tf_keras”后使用了“import tf_keras”。然后“from tf_keras import callbacks”成功训练我的模型。这可能是计算机/本地环境特定问题。

© www.soinside.com 2019 - 2024. All rights reserved.