Pydantic:从 ValidationError loc 字段中删除可区分的联合鉴别符值

问题描述 投票:0回答:1

我目前正在将 FastAPI 用于一个大学项目,它作为我的后端。

我的应用程序的流程之一要求我将有效负载从前端发送到后端,其中一个字段有许多变体。该有效负载需要在后端使用 pydantic 进行验证。为此,我决定加入受歧视的工会。

我们假设类是这样构造的。

class ClassA(pydantic.BaseModel):
    type: Literal["a"]
    field_a: int

class ClassB(pydantic.BaseModel):
    type: Literal["b"]
    ...

ClassUnion = Annotated[Union[ClassA, ClassB], pydantic.Field(discriminator="type")]

class Parent(pydantic.BaseModel):
    child: ClassUnion
    field_parent: str

我有一个父类,它有一个子字段,其中包含 ClassA 和 ClassB 的联合,按类型区分。因此,如果 JSON 负载中的类型字段包含“a”,pydantic 会将其解析为 ClassA。受歧视的联合在快乐的路径中工作(当没有错误时),但是,如果我用无效值(例如字符串)填充 child.field_a ,我希望引发 ValidationError ,以便后端可以通知前端了解错误所在位置。

为此,我有一个自定义错误处理程序,它可以捕获来自 FastAPI 的任何 RequestValidationError,将 pydantic 的 ValidationError 中的各种错误转换为更简单的形状。错误处理程序的细节并不重要;重要的是错误处理程序取决于 ValidationError 上的“loc”字段。

例如,如果“field_parent”发生错误,那么“loc”字段将指向错误的位置;那是(“body”,“field_parent”)。请注意,“body”正是 FastAPI 请求中发生错误的位置。为了简单起见,我们可以假设 loc 只是 (“field_parent”)。

但是,如果 an 出现在“field_a”中,那么我期望 loc 应该是:(“child”,“field_a”)。这样,我可以将字段转换为与输入匹配的形状:

{
    "child": {
        "field_a": "Error message"
    }
}

然而,pydantic 似乎也将受歧视的联合鉴别器包含到“loc”字段中。请记住,“field_a”位于可区分联合(具体而言,ClassA)中,并且按“类型”进行区分。如果 type =“a”,则“child”变为 ClassA。该鉴别器值最终被包含在“loc”中,成为(“child”,“a”,“field_a”)。这显然会扰乱我的前端异常处理程序,因为字段不再与输入匹配。

有什么方法可以从 loc 字段中删除鉴别器值吗?我在 FastAPI 判别联合文档中没有找到任何可以禁用此行为的选项(如果它甚至可以完全禁用)。这不是如何自定义FastAPI请求验证错误的问题;而是 Pydantic 如何格式化“loc”字段的问题

替代解决方案:

  • 我考虑过使用WrapValidators来捕获ValidationError,然后以某种方式修改它,但是关于ValidationError内部的文档很少(查看python代码也没有帮助,因为它可能是在其他地方实现的)。

  • 手动过滤受歧视的联合值似乎是一个肮脏的解决方案,而且还存在判别器与字段共享相同名称的风险,所以我也决定反对它。

  • 我不能只为一个或两个特定路线自定义错误验证,因为模型将在多个路线上重复使用。

有什么建议吗?

python validation fastapi pydantic discriminated-union
1个回答
0
投票

我用可重复使用的“包装”解决了这个问题

model_validator
。解决方案的关键部分:

  1. 验证器需要了解鉴别器参数才能稍后将其弹出
  2. 使用
    wrap
    验证器来捕获标准
    ValidationError
    并使用新的
    loc
  3. 重新加注

在您的用例中,它确实需要创建

RootModel
的子类才能应用 model_validator。我无法找到一种以仅注释形式应用验证器的方法。

以下是重写代码的方法:

class ClassA(pydantic.BaseModel):
    type: Literal["a"]
    field_a: int

class ClassB(pydantic.BaseModel):
    type: Literal["b"]
    ...

class ClassUnion(pydantic.RootModel[Annotated[Union[ClassA, ClassB], pydantic.Field(discriminator="type")]]):
    
    _error_rewriter = discriminated_union_model_validator("type")

class Parent(pydantic.BaseModel):
    child: ClassUnion
    field_parent: str

这是代码:

