避免在tensorflow中重复图形(LSTM模型)

问题描述 投票:1回答:1

我有以下简化代码(实际上,展开的LSTM模型):

def func(a, b):
    with tf.variable_scope('name'):
        res = tf.add(a, b)
    print(res.name)
    return res

func(tf.constant(10), tf.constant(20))

每当我运行最后一行时,它似乎都会改变图形。但我不希望图表发生变化。实际上我的代码是不同的,是一个神经网络模型,但它太大了,所以我添加了上面的代码。我想在不改变模型图的情况下调用func但它会改变。我读到了TensorFlow中的变量范围,但似乎我根本不理解它。

python tensorflow while-loop lstm recurrent-neural-network
1个回答
3
投票

你应该看一下tf.nn.dynamic_rnn的源代码,特别是在_dynamic_rnn_loop上的python/ops/rnn.py函数 - 它解决了同样的问题。为了不炸毁图形,它使用tf.while_loop重用相同的图形操作来获取新数据。但是这种方法增加了一些限制,即在循环中穿过的张量的形状必须是不变的。请参阅tf.while_loop文档中的示例:

i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m: i < 10
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
    c, b, loop_vars=[i0, m0],
    shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
© www.soinside.com 2019 - 2024. All rights reserved.