如何实时实现.pth模型?

问题描述 投票:0回答:1

我已经从(here)训练了模型,现在我想在我的笔记本电脑上实时实现它。我尝试了多种方法来加载模型,但没有成功。

代码:

import cv2
import torch
import torchvision.transforms as transforms
from PIL import Image

# Load your ResNet model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=False)
pretrained_dict = torch.load('results/experiment/net_weights.pth', weights_only=True)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval()  # Set the model to evaluation mode

# Define a transform to preprocess the input
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Initialize webcam
cap = cv2.VideoCapture(0)

while True:
    ret, frame = cap.read()
    if not ret:
        break

    # Convert the frame to a PIL image and apply transformations
    img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    img_t = transform(img).unsqueeze(0)  # Add batch dimension

    # Perform inference
    with torch.no_grad():
        output = model(img_t)
    
    # Get the predicted class (for example, if you have a list of class names)
    _, predicted_idx = torch.max(output, 1)
    predicted_class = predicted_idx.item()  # Convert tensor to integer

    # Display the predicted class on the frame
    cv2.putText(frame, f'Predicted: {predicted_class}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
    
    # Show the frame
    cv2.imshow('Real-time Computer Vision', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):  # Press 'q' to exit
        break

# Release the capture and close windows
cap.release()
cv2.destroyAllWindows()

我还有一个logs.json文件,但我不确定它是否有用。

{
    "batch_size": 32,
    "data_augmentation": "standard",
    "epochs": 50,
    "model": "resnet50",
    "optimizer": "adam",
    "pretrained": true,
    "train_biotic_stress_acc": [
        72.54237288135593,
        81.27118644067797,
        82.62711864406779,
        86.77966101694915,
        87.71186440677967,
        87.45762711864407,
        87.28813559322033,
        87.28813559322033,
        90.67796610169492,
        89.91525423728814,
        89.23728813559322,
        89.66101694915254,
        90.9322033898305,
        91.01694915254237,
        91.52542372881356,
        90.84745762711864,
        88.98305084745763,
        91.52542372881356,
        90.84745762711864,
        88.47457627118644,
        91.52542372881356,
        92.28813559322033,
        92.88135593220339,
        91.86440677966101,
        94.23728813559322,
        92.54237288135593,
        90.59322033898304,
        93.55932203389831,
        93.30508474576271,
        92.96610169491525,
        93.47457627118644,
        93.30508474576271,
        90.9322033898305,
        93.30508474576271,
        91.1864406779661,
        93.47457627118644,
        92.96610169491525,
        94.49152542372882,
        95.2542372881356,
        94.15254237288136,
        94.7457627118644,
        94.57627118644068,
        95.0,
        93.30508474576271,
        93.72881355932203,
        95.08474576271186,
        93.72881355932203,
        94.83050847457628,
        93.22033898305085,
        93.64406779661017
    ],
    "train_loss": [
        0.8282963037490845,
        0.5836613774299622,
        0.5153125524520874,
        0.43084049224853516,
        0.4390360713005066,
        0.40780237317085266,
        0.4658176600933075,
        0.42227157950401306,
        0.3433341979980469,
        0.3600851893424988,
        0.38862326741218567,
        0.36267176270484924,
        0.3588496446609497,
        0.33311688899993896,
        0.29994136095046997,
        0.32682493329048157,
        0.36273568868637085,
        0.3315748870372772,
        0.31836193799972534,
        0.33695703744888306,
        0.32421356439590454,
        0.2762307822704315,
        0.26089200377464294,
        0.2845413386821747,
        0.2687235176563263,
        0.28719162940979004,
        0.32435154914855957,
        0.27531903982162476,
        0.29826006293296814,
        0.27290764451026917,
        0.2443954199552536,
        0.2814570963382721,
        0.3201763331890106,
        0.27812349796295166,
        0.30083125829696655,
        0.23732589185237885,
        0.2818565368652344,
        0.25493666529655457,
        0.23779359459877014,
        0.280385285615921,
        0.22742299735546112,
        0.23158197104930878,
        0.23442105948925018,
        0.25980499386787415,
        0.24550049006938934,
        0.21719005703926086,
        0.23912093043327332,
        0.23146015405654907,
        0.247253879904747,
        0.22993849217891693
    ],
    "train_severity_acc": [
        66.10169491525424,
        75.50847457627118,
        77.45762711864407,
        78.47457627118644,
        77.79661016949153,
        80.50847457627118,
        78.13559322033899,
        80.2542372881356,
        81.94915254237289,
        81.69491525423729,
        79.49152542372882,
        80.59322033898304,
        79.66101694915254,
        83.22033898305085,
        85.0,
        81.35593220338983,
        81.35593220338983,
        82.71186440677967,
        84.0677966101695,
        84.7457627118644,
        83.38983050847457,
        84.91525423728814,
        86.44067796610169,
        86.86440677966101,
        85.84745762711864,
        84.49152542372882,
        84.15254237288136,
        83.98305084745763,
        82.37288135593221,
        85.59322033898304,
        89.23728813559322,
        84.15254237288136,
        84.40677966101696,
        84.49152542372882,
        85.16949152542372,
        88.30508474576271,
        84.57627118644068,
        84.57627118644068,
        85.16949152542372,
        84.83050847457628,
        87.37288135593221,
        87.45762711864407,
        87.54237288135593,
        85.08474576271186,
        86.94915254237289,
        87.45762711864407,
        87.11864406779661,
        86.35593220338983,
        87.20338983050847,
        89.15254237288136
    ],
    "val_biotic_stress_acc": [
        53.359683794466406,
        76.28458498023716,
        79.44664031620553,
        82.21343873517786,
        74.70355731225297,
        82.21343873517786,
        85.37549407114625,
        77.07509881422925,
        90.9090909090909,
        86.56126482213439,
        83.79446640316206,
        86.16600790513834,
        73.51778656126483,
        89.32806324110672,
        85.7707509881423,
        87.35177865612648,
        89.32806324110672,
        78.26086956521739,
        78.26086956521739,
        83.79446640316206,
        82.21343873517786,
        86.95652173913044,
        89.32806324110672,
        83.79446640316206,
        91.699604743083,
        89.32806324110672,
        90.51383399209486,
        86.56126482213439,
        88.14229249011858,
        90.51383399209486,
        85.7707509881423,
        84.18972332015811,
        82.21343873517786,
        89.32806324110672,
        87.74703557312253,
        90.51383399209486,
        85.7707509881423,
        86.95652173913044,
        89.32806324110672,
        89.32806324110672,
        91.699604743083,
        86.56126482213439,
        88.14229249011858,
        91.699604743083,
        89.32806324110672,
        88.14229249011858,
        86.95652173913044,
        90.9090909090909,
        92.09486166007905,
        84.9802371541502
    ],
    "val_loss": [
        12.854988098144531,
        0.7233027815818787,
        0.565257728099823,
        0.3938770592212677,
        0.5002809166908264,
        0.4399007558822632,
        0.7318148016929626,
        0.6214114427566528,
        0.31451746821403503,
        0.41589120030403137,
        0.5394357442855835,
        0.3834279775619507,
        0.6205282211303711,
        0.3562229573726654,
        0.4006038308143616,
        0.3478359878063202,
        0.3223670721054077,
        0.6522853374481201,
        0.5050507187843323,
        0.5499345064163208,
        0.4598318040370941,
        0.46111157536506653,
        0.2799707353115082,
        0.37483757734298706,
        0.30881232023239136,
        0.35001036524772644,
        0.33550742268562317,
        0.48047876358032227,
        0.4079080820083618,
        0.3100490868091583,
        0.38654884696006775,
        0.7133758068084717,
        0.4948892295360565,
        0.3370022475719452,
        0.34756383299827576,
        0.28952303528785706,
        0.41055554151535034,
        0.41655051708221436,
        0.4456130862236023,
        0.3225994408130646,
        0.3737548291683197,
        0.35072797536849976,
        0.35582756996154785,
        0.28029096126556396,
        0.3270091712474823,
        0.5963916182518005,
        0.4246876537799835,
        0.28707748651504517,
        0.31602567434310913,
        0.41391074657440186
    ],
    "val_severity_acc": [
        45.8498023715415,
        69.96047430830039,
        77.86561264822134,
        83.39920948616601,
        82.21343873517786,
        81.81818181818181,
        70.75098814229248,
        73.12252964426878,
        86.56126482213439,
        79.05138339920948,
        77.86561264822134,
        85.37549407114625,
        84.9802371541502,
        83.39920948616601,
        84.58498023715416,
        83.79446640316206,
        84.18972332015811,
        80.23715415019763,
        89.72332015810277,
        79.84189723320158,
        81.81818181818181,
        79.84189723320158,
        87.74703557312253,
        88.53754940711462,
        83.00395256916995,
        84.58498023715416,
        84.9802371541502,
        73.51778656126483,
        78.26086956521739,
        84.58498023715416,
        84.58498023715416,
        72.33201581027669,
        81.42292490118577,
        86.16600790513834,
        85.37549407114625,
        88.14229249011858,
        85.37549407114625,
        81.02766798418972,
        79.44664031620553,
        86.95652173913044,
        79.44664031620553,
        87.74703557312253,
        85.37549407114625,
        87.74703557312253,
        84.9802371541502,
        65.21739130434783,
        83.79446640316206,
        85.7707509881423,
        83.79446640316206,
        85.7707509881423
    ],
    "weight_decay": 0.0005
}

如何实时实现我的模型,以便在屏幕上显示疾病的严重程度和类型?

python pytorch torch torchvision
1个回答
0
投票

我不太确定“实时”在您的情况下意味着什么。但我假设您试图在 while 循环执行后立即观察其结果,并相应地查看下降曲线。

如果这是您的问题,您应该尝试使用tensorboard等日志记录工具。这是关于如何在 pytorch 中使用它的快速教程。简而言之,您应该在代码中插入

writer.add_scaler("severity", value)
并使用终端启动网络应用程序来观察您的曲线。因为是Web应用程序,所以必须手动刷新才能查看最新数据。

Tensorboard 的界面应该是这样的: enter image description here

© www.soinside.com 2019 - 2024. All rights reserved.