以某pytorch数据集的子集

问题描述 投票:3回答:1

我有一个网络,我想在一些数据集训练(作为一个例子,说CIFAR10)。我可以通过创建数据加载对象

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

我的问题是:假设我要做出几个不同的训练迭代。比方说,我想在第一次训练在奇数位置的所有图像的网络,然后对所有的图像在偶数位置等。为了做到这一点,我需要能够访问到这些图像。不幸的是,似乎trainset不允许这样的访问。也就是说,试图做trainset[:1000]或者更一般trainset[mask]将抛出一个错误。

我所能做的,而不是

trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]

接着

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)

然而,这将迫使我建立在每个迭代的完整数据集的新副本(因为我已经改变了trainset.train_data所以我需要重新定义trainset)。是否有某种方式来避免呢?

理想情况下,我想有一些“相当于”

trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
                                              shuffle=True, num_workers=2)
python machine-learning neural-network torch pytorch
1个回答
12
投票

您可以定义数据集装载机避免重新创建数据集(只是创建为每个不同的采样新装载机)的自定义采样。

class YourSampler(Sampler):
    def __init__(self, mask):
        self.mask = mask

    def __iter__(self):
        return (self.indices[i] for i in torch.nonzero(self.mask))

    def __len__(self):
        return len(self.mask)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

sampler1 = YourSampler(your_mask)
sampler2 = YourSampler(your_other_mask)
trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          sampler = sampler1, shuffle=False, num_workers=2)
trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          sampler = sampler2, shuffle=False, num_workers=2)

PS:你可以在这里找到更多的信息:http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler

© www.soinside.com 2019 - 2024. All rights reserved.