我想在 pytorch 中的 huggingface vilt Transformer 之上添加一个分类层,以便我可以对我的文本标签进行分类。
通常在正常设置下,vilt 会获取图像、问题对并在前向传递后输出问题的答案
我想让任务成为分类任务而不是文本生成任务。我有一组标签,我希望 vilt 分配哪个标签最有可能成为给定问题的答案。
我对变形金刚完全陌生,对如何完成这项任务知之甚少。有人可以帮我吗?
我检查了这个媒体博客,但无法理解它。
您可以在 Vilt 模型之上添加自己的 Classification_Head。
这只是概述,请根据您的要求进行更改
class ClassificationHead(nn.Module):
def __init__(self, input_size, num_classes):
super(ClassificationHead, self).__init__()
self.fc = nn.Linear(input_size, num_classes)
def forward(self, x):
return self.fc(x)
# Define your number of classes
num_classes = .. # Number of classes in your classification task
# Create the classification head
classification_head = ClassificationHead(vilt_model.config.hidden_size, num_classes)
# Training Loop
for epoch in range(num_epochs):
for batch in dataloader:
inputs = batch["input_ids"]
labels = batch["labels"]
# Forward pass
outputs = vilt_model(inputs).last_hidden_state[:, 0, :]
logits = classification_head(outputs)
# Calculate loss
loss = criterion(logits, labels)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Inference
with torch.no_grad():
inputs = .. # Prepare your input data
outputs = vilt_model(inputs).last_hidden_state[:, 0, :]
logits = classification_head(outputs)
predicted_labels = logits.argmax(dim=1)