如何使用pytorch中的Dataloader在自定义数据集中查找图像的标签名称?

问题描述 投票:0回答:1

我使用此代码从自定义数据集中加载训练数据,有 14 个类,即 14 个文件夹,每个文件夹有 100 个图像。

data_loader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    
)

我想打印图像所属的类的名称。我该怎么做?

python deep-learning pytorch computer-vision pytorch-dataloader
1个回答
0
投票

您是否使用 torchvision.datasets.ImageFolder 创建

dataset_train
?如果是,这可能是您的解决方案:

class_names = {idx: cls for cls, idx in data_loader_train.dataset.class_to_idx.items()}

data_loader_train.dataset
将返回其加载(或迭代)的原始数据集,并且由于数据集是使用 torchvision.datasets.ImageFolder 制作的,因此它有一个属性
class_to_idx
,它是映射类标签(真实类)的字典name) 到其对应的索引。上面的代码将反转该字典,帮助您从类索引中获取类名。

希望有帮助!

© www.soinside.com 2019 - 2024. All rights reserved.