创建一个自定义 Pydantic 类,要求值位于一个区间内?

问题描述 投票:0回答:1

我正在努力使用

pydantic
构建函数参数验证库。我们希望能够验证参数的类型和值。类型很简单,但我很难创建一个类来验证值。具体来说,我想要构建的第一个类是要求值位于用户定义的区间内的类。

到目前为止,我已经编写了一个装饰器和一个基于函数的版本的

ValueInInterval
。但是,我更喜欢使用基于类的方法。这是我的问题的 MRE:

from typing import Any
from typing_extensions import Annotated
from pydantic import Field, validate_call


class ValueInInterval:
    def __init__(
        self,
        type_definition: Any,
        start: Any,
        end: Any,
        include_start: bool = True,
        include_end: bool = True,
    ):
        self.type_definition = type_definition
        self.start = start
        self.end = end
        self.include_start = include_start
        self.include_end = include_end

    def __call__(self):
        return Annotated[self.type_definition, self.create_field()]

    def create_field(self) -> Field:
        field_config = {}
        if self.include_start:
            field_config.update({"ge": self.start})
        else:
            field_config.update({"gt": self.start})
        if self.include_end:
            field_config.update({"le": self.end})
        else:
            field_config.update({"lt": self.end})

        return Field(**field_config)

    def __get_pydantic_core_schema__(
        self,
        handler,
    ):
        schema = handler(self.type_definition)

        if self.include_start:
            schema.update({"ge": self.start})
        else:
            schema.update({"gt": self.start})

        if self.include_end:
            schema.update({"le": self.end})
        else:
            schema.update({"lt": self.end})

        return schema


@validate_call()
def test_interval(
    value: ValueInInterval(type_definition=int, start=1, end=10),
):
    print(value)


test_interval(value=1)  # should succeed
test_interval(value=0)  # should fail

在 PyCharm 的 Python 控制台中运行此代码,出现以下错误:

Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm 2023.2\plugins\python\helpers\pydev\pydevconsole.py", line 364, in runcode
    coro = func()
  File "<input>", line 56, in <module>
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\validate_call_decorator.py", line 56, in validate
    validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return, local_ns)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_validate_call.py", line 57, in __init__
    schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 512, in generate_schema
    schema = self._generate_schema_inner(obj)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 789, in _generate_schema_inner
    return self.match_type(obj)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 856, in match_type
    return self._callable_schema(obj)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1692, in _callable_schema
    arg_schema = self._generate_parameter_schema(name, annotation, p.default, parameter_mode)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1414, in _generate_parameter_schema
    schema = self._apply_annotations(source_type, annotations)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1890, in _apply_annotations
    schema = get_inner_schema(source_type)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_schema_generation_shared.py", line 83, in __call__
    schema = self._handler(source_type)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1869, in inner_handler
    from_property = self._generate_schema_from_property(obj, source_type)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 677, in _generate_schema_from_property
    schema = get_schema(source)
  File "<input>", line 40, in __get_pydantic_core_schema__
TypeError: __call__() takes 1 positional argument but 2 were given

我强烈怀疑我没有写正确

__get_pydantic_core_schema__
。请注意,我使用的是
pydantic
版本 2.7 和 Python 3.11。

python python-3.x pydantic
1个回答
0
投票

您可以为您的模型实现一个工厂来实现所需的行为:

import pydantic


def model_factory(type_: type, ge: int, lt: int):
    class Model(pydantic.BaseModel):
        value: type_ = pydantic.Field(gt=ge, lt=lt)

    return Model


model = model_factory(type_=int, ge=1, lt=10)

model(value=5)  # ok
model(value=0)  # raise ValidationError
model(value="foo")  # raise ValidationError
© www.soinside.com 2019 - 2024. All rights reserved.