我正在构建/优化一个 CNN,用于对这个数据集中的汽车进行分类。 我的基线模型从一个非常简单的模型架构中获得了令人惊讶的高精度,我担心由于未正确加载数据而导致一些数据泄漏,因此请提供一些建议。
加载数据集
batch_size = 16
img_size = (64, 64)
train_dataset, val_dataset = tf.keras.utils.image_dataset_from_directory(
data_dir,
label_mode='categorical',
seed=1,
subset='both',
validation_split=0.2,
image_size=img_size,
batch_size=batch_size,
)
normalization_layer = Rescaling(1./255)
train_dataset = train_dataset.map(lambda x, y: (normalization_layer(x), y))
val_dataset = val_dataset.map(lambda x, y: (normalization_layer(x), y))
构建模型
def baseline_model(input_shape=[64, 64, 3]):
model = Sequential([
# 1st Conv Layer
Conv2D(filters=16, kernel_size=(3, 3), activation='relu', padding='valid', input_shape=input_shape),
# Pool Layer
MaxPooling2D((2, 2)),
# 2nd Conv Layer
Conv2D(filters=32, kernel_size=(3, 3), activation='relu', padding='valid'),
# Pool Layer
MaxPooling2D((2, 2)),
# Flatten
Flatten(),
# Fully connected layer
Dense(64, activation='relu'),
Dense(5, activation='softmax')
])
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy', Precision(name='precision'), Recall(name='recall')])
return model
baseline_model = baseline_model()
适配型号
history = baseline_model.fit(
train_dataset,
epochs=10,
validation_data=val_dataset,
#callbacks=[ConfusionMatrixCallback(val_dataset, class_names)]
)
plot_training_history(history)