model.train_on_batch() 提供 64 个图像批次来训练模型,就好像它是 2 个批次一样

问题描述 投票:0回答:1

我正在尝试使用 Fashion MNIST 数据集创建 GAN。我一直在尝试使用

model.train_on_batch()
方法使用批量大小为 64 的图像来训练我的鉴别器。但是,在训练期间,程序从
train_on_batch()
方法的输出显示:

2/2 [==============================] - 1s 128ms/step

这表明我的批次 64 码被视为 2 批次 32 码。为什么会发生这种情况?

我的代码:

import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# download the data
(x_train, _ ), (x_test, _) = tf.keras.datasets.fashion_mnist.load_data()

# make the pixel values into [-1, 1] range
x = np.concatenate((x_train, x_test), axis=0)
x = x.reshape(-1, 28, 28, 1)
x = x / 255.0 * 2 - 1

# establish latent space dimension
latent_dim = 200

# get size of images and size of output space
N, H, W, _ = x.shape
output_dim = H * W

# define function to create generator
def create_generator(latent_space_dim):
    i = tf.keras.Input(shape=(latent_space_dim,))

    x = tf.keras.layers.Dense(7*7*256,  activation=tf.keras.layers.LeakyReLU(alpha=0.2))(i)
    x = tf.keras.layers.BatchNormalization()(x)

    x = tf.keras.layers.Reshape((7, 7, 256))(x)
    x = tf.keras.layers.Conv2DTranspose(128, (5, 5), padding='same', strides=(1, 1),
                                        activation=tf.keras.layers.LeakyReLU(alpha=0.2))(x)
    x = tf.keras.layers.BatchNormalization()(x)

    x = tf.keras.layers.Conv2DTranspose(64, (5, 5), padding='same', strides=(2, 2),
                                        activation=tf.keras.layers.LeakyReLU(alpha=0.2))(x)
    x = tf.keras.layers.BatchNormalization()(x)

    x = tf.keras.layers.Conv2DTranspose(1, (5, 5), padding='same', strides=(2, 2),
                                        activation="tanh")(x)

    model = tf.keras.Model(i, x)

    return model

# define function to create discriminator
def create_discriminator(input_dim):
  i = tf.keras.Input(shape=input_dim)

  x = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation=tf.keras.layers.LeakyReLU(alpha=0.2))(i)
  x = tf.keras.layers.BatchNormalization()(x)

  x = tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation=tf.keras.layers.LeakyReLU(alpha=0.2))(x)
  x = tf.keras.layers.BatchNormalization()(x)

  x = tf.keras.layers.Flatten()(x)
  x = tf.keras.layers.Dense(64, activation=tf.keras.layers.LeakyReLU(alpha=0.2))(x)

  x = tf.keras.layers.Dense(1, activation="sigmoid")(x)

  model = tf.keras.Model(i, x)

  return model

generator = create_generator(latent_dim)
discriminator = create_discriminator((H, W, 1))

discriminator.compile(
    optimizer="adam",
    loss="binary_crossentropy",
    metrics="accuracy"
)

discriminator.trainable=False


i = tf.keras.Input(shape=(latent_dim, ))
generator_output = generator(i)
discriminator_output = discriminator(generator_output)


combined_model = tf.keras.Model(i, generator_output)

combined_model.compile(
    optimizer="adam",
    loss="binary_crossentropy"
)

# perform training
epochs = 20000
batch_size = 64
ones = np.ones((batch_size,))
zeros = np.zeros((batch_size,))


for epoch in range(epochs):
  # train discriminator
  real_images_ids = np.random.randint(0, N, batch_size)
  fake_images_latent_spaces = np.random.rand(batch_size, latent_dim)

  fake_images = generator.predict(fake_images_latent_spaces)
  real_images = x[real_images_ids]

  d_loss_fake = discriminator.train_on_batch(fake_images, zeros)
  d_loss_real = discriminator.train_on_batch(real_images, ones)

  # train generator
  fake_images_latent_spaces = np.random.rand(batch_size, latent_dim)

  g_loss = combined_model.train_on_batch(fake_images_latent_spaces, ones)

我尝试为该方法提供大小为 32 的批处理,并且输出符合预期:

1/1 [==============================] - 1s 128ms/step
machine-learning deep-learning tensorflow2.0
1个回答
0
投票
  • train_on_batch()
    方法在鉴别器的每次训练迭代中被调用两次,并且您正在分别计算真实图像和假图像的损失。
  • 通过将
    real
    fake
    图像连接到单个数组中,并将相应的标签连接到单个数组中,您可以为鉴别器调用一次 train_on_batch() ,将组合批次视为
    single batch of size 64
# perform training
epochs = 20000
batch_size = 64
ones = np.ones((batch_size,))
zeros = np.zeros((batch_size,))

for epoch in range(epochs):
  # train discriminator
  real_images_ids = np.random.randint(0, N, batch_size)
  fake_images_latent_spaces = np.random.rand(batch_size, latent_dim)

  fake_images = generator.predict(fake_images_latent_spaces)
  real_images = x[real_images_ids]

  # Concatenate real and fake images
  all_images = np.concatenate((real_images, fake_images))
  all_labels = np.concatenate((ones, zeros))

  d_loss = discriminator.train_on_batch(all_images, all_labels)

  # train generator
  fake_images_latent_spaces = np.random.rand(batch_size, latent_dim)
  # just separate variable for target
  target_labels = np.ones((batch_size,))

  g_loss = combined_model.train_on_batch(fake_images_latent_spaces, target_labels)
© www.soinside.com 2019 - 2024. All rights reserved.