pytorch BCEWithLogitsLoss 计算 pos_weight

问题描述 投票:0回答:2

我有一个如下的神经网络用于二进制预测。我的类别严重不平衡,类别 1 只出现 2% 的次数。仅显示最后几层

self.batch_norm2 = nn.BatchNorm1d(num_filters)

self.fc2 = nn.Linear(np.sum(num_filters), fc2_neurons)

self.batch_norm3 = nn.BatchNorm1d(fc2_neurons)

self.fc3 = nn.Linear(fc2_neurons, 1)

我的损失如下。这是计算

pos_weight
参数的正确方法吗?我查看了这个 link 的官方文档,它表明
pos_weight
需要为多类分类的每个类都有一个值。不确定对于二进制类来说这是否是一个不同的场景。我尝试输入 2 个值,但出现错误

我的问题:对于二元问题,

pos_weight
是否是单个值,与多类分类不同,多类分类需要长度等于类数的列表/数组?

BCE_With_LogitsLoss=nn.BCEWithLogitsLoss(pos_weight=class_wts[0]/class_wts[1])

我的 y 变量是一个单一变量,有 0 或 1 来表示实际类别,并且神经网络输出单个值

------------------------------------------------ --更新1

根据 Shai 的回答,我有以下问题:

  1. BCEWithLogitsLoss
    - 如果是多类问题那么如何使用
    pos_weigh
    参数?
  2. 有没有在pytorch中使用焦点损失的例子?我找到了一些链接,但大多数都是旧的 - 约会了 2 或 3 或更长时间
  3. 对于训练,我对班级 1 进行过采样。焦点损失仍然合适吗?
neural-network pytorch
2个回答
2
投票

pos_weight
的文档确实有点不清楚。对于
BCEWithLogitsLoss
pos_weight
应该是大小=1:
torch.tensor

BCE_With_LogitsLoss=nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_wts[0]/class_wts[1]]))

但是,在您的情况下,pos 类仅出现 2% 的次数,我认为设置

pos_weight
是不够的。
请考虑使用焦距损失:
Tsung-Yi Lin、Priya Goyal、Ross Girshick、Kaiming He、Piotr Dollár 密集物体检测的焦点损失(ICCV 2017)。
除了描述 Focal loss 之外,这篇论文还很好地解释了为什么 CE loss 在不平衡的情况下表现如此糟糕。我强烈推荐阅读这篇论文。

其他替代方案列于此处


0
投票

我正在处理一个多标签分割问题,其中我有 2 个类(0 - 背景,1 - 对象掩模)。这些对象太小了,所以我有一个巨大的类不平衡。当我尝试计算一批的一般正权重因子并将其用于训练时,将负类划分为正类 [class_wts[0]/class_wts[1]] - 导致一个高值,因此当我使用该参数进行训练时,损失从 40 左右的非常高的值开始,我认为这是不正常的。当我尝试较小的正权重分数(例如 4,5)时,损失从 1,2 左右的值开始,但它过度拟合并在第 30 个时期后停止改善。对于这种情况,理想的解决方案是什么?

© www.soinside.com 2019 - 2024. All rights reserved.