我目前正在将 FastAPI 用于一个大学项目,它作为我的后端。
我的应用程序的流程之一要求我将有效负载从前端发送到后端,其中一个字段有许多变体。该有效负载需要在后端使用 pydantic 进行验证。为此,我决定加入受歧视的工会。
我们假设类是这样构造的。
class ClassA(pydantic.BaseModel):
type: Literal["a"]
field_a: int
class ClassB(pydantic.BaseModel):
type: Literal["b"]
...
ClassUnion = Annotated[Union[ClassA, ClassB], pydantic.Field(discriminator="type")]
class Parent(pydantic.BaseModel):
child: ClassUnion
field_parent: str
我有一个父类,它有一个子字段,其中包含 ClassA 和 ClassB 的联合,按类型区分。因此,如果 JSON 负载中的类型字段包含“a”,pydantic 会将其解析为 ClassA。受歧视的联合在快乐的路径中工作(当没有错误时),但是,如果我用无效值(例如字符串)填充 child.field_a ,我希望引发 ValidationError ,以便后端可以通知前端了解错误所在位置。
为此,我有一个自定义错误处理程序,它可以捕获来自 FastAPI 的任何 RequestValidationError,将 pydantic 的 ValidationError 中的各种错误转换为更简单的形状。错误处理程序的细节并不重要;重要的是错误处理程序取决于 ValidationError 上的“loc”字段。
例如,如果“field_parent”发生错误,那么“loc”字段将指向错误的位置;那是(“body”,“field_parent”)。请注意,“body”正是 FastAPI 请求中发生错误的位置。为了简单起见,我们可以假设 loc 只是 (“field_parent”)。
但是,如果 an 出现在“field_a”中,那么我期望 loc 应该是:(“child”,“field_a”)。这样,我可以将字段转换为与输入匹配的形状:
{
"child": {
"field_a": "Error message"
}
}
然而,pydantic 似乎也将受歧视的联合鉴别器包含到“loc”字段中。请记住,“field_a”位于可区分联合(具体而言,ClassA)中,并且按“类型”进行区分。如果 type =“a”,则“child”变为 ClassA。该鉴别器值最终被包含在“loc”中,成为(“child”,“a”,“field_a”)。这显然会扰乱我的前端异常处理程序,因为字段不再与输入匹配。
有什么方法可以从 loc 字段中删除鉴别器值吗?我在 FastAPI 判别联合文档中没有找到任何可以禁用此行为的选项(如果它甚至可以完全禁用)。这不是如何自定义FastAPI请求验证错误的问题;而是 Pydantic 如何格式化“loc”字段的问题
替代解决方案:
我考虑过使用WrapValidators来捕获ValidationError,然后以某种方式修改它,但是关于ValidationError内部的文档很少(查看python代码也没有帮助,因为它可能是在其他地方实现的)。
手动过滤受歧视的联合值似乎是一个肮脏的解决方案,而且还存在判别器与字段共享相同名称的风险,所以我也决定反对它。
我不能只为一个或两个特定路线自定义错误验证,因为模型将在多个路线上重复使用。
有什么建议吗?
我用可重复使用的“包装”解决了这个问题
model_validator
。解决方案的关键部分:
wrap
验证器来捕获标准 ValidationError
并使用新的 loc
在您的用例中,它确实需要创建
RootModel
的子类才能应用 model_validator。我无法找到一种以仅注释形式应用验证器的方法。
以下是重写代码的方法:
class ClassA(pydantic.BaseModel):
type: Literal["a"]
field_a: int
class ClassB(pydantic.BaseModel):
type: Literal["b"]
...
class ClassUnion(pydantic.RootModel[Annotated[Union[ClassA, ClassB], pydantic.Field(discriminator="type")]]):
_error_rewriter = discriminated_union_model_validator("type")
class Parent(pydantic.BaseModel):
child: ClassUnion
field_parent: str
这是代码:
def discriminated_union_model_validator(
discriminator_field: str,
) -> PydanticDescriptorProxy[ModelValidatorDecoratorInfo]:
"""
The same as the discriminated_union_error_rewriter just with an automatic model_validator call
Usage example:
class SomeSpecificRequestBody(pydantic.BaseModel):
discriminator_param: Literal["first_option"]
required_param: str
class SomeOtherSpecificRequestBody(pydantic.BaseModel):
discriminator_param: Literal["other_option"]
class UnionRequestBodyReusableModelValidator(
pydantic.RootModel[
Annotated[
SomeSpecificRequestBody | SomeOtherSpecificRequestBody,
pydantic.Field(discriminator="discriminator_param"),
]
]
):
_error_rewriter = discriminated_union_model_validator("discriminator_param")
"""
return model_validator(mode="wrap")(discriminated_union_error_rewriter(discriminator_field))
def discriminated_union_error_rewriter(
discriminator_field: str,
) -> ModelWrapValidator[_ModelType]:
"""
Used exclusively for discriminated union containers to remove the discriminator value from the error 'loc' fields.
Usage example:
class SomeSpecificRequestBody(pydantic.BaseModel):
discriminator_param: Literal["first_option"]
required_param: str
class SomeOtherSpecificRequestBody(pydantic.BaseModel):
discriminator_param: Literal["other_option"]
class UnionRequestBodyExplicitModelValidator(
pydantic.RootModel[
Annotated[
SomeSpecificRequestBody | SomeOtherSpecificRequestBody,
pydantic.Field(discriminator="discriminator_param"),
]
]
):
@model_validator(mode="wrap")
def _remove_discriminator_from_errors(
cls,
values: dict[str, Any],
handler: ModelWrapValidatorHandler[Any],
info: ValidationInfo,
):
return discriminated_union_error_rewriter("discriminator_param")(cls, values, handler, info)
The below example illustrates the problem this validator solves:
from pprint import pprint
from typing import Literal, Union, Annotated, Any
from pydantic import BaseModel, Field, RootModel, ValidationError
from pydantic.main import Model
class Tiger(BaseModel):
animal_type: Literal["tiger"] = "tiger"
ferocity_scale: float = Field(..., ge=0, le=10)
class Shark(BaseModel):
animal_type: Literal["shark"] = "shark"
ferocity_scale: float = Field(..., ge=0, le=10)
class Lion(BaseModel):
animal_type: Literal["lion"] = "lion"
ferocity_scale: float
class WildAnimal(RootModel):
root: Annotated[Union[Tiger, Shark, Lion], Field(..., discriminator='animal_type')]
try:
my_shark = WildAnimal.model_validate({'animal_type': 'shark', 'ferocity_scale': 115})
except ValidationError as exc:
pprint(exc.errors())
when run, this prints the following:
[{'ctx': {'le': 10.0},
'input': 115,
'loc': ('shark', 'ferocity_scale'), <------ note the 'shark' part from 'loc', when 'ferocity_scale' is all we want
'msg': 'Input should be less than or equal to 10',
'type': 'less_than_equal',
'url': 'https://errors.pydantic.dev/2.7/v/less_than_equal'}]
"""
def validate_discriminated_union(
cls: type[_ModelType],
values: dict[str, Any],
handler: ModelWrapValidatorHandler[_ModelType],
info: ValidationInfo,
) -> _ModelType:
""" """
try:
# Let Pydantic handle the main validation
return handler(values)
except ValidationError as exc:
discriminator_value = values.get(discriminator_field)
if not discriminator_value:
raise # No discriminator value, proceed without adjustment
# Adjust the error locations by removing the discriminator value from 'loc'
adjusted_errors: list[InitErrorDetails] = []
for error in exc.errors(): # type: ErrorDetails
loc = error["loc"]
if len(loc) > 1 and loc[0] == discriminator_value:
loc = loc[1:] # Strip the discriminator value
adjusted_error: InitErrorDetails = {
"type": error["type"],
"loc": loc,
"input": error["input"],
}
if "ctx" in error:
adjusted_error["ctx"] = error["ctx"]
adjusted_errors.append(adjusted_error)
# Re-raise with the adjusted errors
raise ValidationError.from_exception_data(
title=exc.title, line_errors=adjusted_errors, input_type=info.mode
) from None
return validate_discriminated_union