我指的是tf.layers.batch_normilization的一个注释:
注意:训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作位于tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op。例如:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
如何在Custom Estimator中实现这一点?例如,在Tensorflow的网站上查看这个例子:The complete abalone model_fn
我想你可以通过train_op来引用EstimatorSpec的train_op参数。
在下面的问题中,在最底部你有一个例子https://github.com/tensorflow/tensorflow/issues/16455
if mode == tf.estimator.ModeKeys.TRAIN:
lr = 0.001
optimizer = tf.train.RMSPropOptimizer(learning_rate=lr, decay=0.9)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
train_op=train_op)