如何迭代 Dataloader 直到看到大量样本?

问题描述 投票:0回答:2

我正在学习 pytorch,我正在尝试写一篇关于 GAN 渐进式增长 的论文。作者在给定数量的图像上训练网络,而不是给定数量的 epochs。

我的问题是:有没有办法在 pytorch 中使用默认的 DataLoader 来做到这一点?我想做这样的事情:

loader = Dataloader(..., total=800000)
for batch in iter(loader):
   ... #do training

加载程序自动循环,直到看到

800000
样本。

我认为我会是一个更好的方法,而不是计算你必须自己循环遍历数据集的次数

python deep-learning pytorch dataloader
2个回答
1
投票

您可以使用

torch.utils.data.RandomSampler
并从您的数据集中采样。这是一个最小的设置示例:

class DS(Dataset):
    def __len__(self):
        return 5
    def __getitem__(self, index):
        return torch.empty(1).fill_(index)

>>> ds = DS()

初始化随机采样器提供

num_samples
并将
replacement
设置为
True
如果
len(ds) < num_samples

,采样器将被迫多次绘制实例
>>> sampler = RandomSampler(ds, replacement=True, num_samples=10)

然后将此采样器插入新的

torch.utils.data.DataLoader

>>> dl = DataLoader(ds, sampler=sampler, batch_size=2)

>>> for batch in dl:
...     print(batch)
tensor([[6.],
        [4.]])
tensor([[9.],
        [2.]])
tensor([[9.],
        [2.]])
tensor([[6.],
        [2.]])
tensor([[0.],
        [9.]])

0
投票

torch.utils.data.RandomSampler
可用于随机抽取比数据集中存在的条目更多的条目(其中
num_samples
>
dataset_size
);

sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=num_samples)
loader = torch.utils.data.DataLoader(dataset=dataloader_dataset, sampler=sampler, batch_size=batch_size)

如果从 Hugging Face 数据集采样,

dataloader_dataset
类必须配置
StopIteration
以重置迭代器(从数据集的开头开始),例如;

#parameter selection (user configured);
dataset = load_dataset(...) 
dataset_size = dataset.num_rows
number_of_dataset_repetitions = 5
num_samples = dataset_size * number_of_dataset_repetitions
batch_size = 8
drop_last = True

dataloader_dataset = DataloaderDatasetRepeatSampler(dataset, dataset_size)  
sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=num_samples)
loader = torch.utils.data.DataLoader(dataset=dataloader_dataset, sampler=sampler, batch_size=batch_size, drop_last=drop_last)
loop = tqdm(loader, leave=True)

for batch_index, batch in enumerate(loop):
    ...

class DataloaderDatasetRepeatSampler(torch.utils.data.Dataset):
    def __init__(self, dataset, dataset_size):
        self.dataset = dataset
        self.dataset_size = dataset_size
        self.dataset_iterator = iter(dataset)
            
    def __len__(self):
        return self.datasetSize

    def __getitem__(self, i):
        try:
            dataset_entry = next(self.dataset_iterator)
        except StopIteration:
            #reset iterator (start from beginning of dataset)
            self.dataset_iterator = iter(self.dataset)
            dataset_entry = next(self.dataset_iterator)
        batch_sample = ...  #eg torch.Tensor(dataset_entry)
        return batch_sample
© www.soinside.com 2019 - 2024. All rights reserved.