我尝试建立一个用于中文多标签文本分类任务的模型,但该模型的性能不够好(大约60%的准确率),我来寻求如何增强它的帮助。
我基于github项目构建了一个模型:
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
class BertMultiLabelCls(nn.Module):
def __init__(self, hidden_size, class_num, dropout=0.1):
super(BertMultiLabelCls, self).__init__()
self.fc = nn.Linear(hidden_size, class_num)
self.drop = nn.Dropout(dropout)
self.bert = BertModel.from_pretrained("bert-base-chinese")
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.bert(input_ids, attention_mask, token_type_ids)
cls = self.drop(outputs[1])
out = F.sigmoid(self.fc(cls))
return out
我的数据集是2000个查询标签对,有13个标签,查询直播中观众提出的问题。我将数据集按照训练/测试/验证对应的 3:1:1 进行分割。我的标签不平衡,并且没有使用上采样/下采样策略。
训练过程中的损失和准确率,其中横轴代表epoch:
验证准确度停止增加近 60%,我的测试数据集的结果相同。 我尝试了各种方法,包括添加更完整的连接层/添加剩余连接,但结果仍然相同。
这是我的训练参数(如果有帮助的话):
lr = 2e-5
batch_size = 128
max_len = 64
hidden_size = 768
epochs = 30
optimizer = AdamW(model.parameters(), lr=lr)
criterion = nn.BCELoss() # loss function
除了数据集之外,还有关于如何改进模型的建议吗?因为我正在做并行的事情并且我知道如何改进它。但我对网络本身确实是新手。
考虑到您的不平衡数据集,Focal Loss 可能是 BCELoss 的一个有价值的替代方案。它专注于难以分类的示例,减少了在不平衡场景中主导损失的简单负面因素的影响。
卷积神经网络 (CNN) 可以捕获局部模式和 n-gram 特征,补充 BERT 捕获的全局上下文。考虑在 BERT 编码器之前或之后添加 CNN 层。
我们还可以实现回调并监控验证准确性,如果在一定数量的 epoch 内没有改善,则尽早停止训练。这可以防止对训练数据的过度拟合。