cvae 损失为 Nan

问题描述 投票:0回答:0

我有一个条件变分自动编码器模型 (CVAE),在我的案例中,错误有时会爆炸,让我抓狂。我正在使用这两个函数来最小化损失。总损失为 KL+ log_lik。 KL 用于两个分布之间。你能帮我解决这个问题吗?我该如何修改这些损失以防止出现 nan 值?谢谢。

def log_lik(par, mu, log_var):
    """Gaussian log-likelihood """ 
    par = par.view(-1,1).float()
    mu = mu.view(-1,1).float()
    log_var = log_var.view(-1,1).float()
    sigma_square = torch.square(torch.exp(0.5*(log_var)))
    lo = -0.5*torch.sum((torch.log(2*np.pi*sigma_square) + torch.square(par-mu)/sigma_square),dim=1)/par.size(1)
    return lo

def KL(mu_r,log_var_r,mu_q,log_var_q):
    """Gaussian KL divergence"""
    sigma_q = torch.exp(0.5 * (log_var_q))
    sigma_r = torch.exp(0.5 * (log_var_r))
    t1 = torch.square(sigma_q/sigma_r)
    t2 = torch.log(torch.square(sigma_r/sigma_q))
    t3 = torch.square(mu_r - mu_q)/torch.square(sigma_r)   
    kl_loss = 0.5*torch.sum(t1 + t2 + t3, dim=1) - 0.5*t1.size(1)
   
    return kl_loss
machine-learning deep-learning
© www.soinside.com 2019 - 2024. All rights reserved.