tf.keras.models.load_model 无法加载使用 tf.keras.models.save_model 保存的模型

问题描述 投票:0回答:1

我正在尝试加载之前训练过的模型,但它给了我这个错误。有什么方法可以将此文件格式转换为新格式或如何解决此问题?如果可能的话,我宁愿不必训练新模型。

#Model was previously saved with     
tf.keras.models.save_model(file_path)
#Load back in (causes error)
model_a = tf.keras.models.load_model(file_path)

这是收到的错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-12-120117fe47ab> in <cell line: 3>()
      1 #@title Load the checkpoint
      2 checkpoint_path = '/drive/MyDrive/dummy/2' # @param {type:"string"}
----> 3 model_a = tf.keras.models.load_model(checkpoint_path)
      4 tf.keras.utils.plot_model(ad._model, show_shapes = True)

/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_api.py in load_model(filepath, custom_objects, compile, safe_mode)
    197         )
    198     else:
--> 199         raise ValueError(
    200             f"File format not supported: filepath={filepath}. "
    201             "Keras 3 only supports V3 `.keras` files and "

ValueError: File format not supported: filepath=/drive/MyDrive/dummy/2. Keras 3 only supports V3 `.keras` files and legacy H5 format files (`.h5` extension). Note that the legacy SavedModel format is not supported by `load_model()` in Keras 3. In order to reload a TensorFlow SavedModel as an inference-only layer in Keras 3, use `keras.layers.TFSMLayer(/drive/MyDrive/dummy/2, call_endpoint='serving_default')` (note that your `call_endpoint` might have a different name).
python tensorflow keras
1个回答
0
投票

看起来您用来加载 keras 模型的

checkpoint_path
变量没有扩展名。从 documentation 来看,
load_model
文件路径参数必须具有
.keras
扩展名,当您调用
save_model
时,该扩展名不包含在文件中,因此
load_model
无法确定文件是否位于
.keras
.h5
格式。要解决此问题,您只需在保存模型时添加
.keras
文件扩展名

如果

2
是一个目录

import os
checkpoint_path = '/drive/MyDrive/dummy/2'
model_filepath = os.path.join(checkpoint_path, 'model.keras')
#Model was previously saved with     
tf.keras.models.save_model(model_filepath)
#Load back in (causes error)
model_a = tf.keras.models.load_model(model_filepath)

如果

2
是模型的预期文件名

import os
checkpoint_path = '/drive/MyDrive/dummy/2.keras'
#Model was previously saved with     
tf.keras.models.save_model(file_path)
#Load back in (causes error)
model_a = tf.keras.models.load_model(file_path)
© www.soinside.com 2019 - 2024. All rights reserved.