如何在 PyTorch 中使用类型提示?如果我检查类中返回的模型,我会得到
models.common.AutoShape
。然而,models
显示为未知。
import torch
model : models.common.AutoShape = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
我也想明确输入结果。
您始终可以对模型使用类型检查:
from torch import nn
class MyModel(nn.Module):
... # write class definition here
my_model: MyModel = MyModel()
你的问题在于错误的导入:
AutoShape
不是从torch
导入的,而是从YOLOv5导入的(参见源),所以你必须像from yolov5.models import AutoShape
一样导入它。我不知道如何回答你的第二个问题。