我有一个条件变分自动编码器模型 (CVAE),在我的案例中,错误有时会爆炸,让我抓狂。我正在使用这两个函数来最小化损失。总损失为 KL+ log_lik。 KL 用于两个分布之间。你能帮我解决这个问题吗?我该如何修改这些损失以防止出现 nan 值?谢谢。
def log_lik(par, mu, log_var):
"""Gaussian log-likelihood """
par = par.view(-1,1).float()
mu = mu.view(-1,1).float()
log_var = log_var.view(-1,1).float()
sigma_square = torch.square(torch.exp(0.5*(log_var)))
lo = -0.5*torch.sum((torch.log(2*np.pi*sigma_square) + torch.square(par-mu)/sigma_square),dim=1)/par.size(1)
return lo
def KL(mu_r,log_var_r,mu_q,log_var_q):
"""Gaussian KL divergence"""
sigma_q = torch.exp(0.5 * (log_var_q))
sigma_r = torch.exp(0.5 * (log_var_r))
t1 = torch.square(sigma_q/sigma_r)
t2 = torch.log(torch.square(sigma_r/sigma_q))
t3 = torch.square(mu_r - mu_q)/torch.square(sigma_r)
kl_loss = 0.5*torch.sum(t1 + t2 + t3, dim=1) - 0.5*t1.size(1)
return kl_loss