如何使用 TensorFlow 加载预训练的 VGG16 模型并生成混淆矩阵

问题描述 投票:0回答:1

我已经在 15 个类别的图像上训练了一个 VGG16 模型(来自tensorflow.keras.applications)100 个时期。训练后,我将模型保存为名为“best_model.h5”的文件,但不幸的是,我忘记在脚本中包含用于生成混淆矩阵的代码。现在,为了获得混淆矩阵,是否可以以某种方式利用我保存的模型文件(best_model.h5),而无需再次从头开始训练模型?

这是我用来训练模型的代码

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint


print("Set up directories for training and validation data...")
train_dir = '/scratch/user/nabarunkar/Disaster Dataset/train'
validation_dir = '/scratch/user/nabarunkar/Disaster Dataset/validation'

print("Image preprocessing for data augmentation...")
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

validation_datagen = ImageDataGenerator(rescale=1./255)

print("Load images from directories...")

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),  # VGG16 expects 224x224 input
    batch_size=32,
    class_mode='categorical'  # Ensure this is set to 'categorical'
)

validation_generator = validation_datagen.flow_from_directory(
    validation_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'  # Ensure this is set to 'categorical'
)


print("Build the VGG16 model...")
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

print("Freeze the base model...")
base_model.trainable = False

print("Add custom classification layers on top...")
model = Sequential([
    base_model,
    Flatten(),
    Dense(256, activation='relu'),
    Dropout(0.5),
    Dense(15, activation='softmax') # 15 categories
])

print("Compile the model...")
model.compile(optimizer=Adam(learning_rate=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

print("Callbacks...")
checkpoint = ModelCheckpoint(
    'best_model2.h5',  # Path to save the best model
    monitor='val_loss',  # Metric to monitor
    save_best_only=True,  # Only save the model if the monitored metric improves
    mode='min',  # Save when the metric is minimized (for loss)
    verbose=1  # Print messages when saving the model
)

print("Train the model...")
history = model.fit(
    train_generator,
    epochs=100,
    validation_data=validation_generator,
    callbacks=[checkpoint]
)

print("Evaluate the model on the validation set...")
val_loss, val_accuracy = model.evaluate(validation_generator)
print(f'Validation accuracy: {val_accuracy:.2f}')

我尝试运行此代码片段来加载保存的模型并生成混淆矩阵:

import numpy as np
from tensorflow.keras.models import load_model
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import re

def extract_numbers(class_name):
    """Extract the numerical part from class names like 'cat01building'"""
    numbers = re.findall(r'\d+', class_name)
    return numbers[0] if numbers else class_name

def generate_confusion_matrix(model_path, validation_generator, class_names):
    # Load the saved model
    model = load_model('./best_model2.h5')
    
    # Get predictions
    # Reset the generator to ensure we start from the beginning
    validation_generator.reset()
    
    # Get predictions for the validation set
    predictions = model.predict(validation_generator)
    predicted_classes = np.argmax(predictions, axis=1)
    
    # Get true labels
    true_classes = validation_generator.classes
    
    # Generate confusion matrix
    cm = confusion_matrix(true_classes, predicted_classes)
    
    # Simplify class names to just numbers
    simplified_labels = [extract_numbers(name) for name in class_names]
    
    # Plot confusion matrix
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=simplified_labels,
                yticklabels=simplified_labels)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45)
    plt.yticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    # Print classification report
    print("\nClassification Report:")
    print(classification_report(true_classes, predicted_classes, 
                              target_names=class_names))  




train_dir = '/content/dataset/train'
validation_dir = '/content/dataset/validation'

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

validation_datagen = ImageDataGenerator(rescale=1./255)

# Load images from directories

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),  # VGG16 expects 224x224 input
    batch_size=32,
    class_mode='categorical'  # Ensure this is set to 'categorical'
)

validation_generator = validation_datagen.flow_from_directory(
    validation_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'  # Ensure this is set to 'categorical'
)

class names = ['cat07', 'cat08', 'cat09', 'cat10', 'cat11', 'cat12', 'cat13', 'cat14', 'cat15', 'cat16', 'cat17', 'cat18', 'cat19', 'cat20', 'cat21', 'cat22', 'cat23']


# Generate confusion matrix
generate_confusion_matrix('best_model2.h5', validation_generator, class_names)```

但无论我尝试什么,我总是遇到如下相同的错误:


> AttributeError: Exception encountered when calling Flatten.call().

'list' object has no attribute 'shape'

Arguments received by Flatten.call():
  • args=(['<KerasTensor shape=(None, 7, 7, 512), dtype=float32, sparse=False, name=keras_tensor_450>'],)
  • kwargs=<class 'inspect._empty'>
python-3.x tensorflow keras image-classification vgg-net
1个回答
0
投票

查看您的代码,我可以看到您创建了名为名称的新类,而不是新变量class_names。请重新检查是否有拼写错误。

© www.soinside.com 2019 - 2024. All rights reserved.