在本研究中,作者引入了一个方程(方程7)来权衡神经网络不同任务的个体损失。
我想将其实现为
pytorch
中的函数,以便我可以用于我的模型。到目前为止,我尝试过的是:
import torch
# function to apply uncertainty weighing on the losses
def apply_uncertainty_weights(sigma, loss):
"""
This function applies uncertainty weights on the given loss.
NOTE: This implementation is based on the study Kendall et al. 2018 (https://arxiv.org/abs/1705.07115)
Arguments:
sigma: A NN learned uncertainty value (initialised as torch.nn.Parameter(torch.zeros(num_tasks))
loss: The calculated losss between the prediction and target
Returns:
weighted_loss: Weighted loss
"""
# apply uncertainty weighthing
# This is the formula in the publication -> weighted_loss = (1 / (2 * sigma**2)) * loss + torch.log(sigma)
# but we can't use it as it won't be numerically stable/differentiable (e.g. when sigma is predicted to be 0)
# instead use the following
sigma = torch.nn.functional.softplus(sigma) + torch.tensor(1e-8) # this makes sure sigma is never exactly 0 or less otherwise the following functions wont work
log_sigma_squared = torch.log(sigma ** 2) # this is log(sigma^2)
precision = (1/2) * torch.exp(-log_sigma_squared) # this is 1/sigma^2
log_sigma = (1/2) * log_sigma_squared # this is log(sigma)
weighted_loss = precision * loss + log_sigma
# return the weighted loss
return weighted_loss
但奇怪的是,这种实现在训练期间给了我负损失值。我做错了什么?
我注意到的第一件事是您的代码返回损失张量,但预期的是单个标量值。以下是本文在 Github 上的几个(非官方)实现 [1] [2]
这可能是 Pytorch 中的实现
def apply_uncertainty_weights(sigmas_sq, losses):
"""
Applies uncertainty weights on multiple task losses.
Args:
sigmas_sq (torch.Tensor): A tensor of learned variances (sigmas squared) for each task.
losses (torch.Tensor): A tensor of individual task losses.
Returns:
torch.Tensor: The total weighted loss.
"""
precisions = 1.0 / (2.0 * sigmas_sq)
# 1/2.0 is applied to the last additive part because we are doing log(sigma_sq) which is basically 2 * log(sigma)
weighted_losses = precisions * losses + (1.0 / 2.0) * torch.log(sigmas_sq)
total_loss = torch.sum(weighted_losses)
return total_loss
似乎会产生正的损失值
import torch
sigmas_sq = torch.tensor([0.5, 1.0, 2.0], requires_grad=True)
losses = torch.tensor([0.2, 0.4, 0.1])
total_loss = apply_uncertainty_weights(sigmas_sq, losses)
print(total_loss)