为什么我的 YOLO-v8 TFLite 模型比我的 Pytorch 模型慢?

问题描述 投票:0回答:1

我最初有一个经过训练的 Pytorch YOLO-v8 纳米模型,用于视频中的多目标检测(10 个类别 -“自行车”、“椅子”、“盒子”、“桌子”、“塑料袋”、“花盆”、“行李箱”、“雨伞”、“购物车”、“人”)。

我使用 ultralytics 库的导出功能将其转换为 TFLite 模型。然而,当我在视频流上运行这两个模型时,我的 TFLite 模型的运行速度(FPS 约为 8)比我的 Pytorch 模型(FPS 约为 20)慢得多。为什么会这样?

TFLite 和 Pytorch 模型都在这里:https://drive.google.com/drive/folders/1A2XUD5sV332nXv-Z756Di_QUIZV3ObFv?usp=sharing

从经过训练的 Pytorch 模型到 tflite 模型的转换。

from ultralytics import YOLO

model = YOLO("yolov8n_trained.pt")
path = model.export(format="tflite") 

在模型上运行视频流:

from ultralytics import YOLO
import cv2
from time import time

# Start webcam
stream_url = "video_stream_path" 

cap = cv2.VideoCapture(stream_url) # Use 0 for webcam
cap.set(3, 640)
cap.set(4, 640)

# Load your retrained YOLOv8 model 
model = YOLO("yolov8n_trained.tflite")   # use this when testing tflite model
# model = YOLO("yolov8n_trained.pt")     # use this when testing pytorch model

# Custom class names
classNames = ["Bicycle", "Chair", "Box", "Table", "Plastic bag", "Flowerpot", 
              "Luggage and bags", "Umbrella", "Shopping trolley", "Person"]

# initialise variables to calculate frame rate
prev_time = 0
fps = 0

while True:
    success, img = cap.read()
    if not success:
        break

    #Calculate time taken to process frame
    curr_time = time()
    fps = 1 / (curr_time - prev_time)
    prev_time = curr_time

    #Display frame rate on image
    cv2.putText(img, f"FPS: {fps:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)

    # Run the YOLO model on the image
    results = model(img, stream=True)

    # Process detection results
    for r in results:
        for box in r.boxes:
            # Extract bounding box coordinates and confidence
            x1, y1, x2, y2 = map(int, box.xyxy[0])  # Convert to integer
            confidence = box.conf[0]                # Confidence score
            class_id = box.cls[0]                   # Class ID

            # Draw bounding box
            color = (0, 255, 0)  # Green color for bounding box
            cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)

            # Draw label with class name and confidence
            label = f"{classNames[int(class_id)]}: {confidence:.2f}"
            label_size, base_line = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
            y1 = max(y1, label_size[1])
            cv2.rectangle(img, (x1, y1 - label_size[1]), (x1 + label_size[0], y1 + base_line), (0, 255, 0), cv2.FILLED)
            cv2.putText(img, label, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)

    # Display webcam feed
    cv2.imshow('Webcam', img)
    if cv2.waitKey(1) == ord('q'):
        break

# Release resources
cap.release()
cv2.destroyAllWindows()

machine-learning pytorch frame-rate yolov8 tflite
1个回答
0
投票

我们现在支持从 PyTorch 到 TF lite 的官方直接转换。您可以尝试一下:https://github.com/google-ai-edge/ai-edge-torch

此转换包括对 CPU 性能图的许多优化。

© www.soinside.com 2019 - 2024. All rights reserved.