使用tf.metrics.auc进行培训和验证[关闭]

问题描述 投票:0回答:1

当使用相同的模型进行训练和验证时,是否有正确的方法来使用tf.metrics.auc?我目前将我的模型作为一个类实现,并使用tf.data字符串句柄在训练和验证数据之间进行交换。因为tf.metrics.auc函数会创建局部变量,所以我需要采取其他步骤来确保它们不在培训和验证之间共享吗?

最终,我想将ROC和PR值下的区域提供给tf.summary.scalar,在那里我可以看到并排的训练和验证,所以这更加复杂,因为使用tf.variable_scope将图分割成Tensorboard中的单独部分。

tensorflow
1个回答
1
投票

我设法通过创建一个reset_op来解决这个问题,我曾经在预先形成汇总步骤之前将局部变量归零

eval_reset_ops = []
for scope in ['outcome_' + str(i) for i in range(cfg.num_class)]:
    for var in tf.local_variables(scope=scope):
        eval_reset_ops.append(tf.assign(var, tf.zeros_like(var)))

然后我运行以下部分来生成摘要。

_ = sess.run(eval_reset_ops)
for _ in range(n_summary_batches):
    eval_metrics, tr_losses = sess.run([model.eval_ops, model.losses])
summary = sess.run(merged)
train_writer.add_summary(summary, step)
train_writer.flush()

在这种情况下,model.eval_opstf.metrics.auc生成的更新操作列表。

如...

auc, auc_op = tf.metrics.auc(
    labels=labels, predictions=prob, num_thresholds=1000,
   updates_collections='eval_updates', curve='ROC')
aupr, pr_op = tf.metrics.auc(
    labels=labels, predictions=prob, num_thresholds=1000,
    updates_collections='eval_updates', curve='PR')
eval_ops = [auc_op, pr_op]

这似乎工作正常。我希望有人认为这有用,或者可以指出如何改进它。

© www.soinside.com 2019 - 2024. All rights reserved.