我正在使用条件变分自动编码器进行预测,我的标签是 [1,100, 200, 5, 17],当我运行模型时,我的损失激增,达到了 nan。我没有规范化标签的数据。我在两个分布和对数似然之间使用 kl 散度如下:
def KL(mu_r,log_var_r,mu_q,log_var_q):
"""Gaussian KL divergence"""
sigma_q = torch.exp(0.5 * (log_var_q)) #2
sigma_r = torch.exp(0.5 * (log_var_r)) #1
t1 = torch.log(sigma_q/sigma_r)
t2 = torch.square(sigma_r)/torch.square(sigma_q)
t3 = torch.square(mu_r - mu_q)/torch.square(sigma_q)
kl_loss = torch.sum(t1 + 0.5 * t2 + 0.5 * t3, dim=1) - 0.5*t1.size(1)
return kl_loss
def log_lik(par, mu, log_var):
"""Gaussian log-likelihood """
par = par.view(-1,1).float()
mu = mu.view(-1,1).float()
log_var = log_var.view(-1,1).float()
sigma_square = torch.square(torch.exp(0.5*(log_var)))
lo = -0.5*torch.sum((torch.log(2*np.pi*sigma_square) + torch.square(par-mu)/sigma_square),dim=1)/par.size(1)
return lo
我还为我的模型的每个批次使用以下函数。 有人可以帮我吗?
def train_batch(model, optimizer, device, batch, labels):
model.train()
optimizer.zero_grad()
length = float(batch.size(0))
mu_x, log_var_x, mu_q, log_var_q, mu_r, log_var_r = model(batch,labels)
# get the loss funtion
kl_loss_b = KL(mu_r,log_var_r,mu_q,log_var_q
L_loss_b = log_lik(labels, mu_x, log_var_x)
L_loss = torch.sum(L_loss_b)
kl_loss = torch.sum(kl_loss_b)
loss = -(L_loss - kl_loss)/length
loss.backward()
# update the weights
optimizer.step()
# add for validation
return loss, kl_loss/length, L_loss/length
我认为
t1
是不稳定的,它将两个大数相除。t1 = (0.5 * log_var_q) - (0.5 * log_var_r)