GradCAM 使用 Pytorch 可视化多标签多类问题

问题描述 投票:0回答:0

我正在尝试在 Pytorch 中使用 Resnet50 为多标签多类问题实现 GradCAM 可视化。我面临最后一个卷积层输出的一些问题。

这是我的代码:

def grad_cam(fname,模型): x = cv2.imread(fname) 变换 = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) )]) img = cv2.resize(x, (224, 224), 插值=cv2.INTER_AREA) img_tensor = 变换(img) img_tensor = img_tensor.unsqueeze(0) img_tensor = img_tensor.to(设备)

with torch.no_grad():
    model.eval()
    preds = model(img_tensor)
    pred_bool = (preds > 0.5).int().squeeze()

    print("This image has label:")
    class_names = ['A','b','c','d',.... ]

  # Loop through each predicted label and print the result
    for i in range(len(class_names)):
        print(f'{class_names[i]}: {pred_bool[i]}')

    last_conv_layer = model.layer4[-1].conv3
    model_out = model(img_tensor)
    

    # Register the forward hook on the desired layer
    **model.layer4[-1].conv3.register_forward_hook(hook_fn)
 
    #model_out, last_conv_layer_output = model(img_tensor, last_conv_layer)**
    class_out = model_out[:, torch.argmax(model_out)]
    grads = torch.autograd.grad(class_out, last_conv_layer_output)[0]
    pooled_grads = torch.mean(grads, dim=(0, 2, 3))

    heatmap = torch.mean(pooled_grads[:, :, None, None] * last_conv_layer_output, dim=1)
    heatmap = F.relu(heatmap)
    heatmap /= torch.max(heatmap)

heatmap = heatmap.squeeze().numpy()
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

INTENSITY = 0.5
img1 = heatmap * INTENSITY + img

plt.imshow(img)
plt.axis('off')
plt.show()

plt.imshow(img1)
plt.axis('off')
plt.show()

def hook_fn(模块、img_tensor、模型输出): 全局 last_conv_layer_output last_conv_layer_output = model_out 返回 last_conv_layer_output

我的模型如下:

model = torchvision.models.resnet50(预训练=真)
model.fc = torch.nn.Linear(model.fc.in_features, 14)

pytorch computer-vision multilabel-classification resnet
© www.soinside.com 2019 - 2024. All rights reserved.