如何为 PyTorch TensorBoard 将多条 PR 曲线绘制到一张图表上?

问题描述 投票:0回答:1

下面是绘制PR曲线的代码取自https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html#assessing-trained-models-with-tensorboard

from torch.utils.tensorboard import SummaryWriter


writer = SummaryWriter()

def add_pr_curve_tensorboard(class_index, test_probs, test_preds):
    tensorboard_preds = test_preds == class_index
    tensorboard_probs = test_probs[:, class_index]

    writer.add_pr_curve(classes[class_index],
                        tensorboard_preds,
                        tensorboard_probs,
                        global_step=0)
    writer.close()


for i in range(len(classes)):
    add_pr_curve_tensorboard(i, test_probs, test_preds)

我想如果我将第一个参数

tag
更改为相同的值
'all_pr_curves'
,它将有助于在同一张图表上绘制所有PR曲线。

writer.add_pr_curve(
    'all_pr_curves',
    tensorboard_preds,
    tensorboard_probs,
    global_step=0)

但事实证明后一个会覆盖前一个。所以最后只显示最后一条 PR 曲线。

有没有办法在同一张图表上绘制所有 PR 曲线?

python pytorch tensorboard
1个回答
0
投票

你想要一个迷你图。遗憾的是,目前 Tensorboard 中不提供该功能。请参阅此讨论

© www.soinside.com 2019 - 2024. All rights reserved.