我正在学习 pytorch,我正在尝试写一篇关于 GAN 渐进式增长 的论文。作者在给定数量的图像上训练网络,而不是给定数量的 epochs。
我的问题是:有没有办法在 pytorch 中使用默认的 DataLoader 来做到这一点?我想做这样的事情:
loader = Dataloader(..., total=800000)
for batch in iter(loader):
... #do training
加载程序自动循环,直到看到
800000
样本。
我认为我会是一个更好的方法,而不是计算你必须自己循环遍历数据集的次数
torch.utils.data.RandomSampler
并从您的数据集中采样。这是一个最小的设置示例:
class DS(Dataset):
def __len__(self):
return 5
def __getitem__(self, index):
return torch.empty(1).fill_(index)
>>> ds = DS()
初始化随机采样器提供
num_samples
并将replacement
设置为True
即如果len(ds) < num_samples
:,采样器将被迫多次绘制实例
>>> sampler = RandomSampler(ds, replacement=True, num_samples=10)
torch.utils.data.DataLoader
:
>>> dl = DataLoader(ds, sampler=sampler, batch_size=2)
>>> for batch in dl:
... print(batch)
tensor([[6.],
[4.]])
tensor([[9.],
[2.]])
tensor([[9.],
[2.]])
tensor([[6.],
[2.]])
tensor([[0.],
[9.]])
torch.utils.data.RandomSampler
可用于随机抽取比数据集中存在的条目更多的条目(其中 num_samples
> dataset_size
);
sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=num_samples)
loader = torch.utils.data.DataLoader(dataset=dataloader_dataset, sampler=sampler, batch_size=batch_size)
如果从 Hugging Face 数据集采样,
dataloader_dataset
类必须配置 StopIteration
以重置迭代器(从数据集的开头开始),例如;
#parameter selection (user configured);
dataset = load_dataset(...)
dataset_size = dataset.num_rows
number_of_dataset_repetitions = 5
num_samples = dataset_size * number_of_dataset_repetitions
batch_size = 8
drop_last = True
dataloader_dataset = DataloaderDatasetRepeatSampler(dataset, dataset_size)
sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=num_samples)
loader = torch.utils.data.DataLoader(dataset=dataloader_dataset, sampler=sampler, batch_size=batch_size, drop_last=drop_last)
loop = tqdm(loader, leave=True)
for batch_index, batch in enumerate(loop):
...
class DataloaderDatasetRepeatSampler(torch.utils.data.Dataset):
def __init__(self, dataset, dataset_size):
self.dataset = dataset
self.dataset_size = dataset_size
self.dataset_iterator = iter(dataset)
def __len__(self):
return self.datasetSize
def __getitem__(self, i):
try:
dataset_entry = next(self.dataset_iterator)
except StopIteration:
#reset iterator (start from beginning of dataset)
self.dataset_iterator = iter(self.dataset)
dataset_entry = next(self.dataset_iterator)
batch_sample = ... #eg torch.Tensor(dataset_entry)
return batch_sample