计算物体检测的混淆矩阵的正确方法是什么?

问题描述 投票:0回答:2

我正在尝试计算我的对象检测模型的混淆矩阵。然而,我似乎偶然发现了一些陷阱。我当前的方法是将每个预测框与每个地面实况框进行比较。如果它们的 IoU > 某个阈值,我会将预测插入到混淆矩阵中。插入后,我删除预测列表中的元素并移至下一个元素。

因为我也希望将错误分类的提案插入到混淆矩阵中,所以我将 IoU 低于阈值的元素视为与背景混淆。我当前的实现如下所示:

def insert_into_conf_m(true_labels, predicted_labels, true_boxes, predicted_boxes):
    matched_gts = []
    for i in range(len(true_labels)):
        j = 0
        while len(predicted_labels) != 0:
            if j >= len(predicted_boxes):
                break
            if bb_intersection_over_union(true_boxes[i], predicted_boxes[j]) >= 0.7:
                conf_m[true_labels[i]][predicted_labels[j]] += 1
                del predicted_boxes[j]
                del predicted_labels[j]
            else:
                j += 1
        matched_gts.append(true_labels[i])
        if len(predicted_labels) == 0:
            break
    # if there are ground-truth boxes that are not matched by any proposal
    # they are treated as if the model classified them as background
    if len(true_labels) > len(matched_gts):
        true_labels = [i for i in true_labels if not i in matched_gts or matched_gts.remove(i)]
        for i in range(len(true_labels)):
            conf_m[true_labels[i]][0] += 1

    # all detections that have no IoU with any groundtruth box are treated
    # as if the groundtruth label for this region was Background (0)
    if len(predicted_labels) != 0:
        for j in range(len(predicted_labels)):
            conf_m[0][predicted_labels[j]] += 1

行归一化矩阵如下所示:

[0.0, 0.36, 0.34, 0.30]
[0.0, 0.29, 0.30, 0.41]
[0.0, 0.20, 0.47, 0.33]
[0.0, 0.23, 0.19, 0.58]

有没有更好的方法来生成对象检测系统的混淆矩阵?或者有其他更合适的指标吗?

python object-detection confusion-matrix
2个回答
7
投票

这里有一个脚本,用于根据 TensorFlow 对象检测 API 生成的 detectors.record 文件计算混淆矩阵。 这里是文章解释这个脚本是如何工作的。

总而言之,以下是文章中的算法概要:

  1. 对于每个检测记录,算法从输入文件中提取真实框和类,以及检测到的 盒子、班级和分数。

  2. 仅考虑分数大于或等于 0.5 的检测。任何低于该值的内容都会被丢弃。

  3. 对于每个真实框,该算法会与每个检测到的框生成 IoU(并交交集)。找到匹配项,如果 两个盒子的 IoU 都大于或等于 0.5。

  4. 匹配列表被修剪以删除重复项(与多个检测框匹配的真实框,反之亦然)。如果 有重复项,始终选择最佳匹配(更大的 IoU)。

  5. 更新混淆矩阵以反映真实值和检测结果之间的匹配结果。

  6. 属于真实事实一部分但未检测到的对象计入矩阵的最后一列(对应于 真实类别)。已检测到但不属于其中一部分的对象 混淆矩阵计入矩阵的最后一行(在 对应于检测到的类别的列)。

您还可以查看脚本以获取更多信息。


0
投票

这也是使用 PyTorch 和 TorchMetrics 包的实现:

https://github.com/gui-miotto/object_detection_confusion_matrix/

免责声明:我写了这段代码。

© www.soinside.com 2019 - 2024. All rights reserved.