使用 Pydantic 定义布尔表达式的 AST

问题描述 投票:0回答:1

我想使用 Pydantic 来定义和验证将应用于 Pandas 数据帧的查询 AST。这是我的代码:

from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
import pandas as pd

class ColumnCondition(BaseModel):
    """
    A class that represents a condition that is applied to a single column.
    """
    tag: Literal["ColumnCondition"] = "ColumnCondition"
    column: str = Field(..., title="The name of the column to apply the condition to.")
    operator: Literal["==", "!=", "<", ">", "<=", ">="] = Field(
        ..., title="The operator of the condition."
    )
    value: Optional[str] = Field(None, title="The value to compare the column to.")


class AndCondition(BaseModel):
    """
    A class that represents an 'and' condition that is applied to two or more conditions.
    """
    tag: Literal["AndCondition"] = "AndCondition"
    conditions: List["Condition"]


Condition = Union[ColumnCondition, AndCondition]

class ConditionModel(BaseModel):
    condition: Condition = Field(discriminator="tag")

def get_column_metadata(df: pd.DataFrame) -> dict:
    return {col: str(dtype) for col, dtype in df.dtypes.items()}


if __name__ == "__main__":
    """
    Example
    """
    condition_json = {
        "tag": "AndCondition",
        "conditions": [
            {
                "tag": "ColumnCondition",
                "column": "original_amount.currency",
                "operator": ">=",
                "value": "100",
            },
            {
                "tag": "ColumnCondition",
                "column": "original_amount.currency",
                "operator": "<=",
                "value": "1000",
            },
        ],
    }

    cond = ConditionModel.model_validate({"condition": condition_json})
    print(cond.model_dump_json(indent=2))

这效果很好,但我有几个问题:

  • 有没有办法删除
    ConditionModel
    包装类?我无法解决这个问题。
  • 处理值类型的最佳方法是什么?我应该在其类型的
    ColumnCondition
    类中拥有另一个字段吗?或者可能保存列和类型的列表?
  • 将此类条件转换为要在
    Dataframe.query
    方法中使用的字符串的最佳方法是什么?我应该在每堂课中实施
    __str__
    吗?或者编写一个遍历 AST 并创建字符串的方法?
python pandas abstract-syntax-tree pydantic
1个回答
0
投票

要回答有关

ConditionModel
包装器类的第一个问题,是的,您可以通过直接在 Condition 联合类型上定义鉴别器字段来删除 ConditionModel 包装器。为此,您需要使用 Pydantic 的 RootModel 类 定义自定义根模型或使用 嵌套 BaseModel 验证

具体操作方法如下:

from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
import pandas as pd

class ColumnCondition(BaseModel):
    tag: Literal["ColumnCondition"] = "ColumnCondition"
    column: str = Field(..., title="The name of the column to apply the condition to.")
    operator: Literal["==", "!=", "<", ">", "<=", ">="] = Field(
        ..., title="The operator of the condition."
    )
    value: Optional[str] = Field(None, title="The value to compare the column to.")

class AndCondition(BaseModel):
    tag: Literal["AndCondition"] = "AndCondition"
    conditions: List["Condition"]

Condition = Union[ColumnCondition, AndCondition]

# Enable forward references
ColumnCondition.update_forward_refs()
AndCondition.update_forward_refs()

# Example usage without a wrapper
condition_json = {
    "tag": "AndCondition",
    "conditions": [
        {
            "tag": "ColumnCondition",
            "column": "original_amount.currency",
            "operator": ">=",
            "value": "100",
        },
        {
            "tag": "ColumnCondition",
            "column": "original_amount.currency",
            "operator": "<=",
            "value": "1000",
        },
    ],
}

cond = Condition.model_validate(condition_json)
print(cond)

你的第二个问题;一个干净的方法是增强

ColumnCondition
以包含列类型。这将使验证更加严格,并有助于避免不正确的值类型比较。例如;

class ColumnCondition(BaseModel):
    tag: Literal["ColumnCondition"] = "ColumnCondition"
    column: str = Field(..., title="The name of the column to apply the condition to.")
    column_type: Literal["int", "float", "str"] = Field(
        ..., title="The data type of the column."
    )
    operator: Literal["==", "!=", "<", ">", "<=", ">="] = Field(
        ..., title="The operator of the condition."
    )
    value: Union[int, float, str] = Field(
        ..., title="The value to compare the column to."
    )

你的最后一个问题,你已经在寻找正确的方向了。您可以实现一个方法来遍历 AST 并生成查询字符串。每个类中的

__str__
方法可以针对单独的条件工作,而中心函数可以处理遍历和组合条件。下面的例子

class ColumnCondition(BaseModel):
    # existing fields...

    def __str__(self):
        return f"{self.column} {self.operator} {repr(self.value)}"

class AndCondition(BaseModel):
    # existing fields...

    def __str__(self):
        return " and ".join([str(cond) for cond in self.conditions])

# Usage:
query_string = str(cond)   
print(query_string)

上面的示例将生成如下查询字符串:

original_amount.currency >= '100' and original_amount.currency <= '1000'

希望这有帮助。

© www.soinside.com 2019 - 2024. All rights reserved.