调用基类定义的函数时如何获取子类的返回类型?

问题描述 投票:0回答:1

我正在尝试用Python编写一个类层次结构,以便子类可以重写方法

predict
以获得更窄的返回类型,该返回类型本身就是父类返回类型的子类。当我实例化子类的实例并调用
predict
; 时,这似乎工作正常;返回值具有预期的窄类型。但是,当我调用基类 (
predict_batch
) 上定义的另一个函数(它本身调用
predict
)时,窄返回类型就会丢失。

一些背景:我的程序必须支持使用两种类型的图像分割模型:“实例”和“语义”。这两个模型的输出非常不同,所以我想用对称的类层次结构来存储它们的输出(即

BaseResult
InstResult
SemResult
)。当不需要知道使用哪种特定类型的模型时,这将允许某些客户端代码通过使用
BaseResults
变得通用。

这是一个玩具代码示例:

from abc import ABC, abstractmethod
from typing import List

from overrides import overrides

##################
# Result classes #
##################


class BaseResult(ABC):
    """Abstract container class for result of image segmentation"""

    pass


class InstResult(BaseResult):
    """Stores the result of instance segmentation"""

    pass


class SemResult(BaseResult):
    """Stores the result of semantic segmentation"""

    pass


#################
# Model classes #
#################


class BaseModel(ABC):
    def predict_batch(self, images: List) -> List[BaseResult]:
        return [self.predict(img) for img in images]

    @abstractmethod
    def predict(self, image) -> BaseResult:
        raise NotImplementedError()


class InstanceSegModel(BaseModel):
    """performs instance segmentation on images"""

    @overrides
    def predict(self, image) -> InstResult:
        return InstResult()


class SemanticSegModel(BaseModel):
    """performs semantic segmentation on images"""

    @overrides
    def predict(self, image) -> SemResult:
        return SemResult()


########
# main #
########

# placeholder for illustration 
images = [None, None, None]

model = InstanceSegModel()
single_result = model.predict(images[0])  # has type InstResult
batch_result = model.predict_batch(images)  # has type List[BaseResult]

在上面的代码中,我希望

batch_result
具有类型
List[InstResult]

在运行时,这些都不重要,我的代码执行得很好。但是我的编辑器 (VS Code) 中的静态类型检查器 (Pylance) 不喜欢客户端代码假定

batch_result
是更窄的类型。我只能想到这两种可能的解决方案,但我觉得都不干净:

  1. 使用
    cast
    模块中的
    typing
    功能
  2. 在子类中覆盖
    predict_batch
    ,即使逻辑没有改变
python inheritance python-typing pyright
1个回答
3
投票

您可以使用泛型和继承来覆盖/缩小父类中的注释

from typing import List, Generic, TypeVar

T = TypeVar('T')


class BaseModel(ABC, Generic[T]):
    def predict_batch(self, images: List) -> List[T]:
        return [self.predict(img) for img in images]

    @abstractmethod
    def predict(self, image) -> T:
        raise NotImplementedError()


class InstanceSegModel(BaseModel[InstResult]):
    """performs instance segmentation on images"""

    @overrides
    def predict(self, image) -> InstResult:
        return InstResult()


class SemanticSegModel(BaseModel[SemResult]):
    """performs semantic segmentation on images"""

    @overrides
    def predict(self, image) -> SemResult:
        return SemResult()


images = [None, None, None]

model = InstanceSegModel()
single_result = model.predict(images[0])  # has type InstResult
batch_result = model.predict_batch(images)  # has type List[InstResult]
© www.soinside.com 2019 - 2024. All rights reserved.