我已经在 CIFAR10 数据集上训练了 VAE。然而,当我尝试从 VAE 生成图像时,我得到的只是一堆灰色噪声。此 VAE 的实现遵循《生成深度学习》一书中的实现,但代码使用 PyTorch 代替 TensorFlow。 包含训练和生成的笔记本可以在
这里找到,而 VAE 的实际实现可以在这里找到。 我已经尝试过:
禁用辍学。
我已经证实:
输入大小与输出大小匹配
数据标准化
,您的 CIFAR10DataModule
类使用
mean = 0.5
和 std = 0.5
标准化 CIFAR10 图像的 RGB 通道。由于像素值最初在 [0,1] 范围内,因此归一化图像的像素值在 [-1,1] 范围内。但是,您的 Decoder
类对重建图像应用 nn.Sigmoid()
激活。因此,重建图像的像素值在 [0,1] 范围内。我建议删除这种均值标准标准化,以便“真实”图像和重建图像的像素值都在 [0,1] 范围内。关于 2.:由于您正在处理 RGB 图像,因此 MSE 损失是有意义的。 MSE 损失背后的想法是“高斯解码器”。该解码器假设“真实图像”的像素值是由独立的高斯分布生成的,其均值是重建图像(即解码器的输出)的像素值并且具有给定的方差。您对重建损失(即 r_loss = F.mse_loss(predictions, targets)
)的实现相当于固定方差。利用
本文的想法,我们可以做得更好,并获得该方差参数的“最优值”的解析表达式。最后,应将所有像素的重建损失相加(
reduction = 'sum'
)。要理解原因,请查看重建损失的分析表达式(例如,请参阅这篇博文,其中考虑了 BCE 损失)。 这是重构后的
LitVAE
类的样子:
class LitVAE(pl.LightningModule):
def __init__(self,
learning_rate: float = 0.0005,
**kwargs) -> None:
"""
Parameters
----------
- `learning_rate: float`:
learning rate for the optimizer
- `**kwargs`:
arguments to pass to the variational autoencoder constructor
"""
super(LitVAE, self).__init__()
self.learning_rate = learning_rate
self.vae = VariationalAutoEncoder(**kwargs)
def forward(self, x) -> _tensor_size_3_t:
return self.vae(x)
def training_step(self, batch, batch_idx):
r_loss, kl_loss, sigma_opt = self.shared_step(batch)
loss = r_loss + kl_loss
self.log("train_loss_step", loss)
return {"loss": loss, 'log':{"r_loss": r_loss / len(batch[0]), "kl_loss": kl_loss / len(batch[0]), 'sigma_opt': sigma_opt}}
def training_epoch_end(self, outputs) -> None:
# add computation graph
if(self.current_epoch == 0):
sample_input = torch.randn((1, 3, 32, 32))
sample_model = LitVAE(**MODEL_PARAMS)
self.logger.experiment.add_graph(sample_model, sample_input)
epoch_loss = self.average_metric(outputs, "loss")
self.logger.experiment.add_scalar("train_loss_epoch", epoch_loss, self.current_epoch)
def validation_step(self, batch, batch_idx):
r_loss, kl_loss, _ = self.shared_step(batch)
loss = r_loss + kl_loss
self.log("valid_loss_step", loss)
return {"loss": loss}
def validation_epoch_end(self, outputs) -> None:
epoch_loss = self.average_metric(outputs, "loss")
self.logger.experiment.add_scalar("valid_loss_epoch", epoch_loss, self.current_epoch)
def test_step(self, batch, batch_idx):
r_loss, kl_loss, _ = self.shared_step(batch)
loss = r_loss + kl_loss
self.log("test_loss_step", loss)
return {"loss": loss}
def test_epoch_end(self, outputs) -> None:
epoch_loss = self.average_metric(outputs, "loss")
self.logger.experiment.add_scalar("test_loss_epoch", epoch_loss, self.current_epoch)
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.learning_rate)
def shared_step(self, batch) -> torch.TensorType:
# images are both samples and targets thus original
# labels from the dataset are not required
true_images, _ = batch
# perform a forward pass through the VAE
# mean and log_variance are used to calculate the KL Divergence loss
# decoder_output represents the generated images
mean, log_variance, generated_images = self(true_images)
r_loss, kl_loss, sigma_opt = self.calculate_loss(mean, log_variance, generated_images, true_images)
return r_loss, kl_loss, sigma_opt
def calculate_loss(self, mean, log_variance, predictions, targets):
mse = F.mse_loss(predictions, targets, reduction='mean')
log_sigma_opt = 0.5 * mse.log()
r_loss = 0.5 * torch.pow((targets - predictions) / log_sigma_opt.exp(), 2) + log_sigma_opt
r_loss = r_loss.sum()
kl_loss = self._compute_kl_loss(mean, log_variance)
return r_loss, kl_loss, log_sigma_opt.exp()
def _compute_kl_loss(self, mean, log_variance):
return -0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp())
def average_metric(self, metrics, metric_name):
avg_metric = torch.stack([x[metric_name] for x in metrics]).mean()
return avg_metric
10 个 epoch 后,这就是重建图像的样子: