对于GAN,使用自定义数据集代替MNIST

问题描述 投票:1回答:2

所以我正在尝试基于此代码的简单的生成对抗网络(GAN)。https://github.com/eriklindernoren/Keras-GAN。该代码中的GAN示例使用的是mnist数据集

# Load the dataset
(X_train, _), (_, _) = mnist.load_data()

您能帮我如何将mnist.load_data()更改为我自己的自定义数据集吗?我是该领域的初学者,请谅解。谢谢

dataset mnist gan
2个回答
0
投票

我不知道您是否已解决此问题,但我会尽力提供答案。首先mnist.load_data()不能更改为您自己的自定义数据集,它包含mnist数据。为了进一步为您提供帮助,我需要深入了解什么是“自己的”数据集?

[当我一直使用“我自己的”数据集时,我通常以这种方式将它们存储在NumPy数组中,因为我知道结构,我可以只使用numpy.load()


0
投票

这里是从目录加载一堆图像的示例:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
def load_images():
    normalize = lambda x: (x.astype('float32') - 127.5) / 127.5 # normalize to between -1: and 1
    data_gen = ImageDataGenerator(preprocessing_function=normalize, zoom_range=0.2, 
                                  horizontal_flip=True,rotation_range=0.05)

    x_train = data_gen.flow_from_directory(INPUT_DIR,
                                            target_size = (IMAGE_SIZE,IMAGE_SIZE),
                                            batch_size = BATCH_SIZE,
                                            shuffle = True,
                                            save_to_dir='augmented',
                                            class_mode = 'input',
                                            subset = "training")

    return x_train

x_train = load_images()
© www.soinside.com 2019 - 2024. All rights reserved.