使用我的
__getitem__
函数检查时,它正确地提供了图像、标题和 class_id。图像是大小为 [3, 256, 256] 的张量,标题是 20 个元素的 list。
但是当我观察数据加载器时,它以错误的方式对字幕进行分组。批量大小为 32。因此,对于一批数据,dataloader 的预期组件是 32 个图像、32 个标题和 32 个 class_ids。
但是数据加载器给出了 32 张图像、20 个标题 和 32 个 class_id。 其中 20 是标题的最大长度.
数据加载器不是给出每个长度为 20 的 32 个标题,而是给出一个包含 20 个元组的列表,每个元组的长度为 32.
这里是数据加载器的代码,我只展示代码的相关部分。
class Load_Dataset(data.Dataset):
def __init__(self, data_dir, split='train',
base_size=64,
transform=None, target_transform=None):
self.transform = transform
.......
def get_caption(self, sent_ix):
# a list of strings for a sentence
sent_caption = self.captions[sent_ix]
if sent_caption[-1] == '<end>':
print('ERROR: do not need END token', sent_caption)
num_words = len(sent_caption)
# pad with '<end>' tokens
x = ['<end>'] * 20
x_len = num_words
if num_words <= 20:
x[:num_words] = sent_caption
else:
ix = list(np.arange(num_words))
np.random.shuffle(ix)
ix = ix[:20]
ix = np.sort(ix)
x[:] = [sent_caption[i] for i in ix]
x_len = 20
return x, x_len
def __getitem__(self, index):
#
key = self.filenames[index]
cls_id = self.class_id[index]
#
.....
#
img_name = '%s/images/%s.jpg' % (data_dir, key)
imgs = get_imgs(img_name, self.imsize,
bbox, self.transform, normalize=self.norm)
# random select a sentence
sent_ix = random.randint(0, self.embeddings_num)
new_sent_ix = index * self.embeddings_num + sent_ix
caps, cap_len = self.get_caption(new_sent_ix)
return imgs, caps, cap_len, cls_id, key
def __len__(self):
return len(self.filenames)
如何处理?