如何将泡菜文件中的数据集加载到PyTorch中?

问题描述 投票:0回答:1

我以整数矩阵的形式在单独的泡菜文件中具有X_train(inputs)和Y_train(labels)。现在,我需要加载它们并使用PyTorch进行训练。我尝试了torch.utils.data.DataLoadertorchvision.datasets.DatasetFolder,但没有任何效果,否则我可能在某个地方出错了。请提出同样的正确方法。

deep-learning pytorch pickle torch torchvision
1个回答
0
投票

您确实应该通过一些示例对问题进行清晰的描述。无论如何,据我了解,您正在寻找类似的东西。

import pickle
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader


class YourDataset(Dataset):

    def __init__(self, X_Train, Y_Train, transform=None):
        self.X_Train = X_Train
        self.Y_Train = Y_Train
        self.transform = transform

    def __len__(self):
        return len(self.X_Train)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        x = self.X_Train[idx]
        y = self.Y_Train[idx]

        if self.transform:
            x = self.transform(x)
            y = self.transform(y)

        return x, y


file = open('FILENAME_X_train', 'rb')
X_train = pickle.load(file)
file.close()

file = open('FILENAME_Y_train', 'rb')
Y_train = pickle.load(file)
file.close()

your_dataset = YourDataset(X_train, Y_train, transform=transforms.Compose([transforms.ToTensor()]))

your_data_loader = DataLoader(your_dataset, batch_size=8, shuffle=True, num_workers=0)

请注意,我尚未测试代码,但我认为它提供了总体思路。希望对您有所帮助。

© www.soinside.com 2019 - 2024. All rights reserved.