如何使用 DataLoader 为 BLEU 指标加载多个不同的引用?

问题描述 投票:0回答:0

我有一个图像字幕数据集,其中每个样本都由一个图像和一个字幕列表组成。

  • 每个样本都有一个或多个标题
  • 每个样本的字幕数量可以不同。

这是一个视觉示例:

我正在使用 PyTorch,我创建了自定义

Dataset
Dataloader
来训练模型和执行评估。

  • 为了训练,我在可用字幕列表中随机选择一个字幕,然后计算模型输出和模型输出与目标之间的 NLL 损失。
  • 为了评估,我想计算模型和采样标题之间的损失,以及文本生成任务中使用的其他指标,例如 BLEUROUGE。这些指标接受多个引用,所以我想传递每个样本的所有可用标题的列表。

使

Dataset
Dataloader
处理这两种情况的最佳方法是什么,即提供一个随机选择的训练标签和多参考指标的所有标签?

  • 我尝试向 Dataset 类添加一个标志,该标志将在验证和测试期间设置为 true。但是,由于每个样本都有不同数量的标签,因此 DataLoader 无法构建批次。一种解决方案可能是直接迭代数据集,但我认为必须有更好的解决方案。
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, root, split, image_transform, processor):
        file = pl.Path(root) / '{}.json'.format(split)
        with open(file) as f:
            j = json.load(f)
            self.data = list(j.values())
        self.split = split
        self.image_transform = image_transform
        self.processor = processor

    def __getitem__(self, i):
        image_path = self.data[i]['img_url']
        image = Image.open(image_path).convert('RGB')
        # randomly sample one visual sentence
        labels = self.data[i]['visual_sentences']
        if self.image_transform is not None:
            image = self.image_transform(image)
        encoding = self.processor(images=image, text=random.sample(labels, 1), padding="max_length", return_tensors="pt")
            # remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        # add all the labels if not in training
        if self.split != 'train':
            encoding['labels'] = labels
        return encoding


class MyDataLoader(BaseDataLoader):
    def __init__(self, data_dir, batch_size, split, shuffle=True, validation_split=0.0, num_workers=1, processor=None):
        transform = transforms.Compose([
            transforms.Resize((224, 224))
        ])
        processor = AutoProcessor.from_pretrained(processor)
        self.data_dir = data_dir
        self.dataset = MyDataset(data_dir, split, image_transform=transform, processor=processor)
        super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)

  • 我是否必须手动“填充”列表,例如通过添加空字符串以使所有列表长度相等(然后在计算指标时删除这些空字符串,我想)?还有其他解决方案吗?
deep-learning pytorch dataloader bleu multimodal
© www.soinside.com 2019 - 2024. All rights reserved.