什么可能导致 VAE(变分自动编码器)即使在训练后也输出随机噪声?

问题描述 投票:0回答:2

我已经在 CIFAR10 数据集上训练了 VAE。然而,当我尝试从 VAE 生成图像时,我得到的只是一堆灰色噪声。此 VAE 的实现遵循《生成深度学习》一书中的实现,但代码使用 PyTorch 代替 TensorFlow。 包含训练和生成的笔记本可以在

这里

找到,而 VAE 的实际实现可以在这里找到。 我已经尝试过:

禁用辍学。
  1. 增加潜在空间的维度。
  2. 所有方法都没有显示出任何改进。

我已经证实:

输入大小与输出大小匹配
  1. 随着训练过程中损失的减少,反向传播成功运行。
python tensorflow deep-learning pytorch autoencoder
2个回答
4
投票

数据标准化
  1. VAE损失的实施。
关于 1.

,您的 CIFAR10DataModule 类使用

mean = 0.5
std = 0.5
标准化 CIFAR10 图像的 RGB 通道。由于像素值最初在 [0,1] 范围内,因此归一化图像的像素值在 [-1,1] 范围内。但是,您的
Decoder
类对重建图像应用
nn.Sigmoid()
激活。因此,重建图像的像素值在 [0,1] 范围内。我建议删除这种均值标准标准化,以便“真实”图像和重建图像的像素值都在 [0,1] 范围内。

关于 2.

:由于您正在处理 RGB 图像,因此 MSE 损失是有意义的。 MSE 损失背后的想法是“高斯解码器”。该解码器假设“真实图像”的像素值是由独立的高斯分布生成的,其均值是重建图像(即解码器的输出)的像素值并且具有给定的方差。您对重建损失(即 r_loss = F.mse_loss(predictions, targets))的实现相当于固定方差。利用

本文
的想法,我们可以做得更好,并获得该方差参数的“最优值”的解析表达式。最后,应将所有像素的重建损失相加(reduction = 'sum')。要理解原因,请查看重建损失的分析表达式(例如,请参阅
这篇博文
,其中考虑了 BCE 损失)。 这是重构后的

LitVAE

类的样子:

class LitVAE(pl.LightningModule):
    def __init__(self,
                 learning_rate: float = 0.0005,
                 **kwargs) -> None:
        """
        Parameters
        ----------
        - `learning_rate: float`:
            learning rate for the optimizer
        - `**kwargs`:
            arguments to pass to the variational autoencoder constructor
        """
        super(LitVAE, self).__init__()
        
        self.learning_rate = learning_rate 

        self.vae = VariationalAutoEncoder(**kwargs)

    def forward(self, x) -> _tensor_size_3_t: 
        return self.vae(x)

    def training_step(self, batch, batch_idx):
        r_loss, kl_loss, sigma_opt = self.shared_step(batch)
        loss = r_loss + kl_loss
        
        self.log("train_loss_step", loss)
        return {"loss": loss, 'log':{"r_loss": r_loss / len(batch[0]), "kl_loss": kl_loss / len(batch[0]), 'sigma_opt': sigma_opt}}

    def training_epoch_end(self, outputs) -> None:
        # add computation graph
        if(self.current_epoch == 0):
            sample_input = torch.randn((1, 3, 32, 32))
            sample_model = LitVAE(**MODEL_PARAMS)
            
            self.logger.experiment.add_graph(sample_model, sample_input)
            
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("train_loss_epoch", epoch_loss, self.current_epoch)

    def validation_step(self, batch, batch_idx):
        r_loss, kl_loss, _ = self.shared_step(batch)
        loss = r_loss + kl_loss

        self.log("valid_loss_step", loss)

        return {"loss": loss}

    def validation_epoch_end(self, outputs) -> None:
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("valid_loss_epoch", epoch_loss, self.current_epoch)

    def test_step(self, batch, batch_idx):
        r_loss, kl_loss, _ = self.shared_step(batch)
        loss = r_loss + kl_loss
        
        self.log("test_loss_step", loss)
        return {"loss": loss}

    def test_epoch_end(self, outputs) -> None:
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("test_loss_epoch", epoch_loss, self.current_epoch)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)
        
    def shared_step(self, batch) -> torch.TensorType: 
        # images are both samples and targets thus original 
        # labels from the dataset are not required
        true_images, _ = batch

        # perform a forward pass through the VAE 
        # mean and log_variance are used to calculate the KL Divergence loss 
        # decoder_output represents the generated images 
        mean, log_variance, generated_images = self(true_images)

        r_loss, kl_loss, sigma_opt = self.calculate_loss(mean, log_variance, generated_images, true_images)
        return r_loss, kl_loss, sigma_opt

    def calculate_loss(self, mean, log_variance, predictions, targets):
        mse = F.mse_loss(predictions, targets, reduction='mean')
        log_sigma_opt = 0.5 * mse.log()
        r_loss = 0.5 * torch.pow((targets - predictions) / log_sigma_opt.exp(), 2) + log_sigma_opt
        r_loss = r_loss.sum()
        kl_loss = self._compute_kl_loss(mean, log_variance)
        return r_loss, kl_loss, log_sigma_opt.exp()

    def _compute_kl_loss(self, mean, log_variance): 
        return -0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp())

    def average_metric(self, metrics, metric_name):
        avg_metric = torch.stack([x[metric_name] for x in metrics]).mean()
        return avg_metric

10 个 epoch 后,这就是重建图像的样子:


-1
投票

我按照你的笔记本和variational_encoder python脚本进行操作,即使使用重构的LitVAE代码,我在前向传播时也遇到错误。
  • 附上截图

非常感谢任何帮助。短暂性脑缺血发作。

© www.soinside.com 2019 - 2024. All rights reserved.