我有一个网络,我想在一些数据集训练(作为一个例子,说CIFAR10
)。我可以通过创建数据加载对象
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
我的问题是:假设我要做出几个不同的训练迭代。比方说,我想在第一次训练在奇数位置的所有图像的网络,然后对所有的图像在偶数位置等。为了做到这一点,我需要能够访问到这些图像。不幸的是,似乎trainset
不允许这样的访问。也就是说,试图做trainset[:1000]
或者更一般trainset[mask]
将抛出一个错误。
我所能做的,而不是
trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]
接着
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
然而,这将迫使我建立在每个迭代的完整数据集的新副本(因为我已经改变了trainset.train_data
所以我需要重新定义trainset
)。是否有某种方式来避免呢?
理想情况下,我想有一些“相当于”
trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
shuffle=True, num_workers=2)
您可以定义数据集装载机避免重新创建数据集(只是创建为每个不同的采样新装载机)的自定义采样。
class YourSampler(Sampler):
def __init__(self, mask):
self.mask = mask
def __iter__(self):
return (self.indices[i] for i in torch.nonzero(self.mask))
def __len__(self):
return len(self.mask)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
sampler1 = YourSampler(your_mask)
sampler2 = YourSampler(your_other_mask)
trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,
sampler = sampler1, shuffle=False, num_workers=2)
trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,
sampler = sampler2, shuffle=False, num_workers=2)
PS:你可以在这里找到更多的信息:http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler