我知道 torch.nn.CrossEntropyLoss(weight=?) 参数“weight”是为了平衡不同类的样本之间的不平衡,它是类的参数,它具有类数量的长度。 那么如果我想在计算损失时设置不同样本的权重,其长度为样本数,该怎么办? ps:就像keras.Sequential.fit中的“sample_weight”一样(sample_weight=?)
尝试使用 torch.nn.Parameter() 但似乎没用; 期待一些解决方案。
如果您想在交叉熵损失中使用样本权重,您需要像这样进行临时操作:
entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
loss = entropy_loss(output, target)
loss = (loss*sample_weights).mean()
loss.backward()
参见https://discuss.pytorch.org/t/per-class-and-per-sample-weighting/25530/4
这假设您的样本权重已标准化。
学习中的每一步都使用sample_weight
loss = criterion(outputs, batch_targets)
weighted_loss = (loss * sample_weight).mean()
如果其他人可能遇到像我一样的问题,这是我的解决方案:
My_CrossEntropy 类(torch.nn.Module): def init(自身): 超级()。init()
def forward(self, x, y, sample_weight):
x = torch.nn.functional.log_softmax(x, dim=1)
y = torch.nn.functional.one_hot(y,num_classes = 8)
y = y * sample_weight.reshape(-1, 1)
loss = -torch.mean(torch.sum(y * x, dim=1), dim=0)
return loss