如何保存和重新加载 Tensorflow/Keras 和 Keras-cv YOLO 模型

问题描述 投票:0回答:1

我一直在研究 Keras 网站的示例: https://keras.io/examples/vision/yolov8/ 在 Tensorflow/Keras 中构建 YOLOv8 模型。我已成功训练模型,尽管在使用回调时出现错误,因此已删除回调并尝试在模型训练完成后手动保存模型。

yolo.fit(
    train_ds,
    validation_data=val_ds,
    epochs=3
)
yolo.save('my_yolo_mdl.keras')

yolo_load=tf.keras.models.load_model('my_yolo_mdl.keras')

我在这里收到模型未编译的警告,因此我按照与训练前相同的方式对其进行编译:

optimizer = tf.keras.optimizers.legacy.Adam(
    learning_rate=LEARNING_RATE,
    global_clipnorm=GLOBAL_CLIPNORM,
)

yolo_load.compile(
    optimizer=optimizer, classification_loss="binary_crossentropy", box_loss="ciou"
)

然后,当我尝试使用编译的模型和形状 (1,:,:,3) 的图像进行预测时,出现错误:

yolo_load.predict(img)

TypeError: in user code:

    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 2440, in predict_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 2425, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 2413, in run_step  **
        outputs = model.predict_step(data)
    File "/usr/local/lib/python3.10/dist-packages/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py", line 616, in predict_step
        return self.decode_predictions(outputs, args[-1])
    File "/usr/local/lib/python3.10/dist-packages/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py", line 609, in decode_predictions
        return self.prediction_decoder(box_preds, scores)

    TypeError: '_DictWrapper' object is not callable

如果我使用原始模型而不是加载的模型,此输入数组将按预期工作。我哪里出错了? 谢谢

tensorflow keras deep-learning yolov8
1个回答
0
投票

这里同样的问题。你找到这个问题的答案了吗?

© www.soinside.com 2019 - 2024. All rights reserved.