我正在使用 Keras(带有 TensorFlow 后端)来实现神经网络,并且只想保存在训练期间最小化验证集损失的模型。为此,我实例化了一个 ModelCheckpoint 并在调用模型的 fit 方法时传递它。但是,当我这样做时,我收到以下错误:“
AttributeError: 'ModelCheckpoint' object has no attribute '_implements_train_batch_hooks'
”。我在网上找到的最接近我的问题的是this post,有一个类似的错误,其中问题来自于混合来自keras
和tf.keras
的模块,但这不是我的情况,因为我的所有模块都是从导入的keras
。我在网上和 Keras 文档中查找了一段时间,但找不到任何可以解释此错误的内容。以下是与该问题最相关的代码部分:
导入模块:
from keras.models import Sequential
from keras.layers import Embedding, Conv1D, Dense, Dropout, GlobalMaxPool1D, Concatenate
from keras.callbacks import ModelCheckpoint
ModelCheckpoint实例化、模型编译以及调用fit方法:
checkpoint = ModelCheckpoint('../model_best.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min')
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
history = model.fit(x_train, y_train,
epochs = 10, batch_size = 64,
validation_data = (x_val, y_val),
callbacks = [checkpoint])
...这是完整的回溯:
Traceback (most recent call last):
File "/Users/thisuser/thisrepo/classifier.py", line 39, in <module>
callbacks = [checkpoint])
File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
return method(self, *args, **kwargs)
File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 826, in fit
steps=data_handler.inferred_steps)
File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 231, in __init__
cb._implements_train_batch_hooks() for cb in self.callbacks)
File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 231, in <genexpr>
cb._implements_train_batch_hooks() for cb in self.callbacks)
AttributeError: 'ModelCheckpoint' object has no attribute '_implements_train_batch_hooks'
我使用的版本是:
有谁知道可能是什么原因造成的?如果需要,我可以稍微修改我的代码以将其全部放在这里,以便它是可重现的。预先感谢您的帮助!
我最近也遇到这个问题了
我发现了什么:最近 keras 或 tensorflow 版本被开发人员更新了,这导致了问题。
解决方案:由于keras的开发者要求大家切换到tf.keras版本,所以你需要替换你的代码import部分。
来自:
import keras
致:
import tensorflow.keras as keras
之后一切都对我有用。
替换: 从 keras.callbacks 导入 ModelCheckpoint 到: 从tensorflow.keras.callbacks导入ModelCheckpoint
我遇到了同样的错误,并在运行“!pip install tf_keras”后使用了“import tf_keras”。然后“from tf_keras import callbacks”成功训练我的模型。这可能是计算机/本地环境特定问题。