Tensorflow中自定义估算器中的批量标准化

问题描述 投票:2回答:2

我指的是tf.layers.batch_normilization的一个注释:

注意:训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作位于tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op。例如:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)

如何在Custom Estimator中实现这一点?例如,在Tensorflow的网站上查看这个例子:The complete abalone model_fn

tensorflow tensorboard
2个回答
0
投票

我想你可以通过train_op来引用EstimatorSpec的train_op参数。


1
投票

在下面的问题中,在最底部你有一个例子https://github.com/tensorflow/tensorflow/issues/16455

if mode == tf.estimator.ModeKeys.TRAIN:
    lr = 0.001
    optimizer = tf.train.RMSPropOptimizer(learning_rate=lr, decay=0.9)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op)
© www.soinside.com 2019 - 2024. All rights reserved.