我想我们可以使用以下代码段来创建一堆LSTM并将其状态初始化为零。
lstm_cell = tf.contrib.rnn.BasicLSTMCell(
hidden_size, forget_bias=0.0, state_is_tuple=True)
cell = tf.contrib.rnn.MultiRNNCell([lstm_cell] * num_layers, state_is_tuple=True)
cell.zero_state(batch_size, tf_float32)
我不想使用BasicLSTMCell,而是使用CUDNN
cudnn_cell = tf.contrib.cudnn_rnn.CudnnLSTM(
num_layers, hidden_size, dropout=config.keep_prob)
在这种情况下,我怎么能在cudnn_cell上做与cell.zero_state(batch_size, tf_float32)
相同的事情?
定义可以在:tensorflow cudnn_rnn's code中找到
关于initial_states:
with tf.Graph().as_default():
lstm = CudnnLSTM(num_layers, num_units, direction, ...)
outputs, output_states = lstm(inputs, initial_states, training=True)
因此,除了嵌入输入之外,您只需要添加初始状态。在编码器 - 解码器系统中,它看起来像:
encoder_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers, hidden_size)
encoder_output, encoder_state = encoder_cell(encoder_embedding_input)
decoder_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers, hidden_size)
decoder_output, decoder_state = encoder_cell(decoder_embedding_input,
initial_states=encoder_state)
在这里,encoder_state
是tuple
作为(final_c_state, final_h_state)
。这两个州的形状都是(1, batch, hidden_size)
如果您的编码器是双向RNN,那将会有点棘手,因为输出状态现在变为qazxsw poi。
因此,我使用迂回的方式来解决它。
(2, batch, hidden_size)
虽然我没有尝试多层RNN,但我认为它也可以用类似的方式解决。