所以我正在尝试基于此代码的简单的生成对抗网络(GAN)。https://github.com/eriklindernoren/Keras-GAN。该代码中的GAN示例使用的是mnist数据集
# Load the dataset
(X_train, _), (_, _) = mnist.load_data()
您能帮我如何将mnist.load_data()更改为我自己的自定义数据集吗?我是该领域的初学者,请谅解。谢谢
我不知道您是否已解决此问题,但我会尽力提供答案。首先mnist.load_data()
不能更改为您自己的自定义数据集,它包含mnist
数据。为了进一步为您提供帮助,我需要深入了解什么是“自己的”数据集?
[当我一直使用“我自己的”数据集时,我通常以这种方式将它们存储在NumPy数组中,因为我知道结构,我可以只使用numpy.load()
。
这里是从目录加载一堆图像的示例:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def load_images():
normalize = lambda x: (x.astype('float32') - 127.5) / 127.5 # normalize to between -1: and 1
data_gen = ImageDataGenerator(preprocessing_function=normalize, zoom_range=0.2,
horizontal_flip=True,rotation_range=0.05)
x_train = data_gen.flow_from_directory(INPUT_DIR,
target_size = (IMAGE_SIZE,IMAGE_SIZE),
batch_size = BATCH_SIZE,
shuffle = True,
save_to_dir='augmented',
class_mode = 'input',
subset = "training")
return x_train
x_train = load_images()