如何从.pt文件创建Pytorch数据集?

问题描述 投票:0回答:1

我已将保存为.pt文件的MNIST图像转换为Google驱动器中的文件夹。我正在用Colab编写我的Pytorch代码。

我想使用这些文件,并创建一个数据集,将这些图像存储为Tensors。我怎样才能做到这一点?

在训练期间转换图像花费的时间太长。因此,转换它们并将它们全部保存为.pt文件。我只想将它们作为数据集加载回来并在我的模型中使用它们。

python computer-vision pytorch mnist dcgan
1个回答
1
投票

您保存图像的方法确实是一个好主意。在这种情况下,您只需编写自己的数据集类来加载图像。

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler

class ReaderDataset(Dataset):
    def __init__(self, filename):
        # load the images from file

    def __len__(self):
        # return total dataset size

    def __getitem__(self, index):
        # write your code to return each batch element

然后您可以按如下方式创建Dataloader。

train_dataset = ReaderDataset(filepath)
train_sampler = RandomSampler(train_dataset)
train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    sampler=train_sampler,
    num_workers=args.data_workers,
    collate_fn=batchify,
    pin_memory=args.cuda,
    drop_last=args.parallel
)
# args is a dictionary containing parameters
# batchify is a custom function that prepares each mini-batch
© www.soinside.com 2019 - 2024. All rights reserved.