我正在尝试从 YOLOv8 模型中提取分类主干。我在 Ultralytics 官方存储库上看到了一个关于此问题的问题,但是当我尝试提到的方法时,它不起作用。 这是问题的链接
我想要做的是提取主干的权重,然后使用它,就好像它是一个分类模型来执行 FGSM 攻击,因为我尝试在整个模型上执行此操作,即使在我从中提取了类 logits 之后,它也不起作用yolo 使用 pytorch hooks
import torch
import torch.nn as nn
from ultralytics import YOLO
# Step 1: Load trained YOLOv8 model
model = YOLO('/content/best.pt')
image = '/content/stop.png'
# Step 2: Extract the backbone (CSPDarknet53)
backbone = model.model[0]
labels = torch.tensor([22])
num_classes = 29
classify_model = nn.Sequential(
backbone, # Use the CSPDarknet53 backbone
nn.AdaptiveAvgPool2d((1, 1)), # Global Average Pooling to reduce to (batch_size, channels, 1, 1)
nn.Flatten(), # Flatten to (batch_size, channels)
nn.Linear(in_features=backbone[-1].out_channels, out_features=num_classes) # Linear layer for classification
)
# Example FGSM attack
def fgsm_attack(model, images, labels, epsilon):
images.requires_grad = True
outputs = model(images)
loss = nn.CrossEntropyLoss()(outputs, labels)
model.zero_grad()
loss.backward() # Compute gradients with respect to input images
grad_sign = images.grad.data.sign()
perturbed_image = images + epsilon * grad_sign # Apply perturbation
return perturbed_image
# apply FGSM on `classify_model`
perturbed_image = fgsm_attack(classify_model, image, labels, 0.3)
但是我收到此错误“TypeError:'DetectionModel'对象不可订阅”
我尝试对整个模型进行 FGSM,即使在我使用 pytorch hooks 从 yolo 中提取类 logits 之后,它也不起作用。我还尝试提取 YOLOv8 主干并将其用作分类模型,但这也不起作用。
尝试使用此解决方案在代码中提取 YOLOv8 主干:
# Step 2: Extract the backbone (CSPDarknet53)
backbone = model.model.model[:10]
首先,不要只是
model.model
更深入地了解模型结构,以避免出现 'TypeError: 'DetectionModel' object is not subscriptable' 错误。
其次,如果您查看模型的结构,主干是模型的前 10 层(0-9),而不仅仅是第一个。