我是PyTorch的新手,在最近的几天里,我一直在努力使用Dataset类来构建自定义数据集。
我正在使用此数据集(https://www.kaggle.com/ianmoone0617/flower-goggle-tpu-classification/kernels),问题在于图像和它们的标签位于单独的文件夹中,我不知道如何对其进行串联。
这是我正在使用的代码:
class MyDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
self.labels = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
if torch.is_tensor(index):
index = index.tolist()
image_name = os.path.join(self.root_dir, self.labels.iloc[index, 0])
image = io.imread(image_name)
if self.transform:
image = self.transform(image)
return (image, labels)
而文件夹的结构如下:structure of the folders
我真的很想了解这一点,所以在此先谢谢大家!!
似乎您快到了。有很多方法可以解决这个问题。例如,您可以在初始化期间读取两个csv文件以构建字典,该字典将flowers_idx.csv
中的标签字符串映射到flowers_label.csv
中指定的标签索引。
import torch import torchvision.transforms as tt from torchvision.datasets.folder import default_loader from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, data_csv, label_csv, root_dir, transform=None): self.data_entries = pd.read_csv(data_csv) self.root_dir = root_dir self.transform = transform label_map = pd.read_csv(label_csv) self.label_str_to_idx = {label_str: label_idx for label_idx, label_str in label_map.iloc} def __len__(self): return len(self.labels) def __getitem__(self, index): if torch.is_tensor(index): index = index.item() label = self.label_str_to_idx[self.data_entries.iloc[index, 1]] image_path = os.path.join(self.root_dir, f'{self.data_entries.iloc[index, 0]}.jpeg') # torchvision datasets generally return PIL image rather than numpy ndarray image = default_loader(image_path) # alternative to load ndarray using skimage.io # image = io.imread(image_path) if self.transform: image = self.transform(image) return (image, label)
注意,这将返回
PIL
图像而不是ndarray,因为通常是Torchvision数据集返回的图像。这意味着您可以根据需要使用内置的Torchvision转换。
目前,一个简单的用例可能是:
import torchvision.transforms as tt dataset_dir = '/home/jodag/datasets/527293_966816_bundle_archive' # TODO add more transforms/data-augmentation etc... transform = tt.Compose(( tt.ToTensor(), )) dataset = MyDataset( os.path.join(dataset_dir, 'flowers_idx.csv'), os.path.join(dataset_dir, 'flowers_label.csv'), os.path.join(dataset_dir, 'flower_tpu/flower_tpu/flowers_google/flowers_google'), transform) image, label = dataset[0]
在训练或验证期间,您可能会使用
DataLoader
来采样数据集。