如何在Python中实现一个基于子泛型进行验证的基方法

问题描述 投票:0回答:3

我有一个基本的 Python (3.8) 抽象基类,有两个类继承自它:

BoundedModel = TypeVar("BoundedModel", bound=CustomBaseModel)

class BaseDataStore(ABC, Generic[BoundedModel]):
    @abstractmethod
    def get_all(self) -> List[BoundedModel]:
        pass

class MetadataStore(BaseDataStore[Metadata]):
    def get_all(self) -> List[Metadata]:
        items = self.client.get_all()
        return [Metadata(**item) for item in items]
    
class TranscriptStore(BaseDataStore[Transcript]):
    def get_all(self) -> List[Transcript]:
        items = self.client.get_all()
        return [Transcript(**item) for item in items]

CustomBaseModel
绑定
BoundedModel
代表一个学究阶级,意思是
Metadata
Transcript
是用于验证的 pydantic 类模型。

get_all
的具体实现都做同样的事情: 他们使用 Pydantic 有界模型验证数据。这行得通,但迫使我 为每个
BaseDataStore
孩子说出具体的实现方式

有什么方法可以在父级

get_all
中将
BaseDataStore
实现为通用方法(而不是抽象方法),从而消除对子级具体实现的需要?

python mypy pydantic
3个回答
0
投票

是的,您可以在父类 BaseDataStore 中将 get_all 实现为通用方法,方法是使用 Type 对象动态创建绑定模型类的实例并验证返回列表中的每一项。这是一个示例实现:

class BaseDataStore(ABC, Generic[BoundedModel]):
    @abstractmethod
    def get_all(self) -> List[BoundedModel]:
        pass

    def _validate_item(self, item: Dict[str, Any], model_class: Type[BoundedModel]) -> BoundedModel:
        return model_class(**item)
    
    def _get_all(self, model_class: Type[BoundedModel]) -> List[BoundedModel]:
        items = self.client.get_all()
        return [self._validate_item(item, model_class) for item in items]

在这个实现中,我们在BaseDataStore类中添加了两个私有方法:_validate_item和_get_all。 _validate_item 采用表示单个项目的字典和表示 pydantic 模型类的 Type 对象,以字典作为参数创建模型类的实例,并返回经过验证的实例。

_get_all 采用表示 pydantic 模型类的 Type 对象,并返回该类的经过验证的实例列表。它使用客户端对象的 get_all 方法获取代表项目的字典列表,然后使用 _validate_item 来验证列表中的每个项目。

有了这些私有方法,我们可以更新子类中 get_all 的具体实现,以简单地使用适当的模型类调用 _get_all:

class MetadataStore(BaseDataStore[Metadata]):
    def get_all(self) -> List[Metadata]:
        return self._get_all(Metadata)

class TranscriptStore(BaseDataStore[Transcript]):
    def get_all(self) -> List[Transcript]:
        return self._get_all(Transcript)

此实现允许您避免在每个子类中重复相同的代码,而是在父类中使用通用方法来验证和返回每个模型的数据。


0
投票

您可以避免为每个子类重新实现方法,方法是使用类 var 来存储用于实例化项的类型,该类型可以直接从 Generic 参数类型派生。

像这样:

from abc import ABC
from typing import Generic, TypeVar, Type


class CustomBaseModel:
    pass

class Metadata(CustomBaseModel):
    pass

class Transcript(CustomBaseModel):
    pass


class Client:
    def get_all(self) -> list[dict]:
        return [{}]


BoundedModel = TypeVar("BoundedModel", bound=CustomBaseModel)


class BaseDataStore(ABC, Generic[BoundedModel]):
    _item_cls: Type[BoundedModel]
    
    client = Client()
    
    def get_all(self) -> list[BoundedModel]:
        items = self.client.get_all()
        return [self._item_cls(**item) for item in items]

class MetadataStore(BaseDataStore[Metadata]):
    pass

class TranscriptStore(BaseDataStore[Transcript]):
    pass


