我已经从(here)训练了模型,现在我想在我的笔记本电脑上实时实现它。我尝试了多种方法来加载模型,但没有成功。
代码:
import cv2
import torch
import torchvision.transforms as transforms
from PIL import Image
# Load your ResNet model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=False)
pretrained_dict = torch.load('results/experiment/net_weights.pth', weights_only=True)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval() # Set the model to evaluation mode
# Define a transform to preprocess the input
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Initialize webcam
cap = cv2.VideoCapture(0)
while True:
ret, frame = cap.read()
if not ret:
break
# Convert the frame to a PIL image and apply transformations
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
img_t = transform(img).unsqueeze(0) # Add batch dimension
# Perform inference
with torch.no_grad():
output = model(img_t)
# Get the predicted class (for example, if you have a list of class names)
_, predicted_idx = torch.max(output, 1)
predicted_class = predicted_idx.item() # Convert tensor to integer
# Display the predicted class on the frame
cv2.putText(frame, f'Predicted: {predicted_class}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
# Show the frame
cv2.imshow('Real-time Computer Vision', frame)
if cv2.waitKey(1) & 0xFF == ord('q'): # Press 'q' to exit
break
# Release the capture and close windows
cap.release()
cv2.destroyAllWindows()
我还有一个logs.json文件,但我不确定它是否有用。
{
"batch_size": 32,
"data_augmentation": "standard",
"epochs": 50,
"model": "resnet50",
"optimizer": "adam",
"pretrained": true,
"train_biotic_stress_acc": [
72.54237288135593,
81.27118644067797,
82.62711864406779,
86.77966101694915,
87.71186440677967,
87.45762711864407,
87.28813559322033,
87.28813559322033,
90.67796610169492,
89.91525423728814,
89.23728813559322,
89.66101694915254,
90.9322033898305,
91.01694915254237,
91.52542372881356,
90.84745762711864,
88.98305084745763,
91.52542372881356,
90.84745762711864,
88.47457627118644,
91.52542372881356,
92.28813559322033,
92.88135593220339,
91.86440677966101,
94.23728813559322,
92.54237288135593,
90.59322033898304,
93.55932203389831,
93.30508474576271,
92.96610169491525,
93.47457627118644,
93.30508474576271,
90.9322033898305,
93.30508474576271,
91.1864406779661,
93.47457627118644,
92.96610169491525,
94.49152542372882,
95.2542372881356,
94.15254237288136,
94.7457627118644,
94.57627118644068,
95.0,
93.30508474576271,
93.72881355932203,
95.08474576271186,
93.72881355932203,
94.83050847457628,
93.22033898305085,
93.64406779661017
],
"train_loss": [
0.8282963037490845,
0.5836613774299622,
0.5153125524520874,
0.43084049224853516,
0.4390360713005066,
0.40780237317085266,
0.4658176600933075,
0.42227157950401306,
0.3433341979980469,
0.3600851893424988,
0.38862326741218567,
0.36267176270484924,
0.3588496446609497,
0.33311688899993896,
0.29994136095046997,
0.32682493329048157,
0.36273568868637085,
0.3315748870372772,
0.31836193799972534,
0.33695703744888306,
0.32421356439590454,
0.2762307822704315,
0.26089200377464294,
0.2845413386821747,
0.2687235176563263,
0.28719162940979004,
0.32435154914855957,
0.27531903982162476,
0.29826006293296814,
0.27290764451026917,
0.2443954199552536,
0.2814570963382721,
0.3201763331890106,
0.27812349796295166,
0.30083125829696655,
0.23732589185237885,
0.2818565368652344,
0.25493666529655457,
0.23779359459877014,
0.280385285615921,
0.22742299735546112,
0.23158197104930878,
0.23442105948925018,
0.25980499386787415,
0.24550049006938934,
0.21719005703926086,
0.23912093043327332,
0.23146015405654907,
0.247253879904747,
0.22993849217891693
],
"train_severity_acc": [
66.10169491525424,
75.50847457627118,
77.45762711864407,
78.47457627118644,
77.79661016949153,
80.50847457627118,
78.13559322033899,
80.2542372881356,
81.94915254237289,
81.69491525423729,
79.49152542372882,
80.59322033898304,
79.66101694915254,
83.22033898305085,
85.0,
81.35593220338983,
81.35593220338983,
82.71186440677967,
84.0677966101695,
84.7457627118644,
83.38983050847457,
84.91525423728814,
86.44067796610169,
86.86440677966101,
85.84745762711864,
84.49152542372882,
84.15254237288136,
83.98305084745763,
82.37288135593221,
85.59322033898304,
89.23728813559322,
84.15254237288136,
84.40677966101696,
84.49152542372882,
85.16949152542372,
88.30508474576271,
84.57627118644068,
84.57627118644068,
85.16949152542372,
84.83050847457628,
87.37288135593221,
87.45762711864407,
87.54237288135593,
85.08474576271186,
86.94915254237289,
87.45762711864407,
87.11864406779661,
86.35593220338983,
87.20338983050847,
89.15254237288136
],
"val_biotic_stress_acc": [
53.359683794466406,
76.28458498023716,
79.44664031620553,
82.21343873517786,
74.70355731225297,
82.21343873517786,
85.37549407114625,
77.07509881422925,
90.9090909090909,
86.56126482213439,
83.79446640316206,
86.16600790513834,
73.51778656126483,
89.32806324110672,
85.7707509881423,
87.35177865612648,
89.32806324110672,
78.26086956521739,
78.26086956521739,
83.79446640316206,
82.21343873517786,
86.95652173913044,
89.32806324110672,
83.79446640316206,
91.699604743083,
89.32806324110672,
90.51383399209486,
86.56126482213439,
88.14229249011858,
90.51383399209486,
85.7707509881423,
84.18972332015811,
82.21343873517786,
89.32806324110672,
87.74703557312253,
90.51383399209486,
85.7707509881423,
86.95652173913044,
89.32806324110672,
89.32806324110672,
91.699604743083,
86.56126482213439,
88.14229249011858,
91.699604743083,
89.32806324110672,
88.14229249011858,
86.95652173913044,
90.9090909090909,
92.09486166007905,
84.9802371541502
],
"val_loss": [
12.854988098144531,
0.7233027815818787,
0.565257728099823,
0.3938770592212677,
0.5002809166908264,
0.4399007558822632,
0.7318148016929626,
0.6214114427566528,
0.31451746821403503,
0.41589120030403137,
0.5394357442855835,
0.3834279775619507,
0.6205282211303711,
0.3562229573726654,
0.4006038308143616,
0.3478359878063202,
0.3223670721054077,
0.6522853374481201,
0.5050507187843323,
0.5499345064163208,
0.4598318040370941,
0.46111157536506653,
0.2799707353115082,
0.37483757734298706,
0.30881232023239136,
0.35001036524772644,
0.33550742268562317,
0.48047876358032227,
0.4079080820083618,
0.3100490868091583,
0.38654884696006775,
0.7133758068084717,
0.4948892295360565,
0.3370022475719452,
0.34756383299827576,
0.28952303528785706,
0.41055554151535034,
0.41655051708221436,
0.4456130862236023,
0.3225994408130646,
0.3737548291683197,
0.35072797536849976,
0.35582756996154785,
0.28029096126556396,
0.3270091712474823,
0.5963916182518005,
0.4246876537799835,
0.28707748651504517,
0.31602567434310913,
0.41391074657440186
],
"val_severity_acc": [
45.8498023715415,
69.96047430830039,
77.86561264822134,
83.39920948616601,
82.21343873517786,
81.81818181818181,
70.75098814229248,
73.12252964426878,
86.56126482213439,
79.05138339920948,
77.86561264822134,
85.37549407114625,
84.9802371541502,
83.39920948616601,
84.58498023715416,
83.79446640316206,
84.18972332015811,
80.23715415019763,
89.72332015810277,
79.84189723320158,
81.81818181818181,
79.84189723320158,
87.74703557312253,
88.53754940711462,
83.00395256916995,
84.58498023715416,
84.9802371541502,
73.51778656126483,
78.26086956521739,
84.58498023715416,
84.58498023715416,
72.33201581027669,
81.42292490118577,
86.16600790513834,
85.37549407114625,
88.14229249011858,
85.37549407114625,
81.02766798418972,
79.44664031620553,
86.95652173913044,
79.44664031620553,
87.74703557312253,
85.37549407114625,
87.74703557312253,
84.9802371541502,
65.21739130434783,
83.79446640316206,
85.7707509881423,
83.79446640316206,
85.7707509881423
],
"weight_decay": 0.0005
}
如何实时实现我的模型,以便在屏幕上显示疾病的严重程度和类型?
我不太确定“实时”在您的情况下意味着什么。但我假设您试图在 while 循环执行后立即观察其结果,并相应地查看下降曲线。
如果这是您的问题,您应该尝试使用tensorboard等日志记录工具。这是关于如何在 pytorch 中使用它的快速教程。简而言之,您应该在代码中插入
writer.add_scaler("severity", value)
并使用终端启动网络应用程序来观察您的曲线。因为是Web应用程序,所以必须手动刷新才能查看最新数据。