如何在 PyTorch 中创建 14 种疾病类别的多标签混淆矩阵?

问题描述 投票:0回答:1

我正在研究 14 种不同疾病类别的多标签分类任务。我已经训练了我的模型,并且想要生成一个多标签混淆矩阵,其中 x 轴和 y 轴代表 14 个类别。

但是,当我尝试使用当前代码生成混淆矩阵时,它会为每个类创建一个单独的混淆矩阵。相反,我想要一个统一的混淆矩阵,其中真实标签和预测标签位于两个轴上相同的 14 个类。

以下是我的设置的关键细节:

  • 我正在使用 PyTorch。
  • 我有一个经过训练的模型并可以访问 test_loader。
  • 我已经成功地使用我的模型进行预测,但我一直困惑于如何将结果聚合到单个混淆矩阵中以进行多标签分类。

这是我用来生成混淆矩阵的代码:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import multilabel_confusion_matrix

# Initialize lists to store true labels and predictions
all_labels = []
all_predictions = []

# Disable gradient calculation for inference
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass to get model predictions
        outputs = best_model(images)

        # Store true labels and predicted probabilities
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(outputs.cpu().numpy())  # Get the raw output (probabilities)

# Convert to numpy arrays for easier manipulation
all_labels = np.array(all_labels)
all_predictions = np.array(all_predictions)

# Apply thresholding to convert probabilities to binary predictions
binary_predictions = (all_predictions > 0.5).astype(int)

# Compute the multilabel confusion matrix
confusion_mtx = multilabel_confusion_matrix(all_labels, binary_predictions)

# Function to plot the multilabel confusion matrix
def plot_multilabel_confusion_matrix(confusion_mtx, class_names):
    num_classes = confusion_mtx.shape[0]
    ncols = 3  # Set the number of columns for the plot
    nrows = (num_classes + ncols - 1) // ncols  # Calculate the number of rows needed

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 4, nrows * 4))
    axes = axes.flatten()  # Flatten the 2D array of axes for easy iteration

    for i in range(num_classes):
        ax = axes[i]
        ax.matshow(confusion_mtx[i], cmap=plt.cm.Blues, alpha=0.5)
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title(class_names[i])

        # Set x and y axis ticks to show "Positive" first and "Negative" second
        ax.set_xticks([0, 1])
        ax.set_xticklabels(['Positive', 'Negative'])  # Positive first
        ax.set_yticks([0, 1])
        ax.set_yticklabels(['Positive', 'Negative'])  # Positive first

        # Show the counts
        for j in range(confusion_mtx[i].shape[0]):
            for k in range(confusion_mtx[i].shape[1]):
                ax.text(k, j, confusion_mtx[i][j, k], ha='center', va='center')

    # Hide any unused subplots
    for i in range(num_classes, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

# Plot the multilabel confusion matrix
plot_multilabel_confusion_matrix(confusion_mtx, class_names)

我得到的是 14 个独立的混淆矩阵,但我需要一个混淆矩阵,所有 14 个类都在两个轴上表示。

这是我得到的图片: enter image description here

这就是我想要的: enter image description here

deep-learning pytorch confusion-matrix multilabel-classification multiclass-classification
1个回答
0
投票

文档指出

multilabel_confusion_matrix
不是您要找的东西 :

multilabel_confusion_matrix 计算类或样本的多标签混淆矩阵,在多类任务中,标签以一对一的方式二值化

相反,我相信您想使用confusion_matrix,其文档指出它返回:

形状为 (n_classes, n_classes) 的混淆矩阵,其第 i 行和第 j 列条目表示真实标签为第 i 类、预测标签为第 j 类的样本数量。

© www.soinside.com 2019 - 2024. All rights reserved.