Pydantic 支持注释第三方类型,因此它们可以直接在 Pydantic 模型中使用,并从 JSON 反/序列化。
例如:
from typing import Annotated, Any
from pydantic import BaseModel, model_validator
from pydantic.functional_validators import ModelWrapValidatorHandler
from typing_extensions import Self
# Pretend this is some third-party class
# we can't modify directly...
class Quantity:
def __init__(self, value: float, unit: str):
self.value = value
self.unit = unit
class QuantityAnnotations(BaseModel):
value: float
unit: str
@model_validator(mode="wrap")
def _validate(value: Any, handler: ModelWrapValidatorHandler[Self]) -> Quantity:
if isinstance(value, Quantity):
return value
validated = handler(value)
if isinstance(validated, Quantity):
return validated
return Quantity(**dict(validated))
QuantityType = Annotated[Quantity, QuantityAnnotations]
class OurModel(BaseModel):
quantity: QuantityType = Quantity(value=0.0, unit='m')
这工作正常,因为我们只是注释了
Quantity
类型,因此 Pydantic 知道如何将其序列化为 JSON,没有任何问题:
model_instance = OurModel()
print(model_instance.model_dump_json())
# {"quantity":{"value":0.0,"unit":"m"}}
但是,如果我们尝试获取描述 OurModel
的
JSON Schema,我们会收到一条警告,指出它不知道如何序列化默认值(刚刚成功序列化的值)...
OurModel.model_json_schema()
# ...lib/python3.10/site-packages/pydantic/json_schema.py:2158: PydanticJsonSchemaWarning:
# Default value <__main__.Quantity object at 0x75fcccab1960> is not JSON serializable;
# excluding default from JSON schema [non-serializable-default]
我在这里缺少什么?这是 Pydantic 的 bug,还是我需要在注释中添加更多内容来告诉 Pydantic 如何在 JSON 模式上下文中序列化默认值?
有没有人有一个好的解决方法可以轻松地将带注释的第三方类型作为默认值包含在 Pydantic 生成的 JSON 架构中?
你的注释看起来有点奇怪,没有理由扩展注释
BaseModel
并且 model_validator
不应该在这里使用,这看起来很 Pydantic 1.
考虑以下类
Quantity
和 Unit
,它们具有有用的函数来检查单位之间的等效性并解析字符串中的数量(如果您使用的是 astropy.units
,这非常接近):
@dataclass(frozen=True)
class Unit:
unit: str
def is_equivalent(self, other: Unit) -> bool:
# change this to something valid
return self.unit[-1] == other.unit[-1]
@dataclass(frozen=True)
class Quantity:
value: float
unit: Unit
@staticmethod
def parse(value: str) -> Quantity:
i = next(i for i, c in enumerate(value) if not c.isdigit() and c != ".")
return Quantity(
value=float(value[:i].strip()), unit=Unit(unit=value[i:].strip())
)
下面是一个可用于注释
Quantity
对象的实现(请参阅下面的 _predicate
和 _make_pattern
):
@dataclass(frozen=True)
class QuantityAnnotated:
unit: Unit
suffix: tuple[str, ...]
def __get_pydantic_core_schema__(
self, source_type: Any, handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.no_info_plain_validator_function(
function=lambda value: _predicate(value, self.unit)
),
python_schema=core_schema.no_info_plain_validator_function(
function=lambda value: _predicate(value, self.unit)
),
serialization=core_schema.to_string_ser_schema(when_used="always"),
)
def __get_pydantic_json_schema__(
self, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
g = GenerateJsonSchema()
return g.get_flattened_anyof(
[
{
"type": "object",
"properties": {
"value": {"type": "number"},
"unit": {"type": "string", "enum": self.suffix},
},
},
{"type": "string", "pattern": _make_pattern(*self.suffix)},
]
)
使用示例:
print(Model.model_json_schema())
print(Model.model_validate(dict(quantity=Quantity(3.2, Unit("m")))))
try:
Model.model_validate(dict(quantity=Quantity(3.2, Unit("kg"))))
except ValueError as err:
...
print(Model.model_validate(dict(quantity="3.2 m")))
print(Model.model_validate(dict(quantity=dict(value=3.2, unit="m"))))
您可以在下面找到
的实现_predicate
,如果可能的话,将值解析为 Quantity
,_make_pattern
从单位后缀列表中生成有效的模式字符串。def _predicate(value: Any, unit: Unit) -> Quantity:
if isinstance(value, str):
try:
value = Quantity.parse(value)
except TypeError as err:
raise ValueError("invalid value encountered") from err
elif isinstance(value, dict):
try:
value = Quantity(value=float(value["value"]), unit=Unit(value["unit"]))
except (KeyError, TypeError, ValueError) as err:
raise ValueError("invalid value encountered") from err
if isinstance(value, Quantity):
if not unit.is_equivalent(value.unit):
raise ValueError(f"cannot convert {value} to {unit}")
else:
# or value = Quantity(value=value, unit=unit) if you want to allow this
raise ValueError("expected quantity, found raw value")
return value
def _make_pattern(*units: str):
opts = "|".join(units)
return f"^[+-]?(([0-9]+)|([0-9]*(\\.[0-9]+)?))\\s*({opts})$"