如何从csv文件中提取图像,标签并使用割炬创建火车集?

问题描述 投票:0回答:1

我下载了一个用于面部关键点检测的数据集,图像和标签位于一个csv文件中,我使用熊猫将其提取出来,但我不知道如何将其转换为张量并将其加载到数据加载器中进行训练。

dataframe = pd.read_csv("training_facial_keypoints.csv")
dataframe['Image'] = dataframe['Image'].apply(lambda i: np.fromstring(i, sep=' '))
dataframe= dataframe.dropna()
images_array = np.vstack(dataframe['Image'].values)/255.0
images_array = images_array.astype(np.float32)
images_array = images_array.reshape(-1, 96, 96, 1)
print(images_array.shape)
labels_array = dataframe[dataframe.columns[:-1]].values
labels_array = (labels_array-48)/48
labels_array = labels_array.astype(np.float32)

我将图像和标签分为两个数组。如何从中创建一个训练集并使用变换。然后使用数据加载器加载它。

deep-learning computer-vision pytorch torch torchvision
1个回答
0
投票
创建torch.utils.data.Dataset的子类,并用您的数据填充它。您可以将所需的torchvision.transforms传递给它,然后将其应用于__getitem__(self, index)中的数据。

比您可以将其传递给torch.utils.data.DataLoader,它允许多线程加载数据。

并且PyTorch具有压倒性的documentation,您应该首先参考。

© www.soinside.com 2019 - 2024. All rights reserved.