我想使用 Pydantic 来定义和验证将应用于 Pandas 数据帧的查询 AST。这是我的代码:
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
import pandas as pd
class ColumnCondition(BaseModel):
"""
A class that represents a condition that is applied to a single column.
"""
tag: Literal["ColumnCondition"] = "ColumnCondition"
column: str = Field(..., title="The name of the column to apply the condition to.")
operator: Literal["==", "!=", "<", ">", "<=", ">="] = Field(
..., title="The operator of the condition."
)
value: Optional[str] = Field(None, title="The value to compare the column to.")
class AndCondition(BaseModel):
"""
A class that represents an 'and' condition that is applied to two or more conditions.
"""
tag: Literal["AndCondition"] = "AndCondition"
conditions: List["Condition"]
Condition = Union[ColumnCondition, AndCondition]
class ConditionModel(BaseModel):
condition: Condition = Field(discriminator="tag")
def get_column_metadata(df: pd.DataFrame) -> dict:
return {col: str(dtype) for col, dtype in df.dtypes.items()}
if __name__ == "__main__":
"""
Example
"""
condition_json = {
"tag": "AndCondition",
"conditions": [
{
"tag": "ColumnCondition",
"column": "original_amount.currency",
"operator": ">=",
"value": "100",
},
{
"tag": "ColumnCondition",
"column": "original_amount.currency",
"operator": "<=",
"value": "1000",
},
],
}
cond = ConditionModel.model_validate({"condition": condition_json})
print(cond.model_dump_json(indent=2))
这效果很好,但我有几个问题:
ConditionModel
包装类?我无法解决这个问题。ColumnCondition
类中拥有另一个字段吗?或者可能保存列和类型的列表?Dataframe.query
方法中使用的字符串的最佳方法是什么?我应该在每堂课中实施__str__
吗?或者编写一个遍历 AST 并创建字符串的方法?要回答有关
ConditionModel
包装器类的第一个问题,是的,您可以通过直接在 Condition 联合类型上定义鉴别器字段来删除 ConditionModel 包装器。为此,您需要使用 Pydantic 的 RootModel 类 定义自定义根模型或使用 嵌套 BaseModel 验证。
具体操作方法如下:
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
import pandas as pd
class ColumnCondition(BaseModel):
tag: Literal["ColumnCondition"] = "ColumnCondition"
column: str = Field(..., title="The name of the column to apply the condition to.")
operator: Literal["==", "!=", "<", ">", "<=", ">="] = Field(
..., title="The operator of the condition."
)
value: Optional[str] = Field(None, title="The value to compare the column to.")
class AndCondition(BaseModel):
tag: Literal["AndCondition"] = "AndCondition"
conditions: List["Condition"]
Condition = Union[ColumnCondition, AndCondition]
# Enable forward references
ColumnCondition.update_forward_refs()
AndCondition.update_forward_refs()
# Example usage without a wrapper
condition_json = {
"tag": "AndCondition",
"conditions": [
{
"tag": "ColumnCondition",
"column": "original_amount.currency",
"operator": ">=",
"value": "100",
},
{
"tag": "ColumnCondition",
"column": "original_amount.currency",
"operator": "<=",
"value": "1000",
},
],
}
cond = Condition.model_validate(condition_json)
print(cond)
你的第二个问题;一个干净的方法是增强
ColumnCondition
以包含列类型。这将使验证更加严格,并有助于避免不正确的值类型比较。例如;
class ColumnCondition(BaseModel):
tag: Literal["ColumnCondition"] = "ColumnCondition"
column: str = Field(..., title="The name of the column to apply the condition to.")
column_type: Literal["int", "float", "str"] = Field(
..., title="The data type of the column."
)
operator: Literal["==", "!=", "<", ">", "<=", ">="] = Field(
..., title="The operator of the condition."
)
value: Union[int, float, str] = Field(
..., title="The value to compare the column to."
)
你的最后一个问题,你已经在寻找正确的方向了。您可以实现一个方法来遍历 AST 并生成查询字符串。每个类中的
__str__
方法可以针对单独的条件工作,而中心函数可以处理遍历和组合条件。下面的例子
class ColumnCondition(BaseModel):
# existing fields...
def __str__(self):
return f"{self.column} {self.operator} {repr(self.value)}"
class AndCondition(BaseModel):
# existing fields...
def __str__(self):
return " and ".join([str(cond) for cond in self.conditions])
# Usage:
query_string = str(cond)
print(query_string)
上面的示例将生成如下查询字符串:
original_amount.currency >= '100' and original_amount.currency <= '1000'
希望这有帮助。