在 Huggingface Vilt 模型之上添加分类头

问题描述 投票:0回答:1

我想在 pytorch 中的 huggingface vilt Transformer 之上添加一个分类层,以便我可以对我的文本标签进行分类。

通常在正常设置下,vilt 会获取图像、问题对并在前向传递后输出问题的答案

我想让任务成为分类任务而不是文本生成任务。我有一组标签,我希望 vilt 分配哪个标签最有可能成为给定问题的答案。

我对变形金刚完全陌生,对如何完成这项任务知之甚少。有人可以帮我吗?

我检查了这个媒体博客,但无法理解它。

python deep-learning pytorch huggingface-transformers huggingface
1个回答
0
投票

您可以在 Vilt 模型之上添加自己的 Classification_Head。

这只是概述,请根据您的要求进行更改

class ClassificationHead(nn.Module):
    def __init__(self, input_size, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)
    
    def forward(self, x):
        return self.fc(x)

# Define your number of classes
num_classes = ..  # Number of classes in your classification task

# Create the classification head
classification_head = ClassificationHead(vilt_model.config.hidden_size, num_classes)


# Training Loop
for epoch in range(num_epochs):
    for batch in dataloader:
        inputs = batch["input_ids"]
        labels = batch["labels"]

        # Forward pass
        outputs = vilt_model(inputs).last_hidden_state[:, 0, :]
        logits = classification_head(outputs)

        # Calculate loss
        loss = criterion(logits, labels)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Inference
with torch.no_grad():
    inputs = ..  # Prepare your input data
    outputs = vilt_model(inputs).last_hidden_state[:, 0, :]
    logits = classification_head(outputs)
    predicted_labels = logits.argmax(dim=1)

© www.soinside.com 2019 - 2024. All rights reserved.