如何在pytorch中实现自定进度的多任务加权损失(Kendall et al. 2018)?

问题描述 投票:0回答:1

本研究中,作者引入了一个方程(方程7)来权衡神经网络不同任务的个体损失。

我想将其实现为

pytorch
中的函数,以便我可以用于我的模型。到目前为止,我尝试过的是:

import torch

# function to apply uncertainty weighing on the losses
def apply_uncertainty_weights(sigma, loss):

    """
    This function applies uncertainty weights on the given loss. 
    NOTE: This implementation is based on the study Kendall et al. 2018 (https://arxiv.org/abs/1705.07115)

    Arguments:
    sigma: A NN learned uncertainty value (initialised as torch.nn.Parameter(torch.zeros(num_tasks))
    loss: The calculated losss between the prediction and target

    Returns:
    weighted_loss: Weighted loss
    """

    # apply uncertainty weighthing
    # This is the formula in the publication -> weighted_loss = (1 / (2 * sigma**2)) * loss + torch.log(sigma)
    # but we can't use it as it won't be numerically stable/differentiable (e.g. when sigma is predicted to be 0)
    # instead use the following
    sigma = torch.nn.functional.softplus(sigma) + torch.tensor(1e-8) # this makes sure sigma is never exactly 0 or less otherwise the following functions wont work
    log_sigma_squared = torch.log(sigma ** 2) # this is log(sigma^2)
    precision = (1/2) * torch.exp(-log_sigma_squared) # this is 1/sigma^2
    log_sigma = (1/2) * log_sigma_squared # this is log(sigma)
    weighted_loss = precision * loss + log_sigma

    # return the weighted loss
    return weighted_loss

但奇怪的是,这种实现在训练期间给了我负损失值。我做错了什么?

python pytorch
1个回答
0
投票

我注意到的第一件事是您的代码返回损失张量,但预期的是单个标量值。以下是本文在 Github 上的几个(非官方)实现 [1] [2]

这可能是 Pytorch 中的实现

def apply_uncertainty_weights(sigmas_sq, losses):
    """
    Applies uncertainty weights on multiple task losses.
    
    Args:
        sigmas_sq (torch.Tensor): A tensor of learned variances (sigmas squared) for each task.
        losses (torch.Tensor): A tensor of individual task losses.
        
    Returns:
        torch.Tensor: The total weighted loss.
    """
    precisions = 1.0 / (2.0 * sigmas_sq)

    # 1/2.0 is applied to the last additive part because we are doing log(sigma_sq) which is basically 2 * log(sigma) 
    weighted_losses = precisions * losses + (1.0 / 2.0) * torch.log(sigmas_sq)

    total_loss = torch.sum(weighted_losses)

    return total_loss

似乎会产生正的损失值

import torch

sigmas_sq = torch.tensor([0.5, 1.0, 2.0], requires_grad=True)
losses = torch.tensor([0.2, 0.4, 0.1])

total_loss = apply_uncertainty_weights(sigmas_sq, losses)

print(total_loss)

© www.soinside.com 2019 - 2024. All rights reserved.