def discriminated_union_model_validator(
    discriminator_field: str,
) -> PydanticDescriptorProxy[ModelValidatorDecoratorInfo]:
    """
    The same as the discriminated_union_error_rewriter just with an automatic model_validator call

    Usage example:

        class SomeSpecificRequestBody(pydantic.BaseModel):
            discriminator_param: Literal["first_option"]
            required_param: str


        class SomeOtherSpecificRequestBody(pydantic.BaseModel):
            discriminator_param: Literal["other_option"]

        class UnionRequestBodyReusableModelValidator(
            pydantic.RootModel[
                Annotated[
                    SomeSpecificRequestBody | SomeOtherSpecificRequestBody,
                    pydantic.Field(discriminator="discriminator_param"),
                ]
            ]
        ):

        _error_rewriter = discriminated_union_model_validator("discriminator_param")
    """

    return model_validator(mode="wrap")(discriminated_union_error_rewriter(discriminator_field))


def discriminated_union_error_rewriter(
    discriminator_field: str,
) -> ModelWrapValidator[_ModelType]:
    """
    Used exclusively for discriminated union containers to remove the discriminator value from the error 'loc' fields.

    Usage example:

        class SomeSpecificRequestBody(pydantic.BaseModel):
            discriminator_param: Literal["first_option"]
            required_param: str


        class SomeOtherSpecificRequestBody(pydantic.BaseModel):
            discriminator_param: Literal["other_option"]


        class UnionRequestBodyExplicitModelValidator(
            pydantic.RootModel[
                Annotated[
                    SomeSpecificRequestBody | SomeOtherSpecificRequestBody,
                    pydantic.Field(discriminator="discriminator_param"),
                ]
            ]
        ):

        @model_validator(mode="wrap")
        def _remove_discriminator_from_errors(
            cls,
            values: dict[str, Any],
            handler: ModelWrapValidatorHandler[Any],
            info: ValidationInfo,
        ):
            return discriminated_union_error_rewriter("discriminator_param")(cls, values, handler, info)


    The below example illustrates the problem this validator solves:

        from pprint import pprint
        from typing import Literal, Union, Annotated, Any

        from pydantic import BaseModel, Field, RootModel, ValidationError
        from pydantic.main import Model


        class Tiger(BaseModel):
            animal_type: Literal["tiger"] = "tiger"
            ferocity_scale: float = Field(..., ge=0, le=10)


        class Shark(BaseModel):
            animal_type: Literal["shark"] = "shark"
            ferocity_scale: float = Field(..., ge=0, le=10)


        class Lion(BaseModel):
            animal_type: Literal["lion"] = "lion"
            ferocity_scale: float


        class WildAnimal(RootModel):
            root: Annotated[Union[Tiger, Shark, Lion], Field(..., discriminator='animal_type')]


        try:
            my_shark = WildAnimal.model_validate({'animal_type': 'shark', 'ferocity_scale': 115})
        except ValidationError as exc:
            pprint(exc.errors())

    when run, this prints the following:

        [{'ctx': {'le': 10.0},
          'input': 115,
          'loc': ('shark', 'ferocity_scale'),  <------ note the 'shark' part from 'loc', when 'ferocity_scale' is all we want
          'msg': 'Input should be less than or equal to 10',
          'type': 'less_than_equal',
          'url': 'https://errors.pydantic.dev/2.7/v/less_than_equal'}]

    """

    def validate_discriminated_union(
        cls: type[_ModelType],
        values: dict[str, Any],
        handler: ModelWrapValidatorHandler[_ModelType],
        info: ValidationInfo,
    ) -> _ModelType:
        """ """
        try:
            # Let Pydantic handle the main validation
            return handler(values)
        except ValidationError as exc:
            discriminator_value = values.get(discriminator_field)
            if not discriminator_value:
                raise  # No discriminator value, proceed without adjustment

            # Adjust the error locations by removing the discriminator value from 'loc'
            adjusted_errors: list[InitErrorDetails] = []
            for error in exc.errors():  # type: ErrorDetails
                loc = error["loc"]
                if len(loc) > 1 and loc[0] == discriminator_value:
                    loc = loc[1:]  # Strip the discriminator value

                adjusted_error: InitErrorDetails = {
                    "type": error["type"],
                    "loc": loc,
                    "input": error["input"],
                }
                if "ctx" in error:
                    adjusted_error["ctx"] = error["ctx"]

                adjusted_errors.append(adjusted_error)

            # Re-raise with the adjusted errors
            raise ValidationError.from_exception_data(
                title=exc.title, line_errors=adjusted_errors, input_type=info.mode
            ) from None

    return validate_discriminated_union
© www.soinside.com 2019 - 2024. All rights reserved.