生成器无法在一对多 CGAN 中学习

问题描述 投票:0回答:1

我对 GAN 还很陌生,并尝试在 MNIST 数据集上实现一对多 CGAN,以从总和中生成图像序列(输出图像的数量等于生成器的数量)。例如,如果输入为 [噪声、噪声、噪声、噪声] 且条件 = 11,则 4 生成器 1 判别器 CGANS 模型将生成 4 位数字图像:[3, 2, 6]。

但是,我遇到了一个问题,判别器的损失不断减少到接近于零,但生成器的损失却增加了。因此,生成器产生噪音而不是有意义的图像。

我认为这是因为判别器太好了,所以我添加了 Dropout 层并减少了过滤器的数量,但没有任何改变。

我正在尝试让这个模型按照我上面描述的那样正常工作。

自定义数据集:该数据集由 X(data_sizedigit_numchannelsheightwidth)组成,其中 digit_num 表示输入序列中的位数。 Y 是所有可能的结果 (digit_num * 9 + 1)。

digit_num = 3
label_num = 9 * digit_num + 1
data_size = 120000

dataset = SumMNISTDataset(
    "mnist",
    0,
    digit_num,
    data_size,
    transforms.Compose([transforms.Grayscale(), 
                        transforms.Normalize(127.5, 127.5)]),
)

dataloader = DataLoader(dataset, batch_size, True, drop_last=True)

生成器和鉴别器架构:论文中的模型对高分辨率数据集使用更深的网络。就我而言,我针对 MNIST 数据集缩小了它。

class Generator(nn.Module):
def __init__(self, latent_dim, filter_num, label_num, embed_num=50, bias=False):
    super().__init__()
    self.pre_main = nn.Sequential(
        # 7 x 7 x 128
        nn.ConvTranspose2d(latent_dim, filter_num * 4, 7, 1, 0, bias=bias),
        nn.BatchNorm2d(filter_num * 4),
        nn.LeakyReLU(0.2),
    )
    self.condition = nn.Sequential(
        # 1 x 50
        nn.Embedding(label_num, embed_num),
        nn.Linear(embed_num, 49, bias=bias),
    )
    self.main = nn.Sequential(
        # 14 x 14 x 64
        nn.ConvTranspose2d(filter_num * 4 + 1, filter_num * 2, 4, 2, 1, bias=bias),
        nn.BatchNorm2d(filter_num * 2),
        nn.LeakyReLU(0.2),
        # 28 x 28 x 1
        nn.ConvTranspose2d(filter_num * 2, 1, 4, 2, 1, bias=bias),
        nn.Tanh(),
    )

def forward(self, x, y):
    y = self.condition(y).reshape(-1, 1, 7, 7)
    x = self.pre_main(x)
    x = torch.cat((x, y), dim=1)
    x = self.main(x)
    return x

class Discriminator(nn.Module):
def __init__(self, filter_num, label_num, embed_num=50, bias=True):
    super().__init__()
    self.condition = nn.Sequential(
        # 28 x 28 x 50
        nn.Embedding(label_num, embed_num),
        nn.Linear(embed_num, 28 * 28, bias=bias),
    )
    self.main = nn.Sequential(
        # 14 x 14 x 64
        nn.Conv2d(2, filter_num, 3, 2, 1, bias=bias),
        nn.BatchNorm2d(filter_num),
        nn.LeakyReLU(0.2),
        # 7 x 7 x 128
        nn.Conv2d(filter_num, filter_num * 2, 3, 2, 1, bias=bias),
        nn.BatchNorm2d(filter_num * 2),
        nn.LeakyReLU(0.2),
        # Dense
        nn.Flatten(),
        nn.Linear(7 * 7 * filter_num * 2, 1, bias=bias),
    )

def forward(self, x, y):
    y = self.condition(y).reshape(-1, 1, 28, 28)
    x = torch.cat((x, y), dim=1)
    x = self.main(x)
    return x

初始化:生成器的数量等于输入序列中的数字。理论上,如果生成器的数量等于 1,则该模型相当于 CGAN,但是,即使有 1 个生成器,模型仍然无法收敛。

learning_rate = 0.0002
beta_1 = 0.5
latent_dim = 100
filter_num = 32
generator_num = digit_num
omega = 1 / generator_num


def weight_ini_G(model):
    if type(model) == nn.Linear:
        nn.init.constant_(model.weight.data, 1 / generator_num)
    elif type(model) == nn.BatchNorm2d:
        nn.init.constant_(model.weight.data, 1 / generator_num)
        nn.init.constant_(model.bias.data, 0)


def weight_ini_D(model):
    if type(model) == nn.Linear:
        nn.init.normal_(model.weight.data, 0.0, 0.2)
    elif type(model) == nn.BatchNorm2d:
        nn.init.normal_(model.weight.data, 1.0, 0.2)
        nn.init.constant_(model.bias.data, 0)


Gs = [
   Generator(latent_dim, filter_num,  label_num).to(device).apply(weight_ini_G)
   for _ in range(generator_num)
]
D = Discriminator(filter_num, label_num).to(device).apply(weight_ini_D)

G_optimizers = [
    optim.Adam(G.parameters(), learning_rate, betas=(beta_1, 0.999)) for G in Gs
]
D_optimizer = optim.Adam(D.parameters(), learning_rate, betas=(beta_1, 0.999))

bce = nn.BCEWithLogitsLoss()
l1 = nn.L1Loss()

