训练后无法加载 Vision Transformer 模型

问题描述 投票:0回答:1

因此,我使用 Keras 网站 上提供的视觉变换器模型示例进行图像分类。唯一的区别是我添加了一行来将模型训练完成后保存为“.keras”文件。

后来我尝试加载保存的模型并使用“get_configuration()”检查其配置。

Lmodel=load_model("VITexp.keras")
Lmodel.get_config()

但是代码无法加载模型并给出以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/ops/operation.py:208, in Operation.from_config(cls, config)
    207 try:
--> 208     return cls(**config)
    209 except Exception as e:

TypeError: Patches.__init__() got an unexpected keyword argument 'name'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:718, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    717 try:
--> 718     instance = cls.from_config(inner_config)
    719 except TypeError as e:

File /opt/conda/lib/python3.10/site-packages/keras/src/ops/operation.py:210, in Operation.from_config(cls, config)
    209 except Exception as e:
--> 210     raise TypeError(
    211         f"Error when deserializing class '{cls.__name__}' using "
    212         f"config={config}.\n\nException encountered: {e}"
    213     )

TypeError: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.

Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:718, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    717 try:
--> 718     instance = cls.from_config(inner_config)
    719 except TypeError as e:

File /opt/conda/lib/python3.10/site-packages/keras/src/models/model.py:517, in Model.from_config(cls, config, custom_objects)
    515     from keras.src.models.functional import functional_from_config
--> 517     return functional_from_config(
    518         cls, config, custom_objects=custom_objects
    519     )
    521 # Either the model has a custom __init__, or the config
    522 # does not contain all the information necessary to
    523 # revive a Functional model. This happens when the user creates
   (...)
    526 # In this case, we fall back to provide all config into the
    527 # constructor of the class.

File /opt/conda/lib/python3.10/site-packages/keras/src/models/functional.py:517, in functional_from_config(cls, config, custom_objects)
    516 for layer_data in config["layers"]:
--> 517     process_layer(layer_data)
    519 # Then we process nodes in order of layer depth.
    520 # Nodes that cannot yet be processed (if the inbound node
    521 # does not yet exist) are re-enqueued, and the process
    522 # is repeated until all nodes are processed.

File /opt/conda/lib/python3.10/site-packages/keras/src/models/functional.py:501, in functional_from_config.<locals>.process_layer(layer_data)
    500 else:
--> 501     layer = serialization_lib.deserialize_keras_object(
    502         layer_data, custom_objects=custom_objects
    503     )
    504 created_layers[layer_name] = layer

File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:720, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    719 except TypeError as e:
--> 720     raise TypeError(
    721         f"{cls} could not be deserialized properly. Please"
    722         " ensure that components that are Python object"
    723         " instances (layers, models, etc.) returned by"
    724         " `get_config()` are explicitly deserialized in the"
    725         " model's `from_config()` method."
    726         f"\n\nconfig={config}.\n\nException encountered: {e}"
    727     )
    728 build_config = config.get("build_config", None)

TypeError: <class '__main__.Patches'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.

