在“tf.function”中检测到对“Model.fit”的调用。 `Model.fit 是一个高级端点,管理自己的 `tf.function`

问题描述 投票:0回答:1

我在尝试在 Google Colaboratory 笔记本中使用 Keras 训练 LSTM 模型时遇到问题。目标是根据时间序列数据预测某些“unit1”中断 (“moh”)。但是,在尝试将模型拟合到数据时,我遇到了以下错误:

RuntimeError: Detected a call to `Model.fit` inside a `tf.function`. `Model.fit` is a high-level endpoint that manages its own `tf.function`. Please move the call to `Model.fit` outside of all enclosing `tf.function`s. Note that you can call a `Model` directly on `Tensor`s inside a `tf.function` like: `model(x)`.

这是我使用的代码:

# Importing required libraries
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dropout, Dense
from tensorflow.keras.callbacks import EarlyStopping

# Define the LSTM model
def create_lstm_model(input_size, output_size, lstm_layer_sizes, dropout_rates):
    lstm_model = Sequential()
    for size, rate in zip(lstm_layer_sizes, dropout_rates):
        lstm_model.add(LSTM(units=size, return_sequences=True))
        lstm_model.add(Dropout(rate=rate))
    lstm_model.add(Dense(units=output_size))
    return lstm_model

# Set the parameters
input_size = 6
output_size = 3
unit = 'unit1'
outage = 'moh'
lstm_layer_sizes = (64,128,256,128,64)
dropout_rates = (0.05,0.05,0.05,0.05,0.05)

# Prepare the data (omitting data retrieval steps for brevity)
y = kinerja_df_extended_nanremoved_standardized[f'{unit}_{outage}s']
current_dates = kinerja_df_extended_nanremoved_standardized['date']
x = np.array([current_dates[i:i+input_size] for i in range(len(current_dates)-input_size+1)])
y = np.array([y[i:i+output_size] for i in range(len(y)-output_size+1)])

# Instantiate and compile the model
lstm_model = create_lstm_model(input_size=input_size, output_size=output_size, lstm_layer_sizes=lstm_layer_sizes, dropout_rates=dropout_rates)
lstm_model.compile(optimizer='adam', loss='mean_squared_error')

# The following line causes the error
history = lstm_model.fit(x=x, y=y, batch_size=1, epochs=128, validation_split=0.1, shuffle=False)

# Plot the training and validation loss
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

我尝试过在线搜索解决方案,但在我的上下文中没有找到任何可以解决此特定错误的内容。如何解决此问题并成功训练我的 LSTM 模型?

在这个修订版本中,我提供了问题的清晰描述、相关代码,并提到您试图找到解决方案,但找不到符合您特定场景的解决方案。这应该使您的帖子内容更丰富,并且不太可能被标记为“主要是代码”。

python tensorflow keras deep-learning google-colaboratory
1个回答
0
投票

错误 RuntimeError 是由调用 tf.function 内的 Model.fit 方法引起的。 Model.fit 方法是管理自己的 tf.function 的高级端点,应该在所有封闭的 tf.function 之外调用。

尝试以下操作:

# Solution
# Move the call to `Model.fit` outside of all enclosing `tf.function`s
# Define the LSTM model
def create_lstm_model(input_size, output_size, lstm_layer_sizes, dropout_rates):
    lstm_model = Sequential()
    for size, rate in zip(lstm_layer_sizes, dropout_rates):
        lstm_model.add(LSTM(units=size, return_sequences=True))
        lstm_model.add(Dropout(rate=rate))
    lstm_model.add(Dense(units=output_size))
    return lstm_model

# Set the parameters
input_size = 6
output_size = 3
unit = 'unit1'
outage = 'moh'
lstm_layer_sizes = (64,128,256,128,64)
dropout_rates = (0.05,0.05,0.05,0.05,0.05)

# Prepare the data (omitting data retrieval steps for brevity)
y = kinerja_df_extended_nanremoved_standardized[f'{unit}_{outage}s']
current_dates = kinerja_df_extended_nanremoved_standardized['date']
x = np.array([current_dates[i:i+input_size] for i in range(len(current_dates)-input_size+1)])
y = np.array([y[i:i+output_size] for i in range(len(y)-output_size+1)])

# Instantiate and compile the model
lstm_model = create_lstm_model(input_size=input_size, output_size=output_size, lstm_layer_sizes=lstm_layer_sizes, dropout_rates=dropout_rates)
lstm_model.compile(optimizer='adam', loss='mean_squared_error')

# Train the model
history = lstm_model.fit(x=x, y=y, batch_size=1, epochs=128, validation_split=0.1, shuffle=False)

# Plot the training and validation loss
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()
© www.soinside.com 2019 - 2024. All rights reserved.