我正在研究 14 种不同疾病类别的多标签分类任务。我已经训练了我的模型,并且想要生成一个多标签混淆矩阵,其中 x 轴和 y 轴代表 14 个类别。
但是,当我尝试使用当前代码生成混淆矩阵时,它会为每个类创建一个单独的混淆矩阵。相反,我想要一个统一的混淆矩阵,其中真实标签和预测标签位于两个轴上相同的 14 个类。
以下是我的设置的关键细节:
这是我用来生成混淆矩阵的代码:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import multilabel_confusion_matrix
# Initialize lists to store true labels and predictions
all_labels = []
all_predictions = []
# Disable gradient calculation for inference
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
# Forward pass to get model predictions
outputs = best_model(images)
# Store true labels and predicted probabilities
all_labels.extend(labels.cpu().numpy())
all_predictions.extend(outputs.cpu().numpy()) # Get the raw output (probabilities)
# Convert to numpy arrays for easier manipulation
all_labels = np.array(all_labels)
all_predictions = np.array(all_predictions)
# Apply thresholding to convert probabilities to binary predictions
binary_predictions = (all_predictions > 0.5).astype(int)
# Compute the multilabel confusion matrix
confusion_mtx = multilabel_confusion_matrix(all_labels, binary_predictions)
# Function to plot the multilabel confusion matrix
def plot_multilabel_confusion_matrix(confusion_mtx, class_names):
num_classes = confusion_mtx.shape[0]
ncols = 3 # Set the number of columns for the plot
nrows = (num_classes + ncols - 1) // ncols # Calculate the number of rows needed
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 4, nrows * 4))
axes = axes.flatten() # Flatten the 2D array of axes for easy iteration
for i in range(num_classes):
ax = axes[i]
ax.matshow(confusion_mtx[i], cmap=plt.cm.Blues, alpha=0.5)
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title(class_names[i])
# Set x and y axis ticks to show "Positive" first and "Negative" second
ax.set_xticks([0, 1])
ax.set_xticklabels(['Positive', 'Negative']) # Positive first
ax.set_yticks([0, 1])
ax.set_yticklabels(['Positive', 'Negative']) # Positive first
# Show the counts
for j in range(confusion_mtx[i].shape[0]):
for k in range(confusion_mtx[i].shape[1]):
ax.text(k, j, confusion_mtx[i][j, k], ha='center', va='center')
# Hide any unused subplots
for i in range(num_classes, len(axes)):
axes[i].axis('off')
plt.tight_layout()
plt.show()
# Plot the multilabel confusion matrix
plot_multilabel_confusion_matrix(confusion_mtx, class_names)
我得到的是 14 个独立的混淆矩阵,但我需要一个混淆矩阵,所有 14 个类都在两个轴上表示。
文档指出
multilabel_confusion_matrix
不是您要找的东西 :
multilabel_confusion_matrix 计算类或样本的多标签混淆矩阵,在多类任务中,标签以一对一的方式二值化
相反,我相信您想使用confusion_matrix,其文档指出它返回:
形状为 (n_classes, n_classes) 的混淆矩阵,其第 i 行和第 j 列条目表示真实标签为第 i 类、预测标签为第 j 类的样本数量。