我正在尝试用Python编写一个类层次结构,以便子类可以重写方法
predict
以获得更窄的返回类型,该返回类型本身就是父类返回类型的子类。当我实例化子类的实例并调用 predict
; 时,这似乎工作正常;返回值具有预期的窄类型。但是,当我调用基类 (predict_batch
) 上定义的另一个函数(它本身调用 predict
)时,窄返回类型就会丢失。
一些背景:我的程序必须支持使用两种类型的图像分割模型:“实例”和“语义”。这两个模型的输出非常不同,所以我想用对称的类层次结构来存储它们的输出(即
BaseResult
、InstResult
和SemResult
)。当不需要知道使用哪种特定类型的模型时,这将允许某些客户端代码通过使用 BaseResults
变得通用。
这是一个玩具代码示例:
from abc import ABC, abstractmethod
from typing import List
from overrides import overrides
##################
# Result classes #
##################
class BaseResult(ABC):
"""Abstract container class for result of image segmentation"""
pass
class InstResult(BaseResult):
"""Stores the result of instance segmentation"""
pass
class SemResult(BaseResult):
"""Stores the result of semantic segmentation"""
pass
#################
# Model classes #
#################
class BaseModel(ABC):
def predict_batch(self, images: List) -> List[BaseResult]:
return [self.predict(img) for img in images]
@abstractmethod
def predict(self, image) -> BaseResult:
raise NotImplementedError()
class InstanceSegModel(BaseModel):
"""performs instance segmentation on images"""
@overrides
def predict(self, image) -> InstResult:
return InstResult()
class SemanticSegModel(BaseModel):
"""performs semantic segmentation on images"""
@overrides
def predict(self, image) -> SemResult:
return SemResult()
########
# main #
########
# placeholder for illustration
images = [None, None, None]
model = InstanceSegModel()
single_result = model.predict(images[0]) # has type InstResult
batch_result = model.predict_batch(images) # has type List[BaseResult]
在上面的代码中,我希望
batch_result
具有类型 List[InstResult]
。
在运行时,这些都不重要,我的代码执行得很好。但是我的编辑器 (VS Code) 中的静态类型检查器 (Pylance) 不喜欢客户端代码假定
batch_result
是更窄的类型。我只能想到这两种可能的解决方案,但我觉得都不干净:
cast
模块中的 typing
功能predict_batch
,即使逻辑没有改变您可以使用泛型和继承来覆盖/缩小父类中的注释
from typing import List, Generic, TypeVar
T = TypeVar('T')
class BaseModel(ABC, Generic[T]):
def predict_batch(self, images: List) -> List[T]:
return [self.predict(img) for img in images]
@abstractmethod
def predict(self, image) -> T:
raise NotImplementedError()
class InstanceSegModel(BaseModel[InstResult]):
"""performs instance segmentation on images"""
@overrides
def predict(self, image) -> InstResult:
return InstResult()
class SemanticSegModel(BaseModel[SemResult]):
"""performs semantic segmentation on images"""
@overrides
def predict(self, image) -> SemResult:
return SemResult()
images = [None, None, None]
model = InstanceSegModel()
single_result = model.predict(images[0]) # has type InstResult
batch_result = model.predict_batch(images) # has type List[InstResult]