当数据集太大而无法一次加载到内存中时,如何将 PyArrow 与 PyTorch 数据集集成?

问题描述 投票:0回答:1

典型的 PyTorch 数据集实现如下:

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data[:][:-1]
        self.target = data[:][-1]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]

        return x, y

我想将其实现为数据集的原因是因为我想稍后使用 PyTorch DataLoader 进行小批量训练。

但是,如果

data
来自包含多个 Parquet 文件的目录,我如何为其编写
__getitem__
而不将所有数据加载到内存中?我知道 PyArrow 擅长批量加载数据,但我没有找到很好的参考来使其发挥作用。

pytorch pyarrow pytorch-dataloader
1个回答
0
投票

我认为这样的事情应该有效。但是,我不确定它会非常有效。它的两个问题是 Parquet 不太适合点查找,而且我们还将数据加载到 Python 中然后返回到

torch
。我认为应该可以。

from typing import Any

import torch
import pyarrow.dataset
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data: pyarrow.dataset.Dataset, label_column: str) -> None:
        self.data = data

    def __len__(self) -> None:
        return self.data.count_rows()

    def __getitem__(self, index: int) -> tuple[list[Any], Any]:
        row = self.data.take([index]).to_pylist()[0]
        x = [v for k, v in row.items() if k != self.label_column]
        y = row[self.label_column]

        return x, y

data = pyarrow.dataset.dataset('dataset_dir', format='parquet')
torch_dataset = CustomDataset(data=data, label_column='label')
© www.soinside.com 2019 - 2024. All rights reserved.