我正在实现一个类,该类应该支持一些通用行为,用于使用 PytorchLightning 的
LightningDataModule
设置训练/验证/测试数据加载器。我想在这个泛型类中提供一些功能,但将初始化一些属性留给那些从它继承的属性。我对这个问题的尝试应该说明我的想法:
from typing import Protocol
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
class HasTrainValTestDatasets(Protocol):
@property
def train_ds(self) -> Dataset: ...
@property
def val_ds(self) -> Dataset: ...
@property
def test_ds(self) -> Dataset: ...
class GenericTrainValTestDataModule(pl.LightningDataModule):
def __init__(
self,
batch_size_train: int,
batch_size_eval: int,
num_workers: int = 0,
):
self._batch_size_train = batch_size_train
self._batch_size_eval = batch_size_eval
self._num_workers = num_workers
def train_dataloader(self: HasTrainValTestDatasets) -> DataLoader:
return DataLoader(self.train_ds, batch_size=self._batch_size_train, num_workers=self._num_workers)
def val_dataloader(self: HasTrainValTestDatasets) -> DataLoader:
return DataLoader(self.val_ds, batch_size=self._batch_size_eval, num_workers=self._num_workers)
def test_dataloader(self: HasTrainValTestDatasets) -> DataLoader:
return DataLoader(self.test_ds, batch_size=self._batch_size_eval, num_workers=self._num_workers)
我受到了 Stack Overflow 答案的启发:https://stackoverflow.com/a/59128961/8543212。
我想概括地说
DataLoader
是使用 batch_size_(train/eval)
和 num_workers
中的 train/val/test_ds
创建的,后者将由从该泛型类继承的类实现。
不幸的是,我无法正确输入提示。我的目标是使用
Protocol
来强制该通用类的用户提供 train/val/test_ds
。然而,我无法通过上面的例子让 mypy 满意,因为:
d.py:30: error: "HasTrainValTestDatasets" has no attribute "_batch_size_train" [attr-defined]
d.py:30: error: "HasTrainValTestDatasets" has no attribute "_num_workers" [attr-defined]
d.py:35: error: "HasTrainValTestDatasets" has no attribute "_batch_size_eval" [attr-defined]
d.py:35: error: "HasTrainValTestDatasets" has no attribute "_num_workers" [attr-defined]
d.py:40: error: "HasTrainValTestDatasets" has no attribute "_batch_size_eval" [attr-defined]
d.py:40: error: "HasTrainValTestDatasets" has no attribute "_num_workers" [attr-defined]
有没有办法告诉mypy
self
既是HasTrainValTestDatasets
协议又是GenericTrainValTestDataModule
?
对于那些想知道为什么坚持这样的设计的人,我想不出更好的概括 https://lightning.ai/docs/pytorch/stable/data/datamodule.html 来实现我的目的(但我可能是错的)。
重现步骤(假设我的演示存储在
d.py
中):
virtualenv venv
pip install torch pytorch-lightning
mypy --install-types d.py
我想我(大致)找到了我正在寻找的解决方案:
from typing import Protocol
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor
class HasTrainValTestDatasets(Protocol):
train_ds: Dataset
val_ds: Dataset
test_ds: Dataset
class GenericTrainValTestDataModule(pl.LightningDataModule, HasTrainValTestDatasets):
def __init__(self, batch_size_train: int, batch_size_eval: int, num_workers: int = 0):
self._batch_size_train = batch_size_train
self._batch_size_eval = batch_size_eval
self._num_workers = num_workers
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_ds, batch_size=self._batch_size_train, num_workers=self._num_workers)
def val_dataloader(self) -> DataLoader:
return DataLoader(self.val_ds, batch_size=self._batch_size_eval, num_workers=self._num_workers)
def test_dataloader(self) -> DataLoader:
return DataLoader(self.test_ds, batch_size=self._batch_size_eval, num_workers=self._num_workers)
class ConcreteDataModuleOk(GenericTrainValTestDataModule):
def __init__(self):
super().__init__(8, 8, 0)
self.train_ds = FakeData(transform=ToTensor())
self.val_ds = FakeData(transform=ToTensor())
self.test_ds = FakeData(transform=ToTensor())
class ConcreteDataModuleBad(GenericTrainValTestDataModule):
def __init__(self):
super().__init__(8, 8, 0)
cdm_ok = ConcreteDataModuleOk()
print(cdm_ok.train_ds[0][0].shape)
print(next(iter(cdm_ok.train_dataloader()))[0].shape)
cdm_bad = ConcreteDataModuleBad() # <-- mypy will complain here
额外
pip install -r torchvision
当我跑步时:mypy --ignore-missing d1.py
我得到:
d.py:49: error: Cannot instantiate abstract class "ConcreteDataModuleBad" with abstract attributes "test_ds", "train_ds" and "val_ds" [abstract]
Found 1 error in 1 file (checked 1 source file)
这正是我想要实现的目标。如果我删除
Protocol
,mypy 将无法识别该问题,所以我猜协议会完成其工作。需要注意的是,由于某种原因,它无法识别 train/val/test_ds
何时具有错误类型(例如字符串而不是 Dataset
),但我可以使用它。
另外,正如@Reinderien所说,这也可以用抽象类来完成(也许有相同的结果)?我会坚持我的解决方案,因为它对我来说看起来很简洁。