如何在tf.nn.dynamic_rnn中为LSTM初始化intitial_state?

问题描述 投票:0回答:1

当单元格是LSTMCell时,我不确定如何传递initial_state的值。我正在使用LSTMStateTuple,因为它显示在下面的代码中:

c_placeholder = tf.placeholder(tf.float32, [ None, config.state_dim], name='c_lstm')

h_placeholder = tf.placeholder(tf.float32, [ None, config.state_dim], name='h_lstm')

state_tuple = tf.nn.rnn_cell.LSTMStateTuple(c_placeholder, h_placeholder)

cell = tf.contrib.rnn.LSTMCell(num_units=config.state_dim, state_is_tuple=True, reuse=not is_training)  

rnn_outs, states = tf.nn.dynamic_rnn(cell=cell, inputs=x,sequence_length=seqlen, initial_state=state_tuple, dtype= tf.float32)

但是,执行会返回此错误:

TypeError: 'Tensor' object is not iterable.

这是dynamic_rnn文档的链接

python tensorflow lstm
1个回答
0
投票

我之前见过同样的错误。我使用tf.contrib.rnn.MultiRNNCell制作的多层RNN细胞,我需要指定一个LSTMStateTuples元组 - 每层一个。就像是

state = tuple(
        [tf.nn.rnn_cell.LSTMStateTuple(c_ph[i], h_ph[i])
         for i in range(nRecurrentLayers)]
    )
© www.soinside.com 2019 - 2024. All rights reserved.