具有长度的数据集的 PyTorch 类型

问题描述 投票:0回答:1

我正在创建一个元数据集,该元数据集组合了来自多个输入数据集的数据。

from torch.utils.data import Dataset, IterableDataset

class MetaDataset(Dataset):
    def __init__(self, regular_dataset: Dataset, iterable_dataset: IterableDataset):
        self.regular_dataset = regular_dataset
        self.iterable_dataset = iterable_dataset
        pass # Do other stuff...

当我尝试从

len(self.regular_dataset)
 内部访问 
MetaDataset

时,收到类型警告

事实证明,PyTorch

Dataset
的类型定义故意不包含
__len__
。因此,我必须建立自己的类型:

from torch.utils.data import Dataset, IterableDataset

class DatasetWithLength(Dataset):
    def __len__(self) -> int:
        pass

class MetaDataset(Dataset):
    def __init__(self, regular_dataset: DatasetWithLength, iterable_dataset: IterableDataset):
        self.regular_dataset = regular_dataset
        self.iterable_dataset = iterable_dataset
        pass # Do other stuff...

但是现在,当我尝试这样做时,我收到了

Expected type 'DatasetWithLength', got 'FirstDataset' instead
警告:

foo = MetaDataset(
    FirstDataset(),
    FirstIterableDataset()
)

如何正确定义具有 length 属性的 PyTorch 数据集的类型。

python pytorch python-typing
1个回答
0
投票

您需要使用

Generic
,这样就不需要从
DatasetWithLength
扩展。

from typing import Generic

class _DatasetWithLength(Dataset):
    def __len__(self) -> int:
        ...

DatasetWithLength = Generic[type[_DatasetWithLength]]
© www.soinside.com 2019 - 2024. All rights reserved.