典型的 PyTorch 数据集实现如下:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data[:][:-1]
self.target = data[:][-1]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
y = self.target[index]
return x, y
我想将其实现为数据集的原因是因为我想稍后使用 PyTorch DataLoader 进行小批量训练。
但是,如果
data
来自包含多个 Parquet 文件的目录,我如何为其编写 __getitem__
而不将所有数据加载到内存中?我知道 PyArrow 擅长批量加载数据,但我没有找到很好的参考来使其发挥作用。
我认为这样的事情应该有效。但是,我不确定它会非常有效。它的两个问题是 Parquet 不太适合点查找,而且我们还将数据加载到 Python 中然后返回到
torch
。我认为应该可以。
from typing import Any
import torch
import pyarrow.dataset
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data: pyarrow.dataset.Dataset, label_column: str) -> None:
self.data = data
def __len__(self) -> None:
return self.data.count_rows()
def __getitem__(self, index: int) -> tuple[list[Any], Any]:
row = self.data.take([index]).to_pylist()[0]
x = [v for k, v in row.items() if k != self.label_column]
y = row[self.label_column]
return x, y
data = pyarrow.dataset.dataset('dataset_dir', format='parquet')
torch_dataset = CustomDataset(data=data, label_column='label')