pytorch数据集中每个类的实例数。

问题描述 投票:1回答:1

我想用PyTorch做一个简单的图像分类器.我是这样把数据加载到数据集和dataLoader中的。

batch_size = 64
validation_split = 0.2
data_dir = PROJECT_PATH+"/categorized_products"
transform = transforms.Compose([transforms.Grayscale(), CustomToTensor()])

dataset = ImageFolder(data_dir, transform=transform)

indices = list(range(len(dataset)))

train_indices = indices[:int(len(indices)*0.8)] 
test_indices = indices[int(len(indices)*0.8):]

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=16)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, num_workers=16)

我想分别打印出训练数据和测试数据中每个类的图像数量,就像这样。

在训练数据中:

  • 鞋子: 20
  • 衬衫。14

在测试数据。

  • 鞋子: 4
  • 衬衫。3

我试了一下。

from collections import Counter
print(dict(Counter(sample_tup[1] for sample_tup in dataset.imgs)))

但我得到了这个错误。

AttributeError: 'MyDataset' object has no attribute 'img'
python pytorch torch dataloader
1个回答
3
投票

你需要使用 .targets 来访问数据的标签,例如,在MNIST数据集中,它将打印出这样的内容。

print(dict(Counter(dataset.targets)))

它会打印出类似这样的内容(例如在MNIST数据集中)。

{5: 5421, 0: 5923, 4: 5842, 1: 6742, 9: 5949, 2: 5958, 3: 6131, 6: 5918, 7: 6265, 8: 5851}

同时,你也可以使用 .classes.class_to_idx 来获取标签id到类的映射。

print(dataset.class_to_idx)
{'0 - zero': 0,
 '1 - one': 1,
 '2 - two': 2,
 '3 - three': 3,
 '4 - four': 4,
 '5 - five': 5,
 '6 - six': 6,
 '7 - seven': 7,
 '8 - eight': 8,
 '9 - nine': 9}

编辑: 方法1

从注释中可以看出,要想分别得到训练集和测试集的类分布,只需在子集上进行如下迭代即可。

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# labels in training set
train_classes = [label for _, label in train_dataset]
Counter(train_classes)
Counter({0: 4757,
         1: 5363,
         2: 4782,
         3: 4874,
         4: 4678,
         5: 4321,
         6: 4747,
         7: 5024,
         8: 4684,
         9: 4770})

编辑(2): 方法2

由于你有一个大的数据集,而且正如你所说的,它需要相当多的时间来迭代所有的训练集,还有一个方法。

你可以用 .indices 的子集,它指的是原始数据集中被选为子集的索引。

train_classes = [dataset.targets[i] for i in train_dataset.indices]
Counter(train_classes) # if doesn' work: Counter(i.item() for i in train_classes)
© www.soinside.com 2019 - 2024. All rights reserved.