辅助函数:澄清一下,generate_hybrid函数将计算digit_num轴上所有图像的平均值以进行训练。

def generate_fake():
    rand_labels = torch.randint(0, label_num, (batch_size, 1), device=device)
    images = [
        Gs[g](torch.randn((batch_size, latent_dim, 1, 1), device=device), rand_labels)
        for g in range(generator_num)
    ]
    images = torch.stack(images, axis=1).detach_()

    return images, rand_labels


def generate_real():
    images, labels = next(iter(dataloader))
    return images.to(device), labels.to(device)


def generate_hybrid(images):
    if images.shape[0] == digit_num:
        images = torch.mean(images, dim=0)
    elif images.shape[1] == digit_num:
        images = torch.mean(images, dim=1)

    return images

更新生成器

def update_generators(real, fake):
    ones = torch.ones((batch_size, 1), device=device)
    f_images, f_labels = fake
    r_images, _ = real
    total_loss = 0

    for g in range(generator_num):
        hybrid_fake = generate_hybrid(f_images)
        # r_image = r_images[:, g, :, :]

        preds = D(hybrid_fake, f_labels)

        bce_loss = bce(preds, ones)
        # l1_loss = l1(f_image, r_image)
        loss = bce_loss

        Gs[g].zero_grad()
        loss.backward()
        G_optimizers[g].step()

        total_loss += loss.item()

    return total_loss / generator_num

更新鉴别器

def update_discriminator(real, fake):
    half_batch_size = batch_size // 2
    zeros = torch.zeros((half_batch_size, 1), device=device)
    ones = torch.ones((half_batch_size, 1), device=device)
    f_images, f_labels = fake
    r_images, r_labels = real

    f_images = f_images[:half_batch_size]
    f_labels = f_labels[:half_batch_size]
    r_images = r_images[:half_batch_size]
    r_labels = r_labels[:half_batch_size]

    total_loss = 0

    # Train on Real
    hybrid_real = generate_hybrid(r_images)
    real_preds = D(hybrid_real, r_labels)

    bce_r_loss = bce(real_preds, ones)
    D.zero_grad()
    bce_r_loss.backward()

    # Train of Fake
    hybrid_fake = generate_hybrid(f_images)
    fake_preds = D(hybrid_fake, f_labels)

    bce_f_loss = bce(fake_preds, zeros)
    bce_f_loss.backward()
    D_optimizer.step()

    total_loss = (bce_f_loss.item() + bce_r_loss.item()) / 2

    return total_loss

训练模型

D_losses = []
G_losses = []
epochs = 5
fixed_noise = torch.randn((4, latent_dim, 1, 1), device=device)
fixed_label = torch.randint(0, label_num, (4,), device=device)

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}:")
    for batch in range(data_size // batch_size):
        # Generate Fake Images
        fake = generate_fake()

        # Generate Real Images
        real = generate_real()

        D_loss = update_discriminator(real, fake)

        fake = generate_fake()
        G_loss = update_generators(real, fake)

        if batch % 100 == 0:
            print(
                f"[Batch: {(batch + 1) * batch_size :7d}/{data_size}  D_Loss: {D_loss}  G_Loss: {G_loss}]"
            )
            generate_image(epoch, batch, fixed_noise, fixed_label)

        D_losses.append(D_loss)
        G_losses.append(G_loss)

由于模型显示的结果毫无意义,我在大约 1 个时期停止了模型。

发电机损失 鉴别器损失

这是我的代码

deep-learning one-to-many mnist generative-adversarial-network cgan
1个回答
0
投票

问题

尝试模型的痛点之一是它们很快就会变得强大。我会建议一些事情。

  1. 调整学习率 通过减少鉴别器相对于生成器的
    learning rate
    。这将减慢判别器的训练速度,为生成器提供更多学习机会。 例如:
# Adjusted learning rates
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate * 0.1, betas=(beta_1, 0.999))
G_optimizers = [
    optim.Adam(G.parameters(), lr=learning_rate * 2, betas=(beta_1, 0.999)) for G_ops in Gopss
]
  1. 在判别器损失函数中实现梯度

    penalty term
    以强制判别器决策边界的平滑性可以对模型有很大帮助,我将向您指出此存储库以获取有关您当前正在做什么的更多信息参考:github_repo

  2. 对于真实图像尝试较低的值,例如 0.8,对于假图像尝试稍高的值,例如 0.2。

  3. 了解如何实现生成器的特征匹配损失。这种损失有助于计算从鉴别器的中间层提取的特征的统计数据。参考,指导您github_ref_feature_matching_implementation

  4. 我没有看到任何类型的批处理,我可能是错的,但如果你没有看到,你应该考虑它

  5. 减少鉴别器容量:您说您已经减少了过滤器的数量并添加了丢失层,这些都是很好的步骤。但是,您可以通过减少层数或过滤器大小来进一步减少鉴别器的容量。这对 GNN 会有很大帮助。

  6. 在实现 GNN 时,尤其是使用公共数据集或其他方式时,使用噪声进行训练可能会有所帮助,向鉴别器的输入引入噪声(例如,向输入图像添加高斯噪声)以使鉴别器的工作变得更加困难。

  7. 最后是潜在维度,使用不同维度的潜在空间(例如,增加或减少

    latent_dim

训练模型就是试验内置参数

© www.soinside.com 2019 - 2024. All rights reserved.