我正在使用 python bert-base-uncased 模型基于句子创建标题。这是我写的代码。我需要根据
possible_labels
来预测标题。有什么可能的方法可以根据 possible_labels
进行标题预测?
加载预训练的 BERT 模型和分词器以进行文本分类
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
定义类标签
possible_labels = [
"Greeting",
"Farewell",
"Expression of Gratitude",
"Apology",
"Request",
"Command or Instruction",
"Advertisement or Promotion",
"News Update",
"Question",
"Positive Sentiment",
"Negative Sentiment",
"Neutral Statement",
# Add more labels as needed
]
输入句子进行分类
text = "working on work"
# Tokenize and encode the input sentence
encoded_input = tokenizer(text, return_tensors="pt")
# Make a prediction
output = model(**encoded_input)
# Get the predicted class index
predicted_class_index = torch.argmax(output.logits, dim=1).item()
# Get the predicted label
predicted_label = possible_labels[predicted_class_index]
打印结果
print("Input Sentence:", text)
print("Predicted Label:", predicted_label)
打印所有概率
probabilities = torch.nn.functional.softmax(output.logits, dim=1).tolist()[0]
for label, prob in zip(possible_labels, probabilities):
print(f"Probability for {label}: {prob:.4f}")
这是行不通的,因为你想要使用的模型的头部未经训练。当您执行代码时,您还会收到一条消息:
BertForSequenceClassification的部分权重未初始化 来自 bert-base-uncased 的模型检查点,并且是新的 初始化:['classifier.bias', 'classifier.weight'] 你应该 可能在下游任务上训练这个模型以便能够使用它 用于预测和推理。
您有两个明显的选择是:
from transformers import pipeline
possible_labels = [
"Greeting",
"Farewell",
"Expression of Gratitude",
"Apology",
"Request",
"Command or Instruction",
"Advertisement or Promotion",
"News Update",
"Question",
"Positive Sentiment",
"Negative Sentiment",
"Neutral Statement",
# Add more labels as needed
]
# maybe try other models
pipe = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
pipe("working on work", candidate_labels=possible_labels)
输出:
1.36M/1.36M [00:00<00:00, 33.8MB/s]
{'sequence': 'working on work',
'labels': ['Question',
'Request',
'Positive Sentiment',
'Neutral Statement',
'Command or Instruction',
'News Update',
'Negative Sentiment',
'Advertisement or Promotion',
'Expression of Gratitude',
'Apology',
'Farewell',
'Greeting'],
'scores': [0.17385907471179962,
0.15512266755104065,
0.1200760081410408,
0.11586172878742218,
0.08471532166004181,
0.0735798180103302,
0.0702444538474083,
0.06881491094827652,
0.04482586681842804,
0.0351007916033268,
0.03142492100596428,
0.026374489068984985]}