我用 0 到 5 的数字手势训练了 resnet50 模型,我尝试部署它以通过笔记本电脑的网络摄像头预测实时课程。
虽然模型的准确度为 98%,而且我很确定错误不会发生,因为模型训练得不好,但实时值停留在 5 个类别中的 1 或 2 个类别,它们总是预测数字 0 和数字 2。
这是代码:
import torch
import torch.nn as nn
import cv2
import numpy as np
from torchvision import models, transforms
from PIL import Image # Import PIL for image conversion
# Define the model architecture and load weights
class ResNet50Modified(nn.Module):
def __init__(self, num_classes=6):
super(ResNet50Modified, self).__init__()
self.model = models.resnet50(pretrained=True) # Use pretrained=True for better performance
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
def forward(self, X):
return self.model(X)
# Load the trained model
model = ResNet50Modified(num_classes=6)
# Load the model's state_dict for CPU
model.load_state_dict(torch.load("resnet50_modified1.pth", map_location=torch.device('cpu')))
model.eval()
# Define transformations to match training preprocessing
preprocess = transforms.Compose([
transforms.Resize((64, 64)), # Resize to input size of model
transforms.ToTensor(), # Convert to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize as per ResNet standards
])
# Labels for the signs
class_names = ['Class_0', 'Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5'] # Replace with actual sign names
# Open webcam for real-time prediction
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("Error: Could not open webcam.")
exit()
while True:
ret, frame = cap.read()
if not ret:
print("Error: Could not read frame.")
break
# Convert the frame from BGR (OpenCV) to RGB (PIL)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Convert NumPy array to PIL Image
pil_image = Image.fromarray(frame_rgb)
# Preprocess the frame
input_image = preprocess(pil_image) # Use the PIL image for preprocessing
input_image = input_image.unsqueeze(0) # Add batch dimension
# Predict using the model
with torch.no_grad():
outputs = model(input_image)
# Apply softmax to get probabilities
probabilities = torch.softmax(outputs, dim=1)
# Get the predicted class and confidence
_, predicted = torch.max(probabilities, 1)
confidence = probabilities[0][predicted].item() * 100 # Convert to percentage
label = class_names[predicted.item()]
# Display the result with confidence level
cv2.putText(frame, f"Predicted: {label}, Confidence: {confidence:.2f}%", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("Sign Detection", frame)
# Exit on pressing 'q'
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
置信度始终很高,但标签错误,即使我在网络摄像头上显示 5 个手指,它也卡在零上。
我觉得问题出在帧处理上,有人对此有任何见解吗?
这个问题主要是由于训练期间类标签的编码和加载方式造成的。确保标签的编码和加载方式与训练期间相同,以便索引不会混乱,也就是说,如果在训练期间对类进行了排序,请确保它们在推理期间也进行了排序。我希望这有帮助。