Pytorch LSTM 与 LSTMCell

问题描述 投票:0回答:2

Pytorch(当前版本 1.1)中的 LSTMLSTMCell 有什么区别?看起来 LSTMCell 是 LSTM 的一个特例(即只有一层、单向、无 dropout)。

那么,这两种实现的目的是什么呢?除非我遗漏了什么,否则使用 LSTM 对象作为 LSTMCell 是很简单的(或者,使用多个 LSTMCell 来创建 LSTM 对象也很容易)

python pytorch lstm recurrent-neural-network lstm-stateful
2个回答
50
投票

是的,你可以逐个模仿,将它们分开的原因是效率。

LSTMCell
是一个带有参数的单元格:

  • 输入形状批次×输入尺寸;
  • 形状批次 x 隐藏维度的 LSTM 隐藏状态元组。

这是方程的直接实现。

LSTM
是在“for 循环”中应用 LSTM 单元(或多个 LSTM 单元)的层,但该循环使用 cuDNN 进行了深度优化。它的输入是

  • 输入形状为批次×输入长度×输入维度的三维张量;
  • 可选地,LSTM 的初始状态,即形状为批量 × 隐藏暗淡的隐藏状态元组(如果 LSTM 是双向的,则为此类元组的元组)

您经常可能希望在不同的上下文中使用 LSTM 单元,而不是将其应用于序列,即创建一个在树状结构上运行的 LSTM。当您在序列到序列模型中编写解码器时,您还可以在循环中调用该单元,并在解码序列结束符号时停止循环。


1
投票

让我展示一些具体的例子:

# LSTM example:
>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
# LSTMCell example:
>>> rnn = nn.LSTMCell(10, 20)
>>> input = torch.randn(3, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
        hx, cx = rnn(input[i], (hx, cx))
        output.append(hx)

主要区别:

  1. LSTM:参数
    2
    ,代表
    num_layers
    ,循环层数。有
    seq_len * num_layers=5 * 2
    细胞。 没有循环,但有更多单元格。
  2. LSTMCell:在
    for
    循环(
    seq_len=5
    次)中,
    ith
    实例的每个输出将成为
    (i+1)th
    实例的输入。只有一个单元格,真正的循环

如果我们在LSTM中设置

num_layers=1
或者再添加一个LSTMCell,上面的代码将是相同的。

显然,在 LSTM 中应用并行计算更容易。

© www.soinside.com 2019 - 2024. All rights reserved.