config={'module': None, 'class_name': 'Patches', 'config': {'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}, 'registered_name': 'Custom>Patches', 'build_config': {'input_shape': [None, 72, 72, 3]}, 'name': 'patches_1', 'inbound_nodes': [{'args': [{'class_name': '__keras_tensor__', 'config': {'shape': [None, 72, 72, 3], 'dtype': 'float32', 'keras_history': ['data_augmentation', 0, 0]}}], 'kwargs': {}}]}.

Exception encountered: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.

Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[11], line 1
----> 1 Lmodel=load_model("VITexp.keras")
      2 Lmodel.get_config()

File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_api.py:176, in load_model(filepath, custom_objects, compile, safe_mode)
    173         is_keras_zip = True
    175 if is_keras_zip:
--> 176     return saving_lib.load_model(
    177         filepath,
    178         custom_objects=custom_objects,
    179         compile=compile,
    180         safe_mode=safe_mode,
    181     )
    182 if str(filepath).endswith((".h5", ".hdf5")):
    183     return legacy_h5_format.load_model_from_hdf5(
    184         filepath, custom_objects=custom_objects, compile=compile
    185     )

File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:152, in load_model(filepath, custom_objects, compile, safe_mode)
    147     raise ValueError(
    148         "Invalid filename: expected a `.keras` extension. "
    149         f"Received: filepath={filepath}"
    150     )
    151 with open(filepath, "rb") as f:
--> 152     return _load_model_from_fileobj(
    153         f, custom_objects, compile, safe_mode
    154     )

File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:170, in _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode)
    168 # Construct the model from the configuration file in the archive.
    169 with ObjectSharingScope():
--> 170     model = deserialize_keras_object(
    171         config_dict, custom_objects, safe_mode=safe_mode
    172     )
    174 all_filenames = zf.namelist()
    175 if _VARS_FNAME + ".h5" in all_filenames:

File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:720, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    718     instance = cls.from_config(inner_config)
    719 except TypeError as e:
--> 720     raise TypeError(
    721         f"{cls} could not be deserialized properly. Please"
    722         " ensure that components that are Python object"
    723         " instances (layers, models, etc.) returned by"
    724         " `get_config()` are explicitly deserialized in the"
    725         " model's `from_config()` method."
    726         f"\n\nconfig={config}.\n\nException encountered: {e}"
    727     )
    728 build_config = config.get("build_config", None)
    729 if build_config and not instance.built:

TypeError: <class 'keras.src.models.functional.Functional'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.

Exception encountered: <class '__main__.Patches'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.

config={'module': None, 'class_name': 'Patches', 'config': {'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}, 'registered_name': 'Custom>Patches', 'build_config': {'input_shape': [None, 72, 72, 3]}, 'name': 'patches_1', 'inbound_nodes': [{'args': [{'class_name': '__keras_tensor__', 'config': {'shape': [None, 72, 72, 3], 'dtype': 'float32', 'keras_history': ['data_augmentation', 0, 0]}}], 'kwargs': {}}]}.

Exception encountered: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.

Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'

除了保存和加载模型命令之外,代码是从网站复制粘贴的。

请帮我解决这个问题。是否有特定的方法来保存这些模型以便以后在完全不同的笔记本中访问? (我使用 Kaggle 来编写此代码)

python tensorflow keras image-classification vision-transformer
1个回答
0
投票

总结我在问题下的评论并在代码下方给出全面的答案。我使用了问题中链接中的代码,我添加的行用 # 注释标记。只需要修改图层类。

@keras.saving.register_keras_serializable()  # <- this line
class Patches(layers.Layer):
    def __init__(self, patch_size, **kwargs):  # <- add **kwargs
        super().__init__(**kwargs)  # <- add **kwargs
        self.patch_size = patch_size

    def call(self, images):
        input_shape = ops.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        patches = keras.ops.image.extract_patches(images, size=self.patch_size)
        patches = ops.reshape(
            patches,
            (
                batch_size,
                num_patches_h * num_patches_w,
                self.patch_size * self.patch_size * channels,
            ),
        )
        return patches

    def get_config(self):
        config = super().get_config()
        config.update({"patch_size": self.patch_size})
        return config

# ------------------------------------------------------------------


@keras.saving.register_keras_serializable()  # this line
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim, **kwargs):  # <- add **kwargs
        super().__init__(**kwargs)  # <- add **kwargs
        self.num_patches = num_patches
        self.projection_dim = projection_dim  # save projection_dim
        print(f'num_patches: {num_patches}, proj. dim: {projection_dim}')
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def build(self, input_shape):  # add build method (this threw only a warning)
        super().build(input_shape)

    def call(self, patch):
        positions = ops.expand_dims(
            ops.arange(start=0, stop=self.num_patches, step=1), axis=0
        )
        projected_patches = self.projection(patch)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded

    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        config.update({"projection_dim": self.projection_dim})  # this line
        return config

添加的代码行的简短说明:

@keras.saving.register_keras_serializable()

此装饰器注册 Keras 需要了解的自定义层,并将其注册到主列表中。

**kwargs

捕获

__init__()
方法获取的未知(对用户而言)关键字参数,并将其提供给
super()
调用。在这种情况下,
__init__()
获得了参数
name
,因为每个
Layer
类都有一个。但
name
最初并不是预期的参数。

self.projection_dim = projection_dim
# ...
config.update({"projection_dim": self.projection_dim})

这两行将

projection_dim
保存到
PatchEncoder
层的配置中。这样做是为了在再次加载图层时使用设置的参数。

© www.soinside.com 2019 - 2024. All rights reserved.