我正在尝试使用 Fashion MNIST 数据集创建 GAN。我一直在尝试使用
model.train_on_batch()
方法使用批量大小为 64 的图像来训练我的鉴别器。但是,在训练期间,程序从 train_on_batch()
方法的输出显示:
2/2 [==============================] - 1s 128ms/step
这表明我的批次 64 码被视为 2 批次 32 码。为什么会发生这种情况?
我的代码:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# download the data
(x_train, _ ), (x_test, _) = tf.keras.datasets.fashion_mnist.load_data()
# make the pixel values into [-1, 1] range
x = np.concatenate((x_train, x_test), axis=0)
x = x.reshape(-1, 28, 28, 1)
x = x / 255.0 * 2 - 1
# establish latent space dimension
latent_dim = 200
# get size of images and size of output space
N, H, W, _ = x.shape
output_dim = H * W
# define function to create generator
def create_generator(latent_space_dim):
i = tf.keras.Input(shape=(latent_space_dim,))
x = tf.keras.layers.Dense(7*7*256, activation=tf.keras.layers.LeakyReLU(alpha=0.2))(i)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Reshape((7, 7, 256))(x)
x = tf.keras.layers.Conv2DTranspose(128, (5, 5), padding='same', strides=(1, 1),
activation=tf.keras.layers.LeakyReLU(alpha=0.2))(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2DTranspose(64, (5, 5), padding='same', strides=(2, 2),
activation=tf.keras.layers.LeakyReLU(alpha=0.2))(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2DTranspose(1, (5, 5), padding='same', strides=(2, 2),
activation="tanh")(x)
model = tf.keras.Model(i, x)
return model
# define function to create discriminator
def create_discriminator(input_dim):
i = tf.keras.Input(shape=input_dim)
x = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation=tf.keras.layers.LeakyReLU(alpha=0.2))(i)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation=tf.keras.layers.LeakyReLU(alpha=0.2))(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(64, activation=tf.keras.layers.LeakyReLU(alpha=0.2))(x)
x = tf.keras.layers.Dense(1, activation="sigmoid")(x)
model = tf.keras.Model(i, x)
return model
generator = create_generator(latent_dim)
discriminator = create_discriminator((H, W, 1))
discriminator.compile(
optimizer="adam",
loss="binary_crossentropy",
metrics="accuracy"
)
discriminator.trainable=False
i = tf.keras.Input(shape=(latent_dim, ))
generator_output = generator(i)
discriminator_output = discriminator(generator_output)
combined_model = tf.keras.Model(i, generator_output)
combined_model.compile(
optimizer="adam",
loss="binary_crossentropy"
)
# perform training
epochs = 20000
batch_size = 64
ones = np.ones((batch_size,))
zeros = np.zeros((batch_size,))
for epoch in range(epochs):
# train discriminator
real_images_ids = np.random.randint(0, N, batch_size)
fake_images_latent_spaces = np.random.rand(batch_size, latent_dim)
fake_images = generator.predict(fake_images_latent_spaces)
real_images = x[real_images_ids]
d_loss_fake = discriminator.train_on_batch(fake_images, zeros)
d_loss_real = discriminator.train_on_batch(real_images, ones)
# train generator
fake_images_latent_spaces = np.random.rand(batch_size, latent_dim)
g_loss = combined_model.train_on_batch(fake_images_latent_spaces, ones)
我尝试为该方法提供大小为 32 的批处理,并且输出符合预期:
1/1 [==============================] - 1s 128ms/step
train_on_batch()
方法在鉴别器的每次训练迭代中被调用两次,并且您正在分别计算真实图像和假图像的损失。real
和 fake
图像连接到单个数组中,并将相应的标签连接到单个数组中,您可以为鉴别器调用一次 train_on_batch() ,将组合批次视为 single batch of size 64
。# perform training
epochs = 20000
batch_size = 64
ones = np.ones((batch_size,))
zeros = np.zeros((batch_size,))
for epoch in range(epochs):
# train discriminator
real_images_ids = np.random.randint(0, N, batch_size)
fake_images_latent_spaces = np.random.rand(batch_size, latent_dim)
fake_images = generator.predict(fake_images_latent_spaces)
real_images = x[real_images_ids]
# Concatenate real and fake images
all_images = np.concatenate((real_images, fake_images))
all_labels = np.concatenate((ones, zeros))
d_loss = discriminator.train_on_batch(all_images, all_labels)
# train generator
fake_images_latent_spaces = np.random.rand(batch_size, latent_dim)
# just separate variable for target
target_labels = np.ones((batch_size,))
g_loss = combined_model.train_on_batch(fake_images_latent_spaces, target_labels)