为什么数据加载器以错误的方式对字幕进行分组?

问题描述 投票:0回答:0

使用我的

__getitem__
函数检查时,它正确地提供了图像、标题和 class_id。图像是大小为 [3, 256, 256] 的张量,标题是 20 个元素的 list

但是当我观察数据加载器时,它以错误的方式对字幕进行分组。批量大小为 32。因此,对于一批数据,dataloader 的预期组件是 32 个图像、32 个标题和 32 个 class_ids。

但是数据加载器给出了 32 张图像、20 个标题 和 32 个 class_id。 其中 20 是标题的最大长度.

数据加载器不是给出每个长度为 20 的 32 个标题,而是给出一个包含 20 个元组的列表,每个元组的长度为 32.

这里是数据加载器的代码,我只展示代码的相关部分。

class Load_Dataset(data.Dataset):
    def __init__(self, data_dir, split='train',
                 base_size=64,
                 transform=None, target_transform=None):
        self.transform = transform
        .......

    def get_caption(self, sent_ix):
        # a list of strings for a sentence
        sent_caption = self.captions[sent_ix]
        if sent_caption[-1] == '<end>':
            print('ERROR: do not need END token', sent_caption)
        num_words = len(sent_caption)
        # pad with '<end>' tokens
        x = ['<end>'] * 20
        x_len = num_words
        if num_words <= 20:
            x[:num_words] = sent_caption
        else:
            ix = list(np.arange(num_words))
            np.random.shuffle(ix)
            ix = ix[:20]
            ix = np.sort(ix)
            x[:] = [sent_caption[i] for i in ix]
            x_len = 20
        return x, x_len

    def __getitem__(self, index):
        #
        key = self.filenames[index]
        cls_id = self.class_id[index]
        #
        .....
        #
        img_name = '%s/images/%s.jpg' % (data_dir, key)
        imgs = get_imgs(img_name, self.imsize,
                        bbox, self.transform, normalize=self.norm)
        # random select a sentence
        sent_ix = random.randint(0, self.embeddings_num)
        new_sent_ix = index * self.embeddings_num + sent_ix
        caps, cap_len = self.get_caption(new_sent_ix)
        
        return imgs, caps, cap_len, cls_id, key


    def __len__(self):
        return len(self.filenames)

如何处理?

torch dataloader
© www.soinside.com 2019 - 2024. All rights reserved.