我发现在 Keras 库上训练 CNN 模型的脚本存在奇怪的问题。当我在第一个纪元结束时尝试学习简单的 CNN 模型时,Keras 抛出一个错误:
文件“File.py”,第 86 行,位于 generateModelForData(train_dataset,validation_dataset) 文件“File.py”,第 44 行,在generateModelForData 中 model.fit(train_dataset、batch_size=1、epochs=1000、verbose=1、validation_data=validation_dataset、callbacks=[my_callbacks]) 文件“.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py”,第 122 行,在 error_handler 中 从 None 引发 e.with_traceback(filtered_tb) 文件“.local/lib/python3.10/site-packages/keras/src/utils/tree.py”,第 236 行,map_struct raise ValueError("必须提供至少一个结构") ValueError:必须提供至少一个结构
我的模型看起来像这样:
model = Sequential()
model.add(Conv2D(8, (3, 3), activation='relu', input_shape=[32, 128, 3]))
model.add(BatchNormalization())
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(8, (3, 3), activation="relu"))
model.add(MaxPooling2D(2, 2))
model.add(Flatten())
model.add(Dense(250, activation="relu"))
model.add(BatchNormalization())
model.add(Dense(120, activation="relu"))
model.add(Dense(50, activation="relu"))
model.add(Dense(1, activation="sigmoid", name="output_layer"))
optimizer = keras.optimizers.Adam(learning_rate=1e-4)
#
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_accuracy',
mode='max',
save_best_only=True)
#
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy', Precision(), Recall(), AUC()])
my_callbacks = [EarlyStopping(monitor='accuracy', patience=10, mode='max'), model_checkpoint_callback]
#
model.summary()
model.fit(train_dataset, batch_size=1, epochs=1000, verbose=1, validation_data=validation_dataset, callbacks=[my_callbacks])
数据集就是这样生成的:
image_generator = ImageDataGenerator(rescale=1/255)
train_dataset = image_generator.flow_from_directory(batch_size=5,
directory='new_heatmaps/train',
shuffle=True,
target_size=(32, 128),
subset="training",
class_mode='binary')
validation_dataset = image_generator.flow_from_directory(batch_size=5,
directory='new_heatmaps/test',
shuffle=True,
target_size=(32, 128),
subset="validation",
class_mode='binary')
造成这种结构错误的原因可能是什么?当 Keras 抛出某种错误时?
这有点奇怪,但我会告诉你什么对我有用。在发布此内容之前,我验证了几次,因为对于其他数据集,因为与您相同 - 此错误仅在某些数据集上引发,而不是在其他数据集上引发。