实时 resnet 预测

问题描述 投票:0回答:1

我用 0 到 5 的数字手势训练了 resnet50 模型,我尝试部署它以通过笔记本电脑的网络摄像头预测实时课程。

虽然模型的准确度为 98%,而且我很确定错误不会发生,因为模型训练得不好,但实时值停留在 5 个类别中的 1 或 2 个类别,它们总是预测数字 0 和数字 2。

这是代码:

import torch
import torch.nn as nn
import cv2
import numpy as np
from torchvision import models, transforms
from PIL import Image  # Import PIL for image conversion

# Define the model architecture and load weights
class ResNet50Modified(nn.Module):
    def __init__(self, num_classes=6):
        super(ResNet50Modified, self).__init__()
        self.model = models.resnet50(pretrained=True)  # Use pretrained=True for better performance
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, X):
        return self.model(X)

# Load the trained model
model = ResNet50Modified(num_classes=6)
# Load the model's state_dict for CPU
model.load_state_dict(torch.load("resnet50_modified1.pth", map_location=torch.device('cpu')))
model.eval()

# Define transformations to match training preprocessing
preprocess = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize to input size of model
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize as per ResNet standards
])

# Labels for the signs
class_names = ['Class_0', 'Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5']  # Replace with actual sign names

# Open webcam for real-time prediction
cap = cv2.VideoCapture(0)

if not cap.isOpened():
    print("Error: Could not open webcam.")
    exit()

while True:
    ret, frame = cap.read()
    if not ret:
        print("Error: Could not read frame.")
        break

    # Convert the frame from BGR (OpenCV) to RGB (PIL)
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # Convert NumPy array to PIL Image
    pil_image = Image.fromarray(frame_rgb)

    # Preprocess the frame
    input_image = preprocess(pil_image)  # Use the PIL image for preprocessing
    input_image = input_image.unsqueeze(0)  # Add batch dimension

    # Predict using the model
    with torch.no_grad():
        outputs = model(input_image)
        
        # Apply softmax to get probabilities
        probabilities = torch.softmax(outputs, dim=1)
        
        # Get the predicted class and confidence
        _, predicted = torch.max(probabilities, 1)
        confidence = probabilities[0][predicted].item() * 100  # Convert to percentage
        label = class_names[predicted.item()]

    # Display the result with confidence level
    cv2.putText(frame, f"Predicted: {label}, Confidence: {confidence:.2f}%", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    cv2.imshow("Sign Detection", frame)

    # Exit on pressing 'q'
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

置信度始终很高,但标签错误,即使我在网络摄像头上显示 5 个手指,它也卡在零上。

我觉得问题出在帧处理上,有人对此有任何见解吗?

python opencv machine-learning pytorch
1个回答
0
投票

这个问题主要是由于训练期间类标签的编码和加载方式造成的。确保标签的编码和加载方式与训练期间相同,以便索引不会混乱,也就是说,如果在训练期间对类进行了排序,请确保它们在推理期间也进行了排序。我希望这有帮助。

© www.soinside.com 2019 - 2024. All rights reserved.