我有一个非常大的 HDF5 格式的数据集,我无法一次将其全部加载到内存中。我正在使用 Torch 的自定义数据集。
这是代码:
import time
from utils import get_vocab_and_skipgrams
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import h5py
import numpy as np
import torch
class CustomSkipGramDataset(Dataset):
def __init__(self, filename, window_size, data_dir="training_data", data_exists=True):
self.window_size = window_size
self.filename = filename
self.data_exists = data_exists
self.vocab_path = os.path.join(data_dir, "vocab.npy")
self.hdf5_path = os.path.join(data_dir, "skipgram.h5")
if not data_exists:
get_vocab_and_skipgrams(filename, data_dir)
self.vocab = np.load(self.vocab_path, allow_pickle=True).tolist()
self.vocab_size = len(self.vocab)
self.hf = h5py.File(self.hdf5_path, "r")
self.dataset = self.hf["positive_skips"]
def __len__(self):
return self.dataset.shape[0]
def __getitem__(self, index):
x, y = self.dataset[index]
return x, y
现在当我像这样直接加载它时:
with h5py.File("./training_data/skipgram.h5") as hf:
dataset = hf["positive_skips"]
for a in range(1,100):
print(torch.tensor(dataset[a:100*a]))
与 Torch 自定义数据集相比,它确实非常快。几乎快了 100 倍。我知道我做错了。
我使用带有 shuffle=True 的 Dataloader,一旦将 shuffle 更改为 False,它就可以正常工作。现在用随机播放从磁盘读取会很慢,这是有道理的。