Tensoflow2 LSTM - 未使用参数input_shape?

问题描述 投票:0回答:1

所以我用下面的代码建立了神经网络。

import tensorflow as tf

tf_model = tf.keras.Sequential()
tf_model.add(tf.keras.layers.LSTM(50, activation='relu'))
tf_model.add(tf.keras.layers.Dense(20, activation='relu'))
tf_model.add(tf.keras.layers.Dense(10, activation='relu'))
tf_model.add(tf.keras.layers.Dense(1, activation='linear'))
tf_model.compile(optimizer='Adam', loss='mse')

我的训练集形状如下

>> ts_train_X.shape
(16469, 3, 21)

我在stackoverflow上读了很多文章和问题,为了让LSTM的数据框架有正确的形状。我找到的几乎每个页面都指定了 input_shape 参数,并将其传递给LSTM(...)或Sequential(...)。

当我看到 LSTM API 我找不到这个参数的参考。我也曾在源代码上瞥了一眼,在我看来,形状似乎是以某种方式自动推断出来的,但我不确定这一点。

这让我想到了我的问题。为什么我的代码能用?如果我不指定 input_shape 参数,作为第一层的 LSTM 层如何知道我的输入的形状?


编辑:根据评论中的建议修改标题。

python tensorflow keras deep-learning lstm
1个回答
2
投票

参数 input_shape 可以给任何keras的构造函数赋予 Layer 子类,因为这就是定义API的方式。

这段代码的工作原理是 input_shape 是作为关键字参数传递的(该 **kwargs),然后这些关键字参数由 LSTM 构造函数到 Layer 构造函数,然后继续存储信息供以后使用。这实际上意味着 input_shape 参数不需要在每个层中定义,而是作为关键字参数传递。

我认为问题在于,由于 keras 已移至 tensorflow,文档可能不完整。你可以找到更多关于 input_shape 中的参数 序列式API指南.

© www.soinside.com 2019 - 2024. All rights reserved.