如何在嵌套作用域中使用tf.get_collection()作用域过滤

问题描述 投票:0回答:1

我试图通过在范围内定义变量并使用tf.get_collection()中的范围过滤来检索一组变量:

with tf.variable_scope('inner'):
    v = tf.get_variable(name='foo', shape=[1])
    ...
    # more variables
    ...

variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'inner')
# do stuff with variables

这通常工作正常但有时我的代码由已经定义了自己的范围的模块调用get_collection()不再找到变量:

with tf.variable_scope('outer'):
    with tf.variable_scope('inner'):
        v = tf.get_variable(name='foo', shape=[1])
        ...
        # more variables
        ...

我认为过滤是一个正则表达式,因为我可以通过在我的作用域搜索术语前加上.*来使get_collection()工作,但这有点hacky。有没有更好的方法来处理这个?

python tensorflow
1个回答
0
投票

当我想训练模型时,我使用get_collection(),但之前,我必须恢复模型数据,如下面的代码所示:

with tf.Session() as sess:
    last_check = tf.train.latest_checkpoint(tf_data)
    saver = tf.train.import_meta_graph(last_check + '.meta')
    print (last_check +'.meta')
    saver.restore(sess, last_check)
    ######
    Model_variables = tf.GraphKeys.MODEL_VARIABLES
    Global_Variables = tf.GraphKeys.GLOBAL_VARIABLES
    ######
    all_vars = tf.get_collection(Model_variables)
    # print (all_vars)
    pesos=[]
    for i in all_vars:
        print (str(i) + '  -->  '+ str(i.eval()))
© www.soinside.com 2019 - 2024. All rights reserved.