在TF Estimator中冻结和解冻网络层

问题描述 投票:0回答:1

我正在使用TF Estimator在数据集上训练我的模型。对于前几次训练迭代,我想冻结网络中的某些层。对于剩余的迭代,我想解冻这些层。

我找到了一些解决方案,我们在估算器的model_fn中有两个不同的优化器train_ops。

def ModelFunction(features, labels, mode, params):
    if mode == tf.estimator.ModeKeys.TRAIN:
        layerTrainingVars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "LayerName")
        #Train Op for freezing layers
        freeze_train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step(), var_list=layerTrainingVars)
        #Train Op for training all layers
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
        #Based on whether we want to freeze or not, we send the corresponding train_op to the estimatorSpec. How do I do this?
        estimatorSpec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=freeze_train_op)

    return estimatorSpec

对于上述解决方案,可以基于train_op返回相应的EstimatorSpec。我尝试使用freeze_train_op进行一些训练迭代,然后终止进程,并更改train_op以使代码中没有层冻结。执行此操作后,会出现检查点错误,表示检查点中保存的图形/变量不同。我猜第一组迭代没有保存冻结层。如何以编程方式切换train_ops,以便检查点也能正常工作?

有什么更好的方法可以在TF.Estimator中进行冷冻/解冻层训练吗?

python tensorflow tensorflow-estimator
1个回答
0
投票

您可以通过将它们组合在一起返回2 train_op。

def ModelFunction(features, labels, mode, params):
    if mode == tf.estimator.ModeKeys.TRAIN:
        layerTrainingVars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "LayerName")
        #Train Op for freezing layers
        freeze_train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step(), var_list=layerTrainingVars)
        #Train Op for training all layers
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
        estimatorSpec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=tf.group(freeze_train_op, train_op))

    return estimatorSpec

但这不会考虑不同的迭代。如果要在不同的迭代上训练不同的变量集,并且不想停止训练并从检查点加载权重,则需要使用会话。 Estimator api不允许会话管理。

© www.soinside.com 2019 - 2024. All rights reserved.