GAN 训练循环:一个时期*实际*应该看到数据集中的每张图像还是可以重复(包括示例)?

问题描述 投票:0回答:1

好吧,由于网上的例子不同,我一直对此感到困惑。 我知道一个纪元意味着整个数据集已经通过 GAN 一次。

但是,我在网上看到的大多数训练循环都是这样开始的:

def train_gan(generator, discriminator, dataset, latent_dim, n_epochs, n_batch):
 # calculate the number of batches per epoch
 batches_per_epoch = int(len(dataset) / n_batch)
 # calculate the number of training iterations
 n_steps = batches_per_epoch * n_epochs
 # gan training algorithm
 for i in range(n_steps):
 # generate points in the latent space
 z = randn(latent_dim * n_batch)
 # reshape into a batch of inputs for the network
 z = z.reshape(n_batch, latent_dim)
 # generate fake images
 fake = generator.predict(z)
 # select a batch of random real images
 ix = randint(0, len(dataset), n_batch) <---- *LOOK AT THIS LINE*
 # retrieve real images
 real = dataset[ix]

注意“#select a batch of random real images”下的代码。由于使用了“randint”,每批采样的图像是从整个数据集中随机选择的。这意味着来自训练数据集的图像可以每批重复一次(例如,图像 1、2、3 被随机选择用于第一批,然后图像 5、2、3 被随机选择用于第二批——重复 3)。

所以,在一个时代之后,整个数据集还没有被看到(因为重复被包括在内)。为什么大多数例子都是这样做的?我认为您应该这样做:1) 在每个纪元之前对整个数据集进行洗牌,2) 将数据拆分为该纪元的正确批次数,然后 3) 迭代这些预拆分批次中的每一个。这样,所有数据实际上都是每个时期都看到的,没有重复。

如果每个时期的所有数据都没有看到,这无关紧要吗?我在这里错过了什么?

python tensorflow keras deep-learning generative-adversarial-network
1个回答
0
投票

在传统意义上,一个epoch意味着数据集中的每个样本在训练过程中都被看过一次。然而,在实践中,由于现代数据集的大小,确保每个样本在每个时期都恰好出现一次并不总是可行的。因此,常用的方法是在每个训练步骤中随机抽取数据批次,每个时期的批次总数等于数据集大小除以批次大小。

在您提供的示例代码中,数据集是针对每个批次随机抽样的。这意味着某些图像可能会在不同的批次中重复出现,而其他图像可能在给定的时期内根本看不到。虽然这可能不严格遵守纪元的传统定义,但它是深度学习中常用的实用方法。

在每个 epoch 之前对整个数据集进行洗牌是一种很好的做法,有助于确保模型在训练期间接触到各种样本。但是,可能没有必要将数据分成每个时期的预定义批次,因为批次大小可以跨时期变化以实现不同的训练目标。

总而言之,虽然 epoch 的传统定义要求数据集中的每个样本只出现一次,但在实践中,由于现代数据集的大小,这可能并不总是可行的。在每个训练步骤中随机抽取一批数据是一种常用的方法。

您可以在每个时期之前打乱数据集并将其分成批次,以确保每个时期所有数据都被看到一次而没有重复:

import numpy as np

def train_gan(generator, discriminator, dataset, latent_dim, n_epochs, n_batch):
    # calculate the number of batches per epoch
    batches_per_epoch = int(len(dataset) / n_batch)
    
    # gan training algorithm
    for epoch in range(n_epochs):
        # shuffle dataset before each epoch
        np.random.shuffle(dataset)
        
        # iterate over each batch for this epoch
        for batch in range(batches_per_epoch):
            # select a batch of real images
            real = dataset[batch*n_batch:(batch+1)*n_batch]
            
            # generate points in the latent space
            z = np.random.randn(n_batch, latent_dim)
            
            # generate fake images
            fake = generator.predict(z)
            
            # train discriminator on real and fake images
            d_loss_real = discriminator.train_on_batch(real, np.ones((n_batch, 1)))
            d_loss_fake = discriminator.train_on_batch(fake, np.zeros((n_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            # train generator to fool the discriminator
            g_loss = gan.train_on_batch(z, np.ones((n_batch, 1)))
            
            # print progress every 10 batches
            if (batch + 1) % 10 == 0:
                print("Epoch {}/{}, Batch {}/{}: D_loss={:.4f}, G_loss={:.4f}".format(
                    epoch+1, n_epochs, batch+1, batches_per_epoch, d_loss[0], g_loss))

在此示例中,数据集在每个时期之前使用 np.random.shuffle(dataset) 进行混洗。然后,对于每个 epoch,使用 dataset[batch*n_batch:(batch+1)*n_batch] 将数据集分成多个批次,每个批次在训练循环中使用一次。这确保所有数据每个时期都可以看到一次而不会重复。

© www.soinside.com 2019 - 2024. All rights reserved.