我需要按类别拆分 CIFAR10 数据集,以便我可以为每个类别创建具有相同数量样本的较小样本。
我怎样才能最好地实现这一点?
import numpy
sorted_by_value = [0]*10
for i in range(10):
sorted_by_value[i] =(train.data[numpy.where(numpy.array(train.targets) == i)])
numpy.random.shuffle(sorted_by_value[i])
对于任何数据集,你可以将 10 替换为类别数,然后就可以了。
您可以使用 torch.utils.data.Subset 来实现此目的。下面是如何制作仅包含数字 0 到 4 的 MNIST 训练集子集的示例:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
classes = [0, 1, 2, 3, 4]
indices = [i for i, (k, v) in enumerate(trainset) if v in classes]
trainset = torch.utils.data.Subset(trainset, indices)