我正在实施我的 LSTM 模型用于讽刺检测(二进制)。我写了以下代码:
model = Sequential()
model.add(Embedding(vocab_size, embed_dim,
weights=[embed_matrix], input_length=max_tweet_len, trainable=False))
model.add(LSTM(lstm_out1, Dropout(0.2), Dropout(0.2)))
#model.add(MaxPooling1D(2))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
opt = Adam(learning_rate=0.001)
model.compile(loss='binary_crossentropy',
optimizer=opt,
metrics=['accuracy'])
我的模型预测 NaNs。
[[nan nan]
[nan nan]
[nan nan]
...
[nan nan]
[nan nan]
[nan nan]]
为什么会这样?我的输入维度是 (75830, 79)。我检查了缺失值/nans。数据中没有这种异常。我使用了这段代码。在所有情况下,代码都返回 false。
check_nan = df['is_sarcastic'].isnull().values.any()
print(check_nan)
check_na = df['is_sarcastic'].isna().values.any()
print(check_na)
check_nan_ = df['tweet_text'].isnull().values.any()
print(check_nan_)
check_na_ = df['tweet_text'].isna().values.any()
print(check_na_)
编辑
仍然是 NaN 损失。
model.add(LSTM(lstm_out1, Dropout(0.2), Dropout(0.2)))
这个 Dropout 层看起来不正确。我认为你应该使用 dropout=0.2, recurrent_dropout=0.2。参考文档https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM.
此外,您的嵌入初始化可以包含 weights=[embed_matrix] 中的 nan 值
但是你为什么要使用 binary_crossentropy?你有 2 个输出神经元。奇怪的是你有 2 个输出神经元的 sigmoid 激活。 在这种情况下,您有 2 个选择:
您可以使用带有 2 个输出神经元和 softmax 激活的 sparse_categorical_crossentropy(或 categorical_crossentropy,具体取决于您是否有 one_hot)。
您可以使用只有 1 个输出神经元和 sigmoid 激活的 binary_crossentropy。
我想那是错误的。
编辑: 我将根据您的代码为您提供一个工作示例,以便您进行比较。
x_train = np.random.randint(0, 100, (1000, 10))
y_train = np.random.randint(0, 2, (1000,))
vocab_size = 100
embed_dim = 10
max_tweet_len = 10
model = keras.models.Sequential()
model.add(layers.Embedding(vocab_size, embed_dim, input_length=max_tweet_len, trainable=False))
model.add(layers.LSTM(10, dropout=0.2, recurrent_dropout=0.2))
#model.add(MaxPooling1D(2))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
opt = keras.optimizers.Adam(learning_rate=0.001)
model.compile(loss='binary_crossentropy',
optimizer=opt,
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=32)