使用 TensorFlow 2.0 将“.h5”模型转换为“.tflite”时失败

问题描述 投票:0回答:1

将模型从 .h5 转换为 .tflite

我正在使用的代码:

import torch
import torch.nn as nn
import detectron2
from detectron2.modeling import build_model
from detectron2.modeling import build_model
from torch.ao.quantization import (
  get_default_qconfig_mapping,
  get_default_qat_qconfig_mapping,
  QConfigMapping,
)

import torch.ao.quantization.quantize_fx as quantize_fx
import copy

model_path = '/kaggle/input/vehicleobjectdetection/pytorch/v1/2/model_final.h5' 

cfg.MODEL.WEIGHTS = model_path


# Instantiate the model (adjust parameters as necessary)
model = build_model(cfg)


from keras.models import load_model
import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model(model_path)
tfmodel = converter.convert()

出现此错误:

**AttributeError**                            Traceback (most recent call last)
Cell In[54], line 28
     25 import tensorflow as tf
     27 converter = tf.lite.TFLiteConverter.from_keras_model(model_path)
---> 28 **tfmodel = converter.convert()**





File /opt/conda/lib/python3.10/site-packages/tensorflow/lite/python/lite.py:1139, in _export_metrics.<locals>.wrapper(self, *args, **kwargs)
   1136 @functools.wraps(convert_func)
   1137 def wrapper(self, *args, **kwargs):
   1138   # pylint: disable=protected-access
-> 1139   return self._convert_and_export_metrics(convert_func, *args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/tensorflow/lite/python/lite.py:1093, in TFLiteConverterBase._convert_and_export_metrics(self, convert_func, *args, **kwargs)
   1091 self._save_conversion_params_metric()
   1092 start_time = time.process_time()
-> 1093 result = convert_func(self, *args, **kwargs)
   1094 elapsed_time_ms = (time.process_time() - start_time) * 1000
   1095 if result:

File /opt/conda/lib/python3.10/site-packages/tensorflow/lite/python/lite.py:1606, in TFLiteKerasModelConverterV2.convert(self)
   1602 if saved_model_convert_result:
   1603   return saved_model_convert_result
   1605 graph_def, input_tensors, output_tensors, frozen_func = (
-> 1606     self._freeze_keras_model()
   1607 )
   1609 graph_def = self._optimize_tf_model(
   1610     graph_def, input_tensors, output_tensors, frozen_func
   1611 )
   1613 return super(TFLiteKerasModelConverterV2, self).convert(
   1614     graph_def, input_tensors, output_tensors
   1615 )

File /opt/conda/lib/python3.10/site-packages/tensorflow/lite/python/convert_phase.py:215, in convert_phase.<locals>.actual_decorator.<locals>.wrapper(*args, **kwargs)
    213 except Exception as error:
    214   report_error_message(str(error))
--> 215   raise error from None

File /opt/conda/lib/python3.10/site-packages/tensorflow/lite/python/convert_phase.py:205, in convert_phase.<locals>.actual_decorator.<locals>.wrapper(*args, **kwargs)
    202 @functools.wraps(func)
    203 def wrapper(*args, **kwargs):
    204   try:
--> 205     return func(*args, **kwargs)
    206   except ConverterError as converter_error:
    207     if converter_error.errors:

File /opt/conda/lib/python3.10/site-packages/tensorflow/lite/python/lite.py:1543, in TFLiteKerasModelConverterV2._freeze_keras_model(self)
   1537 input_signature = None
   1538 # If the model's call is not a `tf.function`, then we need to first get its
   1539 # input signature from `model_input_signature` method. We can't directly
   1540 # call `trace_model_call` because otherwise the batch dimension is set
   1541 # to None.
   1542 # Once we have better support for dynamic shapes, we can remove this.
-> 1543 if not isinstance(self._keras_model.call, _def_function.Function):
   1544   # Pass `keep_original_batch_size=True` will ensure that we get an input
   1545   # signature including the batch dimension specified by the user.
   1546   # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
   1547   input_signature = _model_input_signature(
   1548       self._keras_model, keep_original_batch_size=True
   1549   )
   1551 # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF

AttributeError: 'str' object has no attribute 'call'
python tensorflow keras pytorch tflite
1个回答
0
投票

您遇到的错误是由于您将文件路径(字符串)传递给 tf.lite.TFLiteConverter.from_keras_model,它需要 Keras 模型对象。要解决此问题,您需要使用 Keras load_model 函数正确加载模型,然后再将其传递给转换器。

正确用法:

keras_model = load_model(model_path)
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
© www.soinside.com 2019 - 2024. All rights reserved.