因此,我使用 Keras 网站 上提供的视觉变换器模型示例进行图像分类。唯一的区别是我添加了一行来将模型训练完成后保存为“.keras”文件。
后来我尝试加载保存的模型并使用“get_configuration()”检查其配置。
Lmodel=load_model("VITexp.keras")
Lmodel.get_config()
但是代码无法加载模型并给出以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/ops/operation.py:208, in Operation.from_config(cls, config)
207 try:
--> 208 return cls(**config)
209 except Exception as e:
TypeError: Patches.__init__() got an unexpected keyword argument 'name'
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:718, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
717 try:
--> 718 instance = cls.from_config(inner_config)
719 except TypeError as e:
File /opt/conda/lib/python3.10/site-packages/keras/src/ops/operation.py:210, in Operation.from_config(cls, config)
209 except Exception as e:
--> 210 raise TypeError(
211 f"Error when deserializing class '{cls.__name__}' using "
212 f"config={config}.\n\nException encountered: {e}"
213 )
TypeError: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.
Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:718, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
717 try:
--> 718 instance = cls.from_config(inner_config)
719 except TypeError as e:
File /opt/conda/lib/python3.10/site-packages/keras/src/models/model.py:517, in Model.from_config(cls, config, custom_objects)
515 from keras.src.models.functional import functional_from_config
--> 517 return functional_from_config(
518 cls, config, custom_objects=custom_objects
519 )
521 # Either the model has a custom __init__, or the config
522 # does not contain all the information necessary to
523 # revive a Functional model. This happens when the user creates
(...)
526 # In this case, we fall back to provide all config into the
527 # constructor of the class.
File /opt/conda/lib/python3.10/site-packages/keras/src/models/functional.py:517, in functional_from_config(cls, config, custom_objects)
516 for layer_data in config["layers"]:
--> 517 process_layer(layer_data)
519 # Then we process nodes in order of layer depth.
520 # Nodes that cannot yet be processed (if the inbound node
521 # does not yet exist) are re-enqueued, and the process
522 # is repeated until all nodes are processed.
File /opt/conda/lib/python3.10/site-packages/keras/src/models/functional.py:501, in functional_from_config.<locals>.process_layer(layer_data)
500 else:
--> 501 layer = serialization_lib.deserialize_keras_object(
502 layer_data, custom_objects=custom_objects
503 )
504 created_layers[layer_name] = layer
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:720, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
719 except TypeError as e:
--> 720 raise TypeError(
721 f"{cls} could not be deserialized properly. Please"
722 " ensure that components that are Python object"
723 " instances (layers, models, etc.) returned by"
724 " `get_config()` are explicitly deserialized in the"
725 " model's `from_config()` method."
726 f"\n\nconfig={config}.\n\nException encountered: {e}"
727 )
728 build_config = config.get("build_config", None)
TypeError: <class '__main__.Patches'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.
config={'module': None, 'class_name': 'Patches', 'config': {'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}, 'registered_name': 'Custom>Patches', 'build_config': {'input_shape': [None, 72, 72, 3]}, 'name': 'patches_1', 'inbound_nodes': [{'args': [{'class_name': '__keras_tensor__', 'config': {'shape': [None, 72, 72, 3], 'dtype': 'float32', 'keras_history': ['data_augmentation', 0, 0]}}], 'kwargs': {}}]}.
Exception encountered: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.
Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[11], line 1
----> 1 Lmodel=load_model("VITexp.keras")
2 Lmodel.get_config()
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_api.py:176, in load_model(filepath, custom_objects, compile, safe_mode)
173 is_keras_zip = True
175 if is_keras_zip:
--> 176 return saving_lib.load_model(
177 filepath,
178 custom_objects=custom_objects,
179 compile=compile,
180 safe_mode=safe_mode,
181 )
182 if str(filepath).endswith((".h5", ".hdf5")):
183 return legacy_h5_format.load_model_from_hdf5(
184 filepath, custom_objects=custom_objects, compile=compile
185 )
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:152, in load_model(filepath, custom_objects, compile, safe_mode)
147 raise ValueError(
148 "Invalid filename: expected a `.keras` extension. "
149 f"Received: filepath={filepath}"
150 )
151 with open(filepath, "rb") as f:
--> 152 return _load_model_from_fileobj(
153 f, custom_objects, compile, safe_mode
154 )
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:170, in _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode)
168 # Construct the model from the configuration file in the archive.
169 with ObjectSharingScope():
--> 170 model = deserialize_keras_object(
171 config_dict, custom_objects, safe_mode=safe_mode
172 )
174 all_filenames = zf.namelist()
175 if _VARS_FNAME + ".h5" in all_filenames:
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:720, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
718 instance = cls.from_config(inner_config)
719 except TypeError as e:
--> 720 raise TypeError(
721 f"{cls} could not be deserialized properly. Please"
722 " ensure that components that are Python object"
723 " instances (layers, models, etc.) returned by"
724 " `get_config()` are explicitly deserialized in the"
725 " model's `from_config()` method."
726 f"\n\nconfig={config}.\n\nException encountered: {e}"
727 )
728 build_config = config.get("build_config", None)
729 if build_config and not instance.built:
TypeError: <class 'keras.src.models.functional.Functional'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.
Exception encountered: <class '__main__.Patches'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.
config={'module': None, 'class_name': 'Patches', 'config': {'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}, 'registered_name': 'Custom>Patches', 'build_config': {'input_shape': [None, 72, 72, 3]}, 'name': 'patches_1', 'inbound_nodes': [{'args': [{'class_name': '__keras_tensor__', 'config': {'shape': [None, 72, 72, 3], 'dtype': 'float32', 'keras_history': ['data_augmentation', 0, 0]}}], 'kwargs': {}}]}.
Exception encountered: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.
Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'
除了保存和加载模型命令之外,代码是从网站复制粘贴的。
请帮我解决这个问题。是否有特定的方法来保存这些模型以便以后在完全不同的笔记本中访问? (我使用 Kaggle 来编写此代码)
总结我在问题下的评论并在代码下方给出全面的答案。我使用了问题中链接中的代码,我添加的行用 # 注释标记。只需要修改图层类。
@keras.saving.register_keras_serializable() # <- this line
class Patches(layers.Layer):
def __init__(self, patch_size, **kwargs): # <- add **kwargs
super().__init__(**kwargs) # <- add **kwargs
self.patch_size = patch_size
def call(self, images):
input_shape = ops.shape(images)
batch_size = input_shape[0]
height = input_shape[1]
width = input_shape[2]
channels = input_shape[3]
num_patches_h = height // self.patch_size
num_patches_w = width // self.patch_size
patches = keras.ops.image.extract_patches(images, size=self.patch_size)
patches = ops.reshape(
patches,
(
batch_size,
num_patches_h * num_patches_w,
self.patch_size * self.patch_size * channels,
),
)
return patches
def get_config(self):
config = super().get_config()
config.update({"patch_size": self.patch_size})
return config
# ------------------------------------------------------------------
@keras.saving.register_keras_serializable() # this line
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim, **kwargs): # <- add **kwargs
super().__init__(**kwargs) # <- add **kwargs
self.num_patches = num_patches
self.projection_dim = projection_dim # save projection_dim
print(f'num_patches: {num_patches}, proj. dim: {projection_dim}')
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def build(self, input_shape): # add build method (this threw only a warning)
super().build(input_shape)
def call(self, patch):
positions = ops.expand_dims(
ops.arange(start=0, stop=self.num_patches, step=1), axis=0
)
projected_patches = self.projection(patch)
encoded = projected_patches + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config()
config.update({"num_patches": self.num_patches})
config.update({"projection_dim": self.projection_dim}) # this line
return config
添加的代码行的简短说明:
@keras.saving.register_keras_serializable()
此装饰器注册 Keras 需要了解的自定义层,并将其注册到主列表中。
**kwargs
捕获
__init__()
方法获取的未知(对用户而言)关键字参数,并将其提供给 super()
调用。在这种情况下,__init__()
获得了参数name
,因为每个Layer
类都有一个。但 name
最初并不是预期的参数。
self.projection_dim = projection_dim
# ...
config.update({"projection_dim": self.projection_dim})
这两行将
projection_dim
保存到PatchEncoder
层的配置中。这样做是为了在再次加载图层时使用设置的参数。