我正在创建一个元数据集,该元数据集组合了来自多个输入数据集的数据。
from torch.utils.data import Dataset, IterableDataset
class MetaDataset(Dataset):
def __init__(self, regular_dataset: Dataset, iterable_dataset: IterableDataset):
self.regular_dataset = regular_dataset
self.iterable_dataset = iterable_dataset
pass # Do other stuff...
当我尝试从
len(self.regular_dataset)
内部访问
MetaDataset
时,收到类型警告
Dataset
的类型定义故意不包含 __len__
。因此,我必须建立自己的类型:
from torch.utils.data import Dataset, IterableDataset
class DatasetWithLength(Dataset):
def __len__(self) -> int:
pass
class MetaDataset(Dataset):
def __init__(self, regular_dataset: DatasetWithLength, iterable_dataset: IterableDataset):
self.regular_dataset = regular_dataset
self.iterable_dataset = iterable_dataset
pass # Do other stuff...
但是现在,当我尝试这样做时,我收到了
Expected type 'DatasetWithLength', got 'FirstDataset' instead
警告:
foo = MetaDataset(
FirstDataset(),
FirstIterableDataset()
)
如何正确定义具有 length 属性的 PyTorch 数据集的类型。
Generic
,这样就不需要从DatasetWithLength
扩展。
from typing import Generic
class _DatasetWithLength(Dataset):
def __len__(self) -> int:
...
DatasetWithLength = Generic[type[_DatasetWithLength]]