我正在努力使用
pydantic
构建函数参数验证库。我们希望能够验证参数的类型和值。类型很简单,但我很难创建一个类来验证值。具体来说,我想要构建的第一个类是要求值位于用户定义的区间内的类。
到目前为止,我已经编写了一个装饰器和一个基于函数的版本的
ValueInInterval
。但是,我更喜欢使用基于类的方法。这是我的问题的 MRE:
from typing import Any
from typing_extensions import Annotated
from pydantic import Field, validate_call
class ValueInInterval:
def __init__(
self,
type_definition: Any,
start: Any,
end: Any,
include_start: bool = True,
include_end: bool = True,
):
self.type_definition = type_definition
self.start = start
self.end = end
self.include_start = include_start
self.include_end = include_end
def __call__(self):
return Annotated[self.type_definition, self.create_field()]
def create_field(self) -> Field:
field_config = {}
if self.include_start:
field_config.update({"ge": self.start})
else:
field_config.update({"gt": self.start})
if self.include_end:
field_config.update({"le": self.end})
else:
field_config.update({"lt": self.end})
return Field(**field_config)
def __get_pydantic_core_schema__(
self,
handler,
):
schema = handler(self.type_definition)
if self.include_start:
schema.update({"ge": self.start})
else:
schema.update({"gt": self.start})
if self.include_end:
schema.update({"le": self.end})
else:
schema.update({"lt": self.end})
return schema
@validate_call()
def test_interval(
value: ValueInInterval(type_definition=int, start=1, end=10),
):
print(value)
test_interval(value=1) # should succeed
test_interval(value=0) # should fail
在 PyCharm 的 Python 控制台中运行此代码,出现以下错误:
Traceback (most recent call last):
File "C:\Program Files\JetBrains\PyCharm 2023.2\plugins\python\helpers\pydev\pydevconsole.py", line 364, in runcode
coro = func()
File "<input>", line 56, in <module>
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\validate_call_decorator.py", line 56, in validate
validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return, local_ns)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_validate_call.py", line 57, in __init__
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 512, in generate_schema
schema = self._generate_schema_inner(obj)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 789, in _generate_schema_inner
return self.match_type(obj)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 856, in match_type
return self._callable_schema(obj)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1692, in _callable_schema
arg_schema = self._generate_parameter_schema(name, annotation, p.default, parameter_mode)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1414, in _generate_parameter_schema
schema = self._apply_annotations(source_type, annotations)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1890, in _apply_annotations
schema = get_inner_schema(source_type)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_schema_generation_shared.py", line 83, in __call__
schema = self._handler(source_type)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1869, in inner_handler
from_property = self._generate_schema_from_property(obj, source_type)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 677, in _generate_schema_from_property
schema = get_schema(source)
File "<input>", line 40, in __get_pydantic_core_schema__
TypeError: __call__() takes 1 positional argument but 2 were given
我强烈怀疑我没有写正确
__get_pydantic_core_schema__
。请注意,我使用的是 pydantic
版本 2.7 和 Python 3.11。
您可以为您的模型实现一个工厂来实现所需的行为:
import pydantic
def model_factory(type_: type, ge: int, lt: int):
class Model(pydantic.BaseModel):
value: type_ = pydantic.Field(gt=ge, lt=lt)
return Model
model = model_factory(type_=int, ge=1, lt=10)
model(value=5) # ok
model(value=0) # raise ValidationError
model(value="foo") # raise ValidationError