如何在 PyTorch 中使用类权重和焦点损失来处理多类分类的不平衡数据集

问题描述 投票:0回答:4

我正在研究语言任务的多类分类(4 类),并且我正在使用 BERT 模型进行分类任务。我正在关注此博客作为参考我的 BERT Fine Tuned 模型返回

nn.LogSoftmax(dim=1)

我的数据非常不平衡,所以我使用

sklearn.utils.class_weight.compute_class_weight
来计算类别的权重,并使用损失中的权重。

class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy  = nn.NLLLoss(weight=weights) 

我的结果不太好,所以我想到用

Focal Loss
进行实验,并有一个焦点损失的代码。

class FocalLoss(nn.Module):
  def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
    super(FocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.logits = logits
    self.reduce = reduce

  def forward(self, inputs, targets):
    BCE_loss = nn.CrossEntropyLoss()(inputs, targets)

    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

    if self.reduce:
      return torch.mean(F_loss)
    else:
      return F_loss

我现在有 3 个问题。首先也是最重要的是

  1. 我应该使用带有焦点损失的班级权重吗?
  2. 如果我必须在其中实现权重
    Focal Loss
    ,我可以在
    weights
    中使用
    nn.CrossEntropyLoss()
  3. 参数吗
  4. 如果此工具不正确,该工具的正确代码应该是什么,包括权重(如果可能)
python machine-learning deep-learning neural-network pytorch
4个回答
4
投票

您可以通过以下方式找到问题的答案:

  1. 焦点损失自动处理类别不平衡,因此焦点损失不需要权重。 alpha 和 gamma 因子处理焦点损失方程中的类别不平衡。
  2. 不需要额外的权重,因为焦点损失使用 alpha 和 gamma 调制因子来处理它们
  3. 根据焦点损失公式,您提到的实现是正确的,但我在使我的模型与此版本收敛时遇到了困难,因此,我使用了 mmdetection 框架中的以下实现
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

您还可以尝试另一个可用的焦点损失版本


3
投票

我想OP现在应该已经得到答案了。我写这篇文章是为了其他可能思考这个问题的人。

OP 实现 Focal Loss 时存在一个问题:

  1. F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

在这一行中,相同的

alpha
值乘以每个类的输出概率,即 (
pt
)。此外,代码没有显示我们如何获得
pt
。可以在here找到焦点损失的一个非常好的实现。但此实现仅适用于二元分类,因为它对于
alpha
张量中的两个类具有
1-alpha
self.alpha

在多类分类或多标签分类的情况下,

self.alpha
张量应包含等于标签总数的元素数量。这些值可以是标签的逆标签频率或逆标签归一化频率(只是要小心频率为 0 的标签)。


3
投票

我认为你问题中的实现是错误的。 alpha 是班级权重。

在交叉熵中,类权重是 alpha_t,如以下表达式所示:

你看到它是 alpha_t 而不是 alpha。

在焦损失中,公式为

我们可以从这个流行的 Pytorch 实现中看到 alpha 的作用与类权重相同。

参考资料:

  1. https://amaarora.github.io/2020/06/29/FocalLoss.html#alpha-and-gamma
  2. https://github.com/clcarwin/focal_loss_pytorch

0
投票

我自己也在寻找这个,发现大多数实现方式都很麻烦。人们可以使用 pytorch 的 CrossEntropyLoss 来代替(它已经具有权重参数并忽略索引,...)并添加焦点项:

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2, ignore_index=-100, reduction='mean'):
        super().__init__()
        # use standard CE loss without reducion as basis
        self.CE = nn.CrossEntropyLoss(weight=alpha, reduction='none', ignore_index=ignore_index)
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input, target):
        '''
        input (B, N)
        target (B)
        '''
        minus_logpt = self.CE(input, target)
        pt = torch.exp(-minus_logpt) # don't forget the minus here
        focal_loss = (1-pt)**self.gamma * minus_logpt
        
        if self.reduction == 'mean':
            focal_loss = focal_loss.mean()
        elif self.reduction == 'sum':
            focal_loss = focal_loss.sum()
        return focal_loss
© www.soinside.com 2019 - 2024. All rights reserved.