metadata_items = MetadataStore().get_all()
# metadata_items: list[Metadata]

这种类型检查:
https://mypy-play.net/?mypy=latest&python=3.11&gist=4f50432739f25ec6ca444e787c8ee0eb

...但不幸的是它在实践中还没有实际工作,因为在运行时没有为

_item_cls
分配任何值。

我们可以通过额外的元编程来解决这个问题……

from abc import ABCMeta
from typing import Generic, TypeVar, Type, get_args


class CustomBaseModel:
    pass

class Metadata(CustomBaseModel):
    pass

class Transcript(CustomBaseModel):
    pass


class Client:
    def get_all(self) -> list[dict]:
        return [{}]


BoundedModel = TypeVar("BoundedModel", bound=CustomBaseModel)


class GenericDataStoreMetaclass(ABCMeta):
    def __new__(cls, name, bases, dct):
        cls_ = super().__new__(cls, name, bases, dct)
        for base, og_base in zip(cls_.__bases__, cls_.__orig_bases__):
            if base is BaseDataStore:
                # introspect the type param of the Generic alias
                cls_._item_cls = get_args(og_base)[0]
        return cls_

class BaseDataStore(Generic[BoundedModel], metaclass=GenericDataStoreMetaclass):
    _item_cls: Type[BoundedModel]

    client = Client()
    
    def get_all(self) -> list[BoundedModel]:
        items = self.client.get_all()
        return [self._item_cls(**item) for item in items]

class MetadataStore(BaseDataStore[Metadata]):
    pass

class TranscriptStore(BaseDataStore[Transcript]):
    pass


metadata_items = MetadataStore().get_all()
# [<__main__.Metadata at 0x108493520>]

此版本现在可在运行时运行。


0
投票

实际上可以。

通过__orig_bases__使用

这个技巧
来访问提供给特定子类的类型参数。那么
BaseDataStore
上的单个具体实现就足够了,您甚至不需要在子类中的任何地方重复类型参数。

假设您有以下模型:

from pydantic import BaseModel


class CustomBaseModel(BaseModel):
    pass


class Foo(CustomBaseModel):
    x: int


class Bar(CustomBaseModel):
    y: str

这是我提出的解决方案:

from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from typing import get_args, get_origin

BoundedModel = TypeVar("BoundedModel", bound=CustomBaseModel)

class BaseDataStore(Generic[BoundedModel]):
    _type_arg: Optional[Type[BoundedModel]] = None

    @classmethod
    def __init_subclass__(cls, **kwargs: Any) -> None:
        """Saves the type argument in the `_type_arg` class attribute."""
        super().__init_subclass__(**kwargs)
        for base in cls.__orig_bases__:  # type: ignore[attr-defined]
            origin = get_origin(base)
            if origin is None or not issubclass(origin, BaseDataStore):
                continue
            type_arg = get_args(base)[0]
            # Do not set the attribute for GENERIC subclasses!
            if not isinstance(type_arg, TypeVar):
                cls._type_arg = type_arg
                return

    @classmethod
    def get_model(cls) -> Type[BoundedModel]:
        if cls._type_arg is None:
            raise AttributeError(f"{cls.__name__} is generic; type argument unspecified")
        return cls._type_arg

    def get_all(self) -> List[BoundedModel]:
        items = self.demo_data  # just for this example
        return [self.get_model()(**item) for item in items]

    demo_data: List[Dict[str, Any]]  # just for this example

用法:

class FooStore(BaseDataStore[Foo]):
    demo_data = [{"x": 1}, {"x": -1}]


class BarStore(BaseDataStore[Bar]):
    demo_data = [{"y": "spam"}, {"y": "eggs"}]


foos = FooStore().get_all()
bars = BarStore().get_all()

print(foos)
print(bars)

输出:

[Foo(x=1), Foo(x=-1)]
[Bar(y='spam'), Bar(y='eggs')]

通过

mypy --strict
。不需要元类魔法。

© www.soinside.com 2019 - 2024. All rights reserved.