我正在尝试在 django Rest 框架中创建与 athena 交互的自定义管理器。尽管我面临着奇怪的问题,每次我使用过滤器到达端点时,它都会附加到每个请求的列表 where_conditions=[] 中。 不应该发生实际上使转换为 SQL 的查询无效的情况。我想重置每个请求的 where_conditions 列表。
filter_class/base.py
from apps.analytics.models.filter.fields.filter import FilterField
from apps.analytics.models.filter.fields.ordering import OrderingField
from apps.analytics.models.filter.fields.search import SearchField
from apps.analytics.models.manager.query.utils import ModelQueryUtils
class BaseFilterClass:
_ordering_filter_query_key = "ordering"
_search_filter_query_key = "search"
def __init__(self, request):
self._request = request
@property
def filter_params(self):
if getattr(self.Meta, "filter_fields", None):
req_query_params = {**self._request.query_params}
for i in req_query_params:
if isinstance(req_query_params[i], list):
req_query_params[i] = req_query_params[i][0]
if req_query_params and req_query_params.get("search"):
del req_query_params["search"]
req_query_params_key = set(req_query_params.keys())
valid_filter_params = self.Meta.filter_fields.declared_filter_params.intersection(req_query_params_key)
params = {key: self.Meta.filter_fields.get_processed_value(key.split('__')[-1],
req_query_params[key]) if "__" in key else
req_query_params[key] for key in valid_filter_params}
return ModelQueryUtils.FilterInput(params, ModelQueryUtils.FilterInput.JoinOperator.AND)
@property
def ordering_params(self):
if getattr(self.Meta, "ordering_fields", None):
req_query_params_ordering_params = self._request.query_params and self._request.query_params.get(
self._ordering_filter_query_key)
if req_query_params_ordering_params:
declared_ordering_params = self.Meta.ordering_fields.declared_ordering_params
return ModelQueryUtils.OrderInput(
[str(param) for param in req_query_params_ordering_params.split(",") if
((param.startswith("-") and param[1:]) or param) in declared_ordering_params])
@property
def search_params(self):
if getattr(self.Meta, "search_fields", None):
req_query_params_search_param = self._request.query_params and self._request.query_params.get(
self._search_filter_query_key)
params = {}
if req_query_params_search_param:
for key in self.Meta.search_fields.declared_search_params:
params[key] = req_query_params_search_param
return ModelQueryUtils.SearchInput(params, ModelQueryUtils.FilterInput.JoinOperator.OR)
class Meta:
filter_fields = FilterField([])
search_fields = SearchField([])
ordering_fields = OrderingField([])
经理/查询/base.py
class BaseModelQuery:
__slots__ = ['_model']
# _where_conditions = None
_order_conditions = None
_offset: int = None
_limit: int = None
_full_query_mtd_name: str = None
def __init__(self, model):
self._model = model
self._where_conditions = []
self._order_conditions = []
self._offset = None
self._limit = None
self._full_query_mtd_name = None
@property
def count_sql(self):
if self._full_query_mtd_name:
return getattr(self, self._full_query_mtd_name)()["count"]
return f"""
SELECT
Count(*)
FROM
{self._model.table_name}
{self._where_sql}
"""
@property
def full_sql(self):
if self._full_query_mtd_name:
return getattr(self, self._full_query_mtd_name)()["full"]
return f"""
{self._base_sql}
{self._where_sql}
{self._order_by_sql}
{self._offset_sql}
{self._limit_sql}
"""
@property
def _base_sql(self):
return f"""
SELECT
{','.join(field_name for field_name in self._model.fields.keys())}
FROM
{self._model.table_name}
"""
@property
def _where_sql(self):
if self._where_conditions:
statements = [item.get_sql_statement(self._model) for item in self._where_conditions]
print(statements)
statements = [f"({item})" for item in statements if item]
if statements:
return f"""
WHERE
{" AND ".join(statements)}
"""
return ""
@property
def _order_by_sql(self):
if self._order_conditions:
statements = [item.get_sql_statement() for item in self._order_conditions]
statements = [item for item in statements if item]
if statements:
return f"""
ORDER BY
{", ".join(statements)}
"""
return ""
@property
def _offset_sql(self):
if self._offset is not None:
return f"""
OFFSET {self._offset}
"""
return ""
@property
def _limit_sql(self):
if self._limit is not None:
return f"""
LIMIT {self._limit}
"""
return ""
def add_where(self, condition):
self._where_conditions.append(condition)
def add_order(self, condition):
self._order_conditions.append(condition)
def set_offset(self, offset: int):
self._offset = offset
def set_limit(self, limit: int):
self._limit = limit
def set_full_query_mtd_name(self, name: str):
self._full_query_mtd_name = name
经理/查询/utils.py
from apps.analytics.models.fields import ModelField
from apps.analytics.models.filter.lookups import FilterFieldLookup
class ModelQueryUtils:
class FilterInput:
_filter_lookup_operator_mapping = {
FilterFieldLookup.Definition.NOT_EQUAL: "<>",
FilterFieldLookup.Definition.GREATER_THAN_EQUAL_TO: ">=",
FilterFieldLookup.Definition.LESS_THAN_EQUAL_TO: "<=",
FilterFieldLookup.Definition.GREATER_THAN: ">",
FilterFieldLookup.Definition.LESS_THAN: "<",
FilterFieldLookup.Definition.IS_IN: "IN",
FilterFieldLookup.Definition.CONTAINS: "LIKE",
FilterFieldLookup.Definition.ICONTAINS: "LIKE"
}
class JoinOperator:
OR = "OR"
AND = "AND"
def __init__(self, filter_params: dict, operator: str):
self.filter_params = filter_params
self.operator = operator
def get_sql_statement(self, model):
list_conditions = []
for key, value in self.filter_params.items():
field, lookup = key.split("__") if "__" in key else (key, None)
model_field: ModelField = model.fields.get(field)
if model_field:
value = model_field.get_cast_statement_for_value(value)
list_conditions.append(
f"{field} {self._filter_lookup_operator_mapping.get(lookup) or '='} {value}")
return list_conditions and f" {self.operator} ".join(list_conditions)
class SearchInput(FilterInput):
def get_sql_statement(self, model):
list_conditions = []
for key, value in self.filter_params.items():
field, lookup = key.split("__") if "__" in key else (key, None)
model_field: ModelField = model.fields.get(field)
if model_field:
value = model_field.get_cast_statement_for_value(value)
value = value.strip("'")
list_conditions.append(
f"LOWER({field}) {self._filter_lookup_operator_mapping.get(lookup) or '='} LOWER('%{value}%')")
return list_conditions and f" {self.operator} ".join(list_conditions)
class OrderInput:
def __init__(self, order_params: list[str]):
self.order_params = order_params
def get_sql_statement(self):
return ", ".join("{}{}".format(param.lstrip("-"), " DESC" if param.startswith("-") else "") for param in
self.order_params)
经理/queryset/base.py
from apps.analytics.models.manager.query.base import BaseModelQuery
from apps.analytics.models.manager.query.utils import ModelQueryUtils
from apps.utils.custom_exeption.custom_exceptions import CustomException
from apps.utils.helpers.externals.aws.athena.client import AWSAthenaClient
class BaseModelQueryset:
__slots__ = ['_query', '_annotations']
def __init__(self, query: BaseModelQuery):
self._query = query
self._annotations = {}
def __iter__(self):
query_client = AWSAthenaClient()
print(" ---- FULL QUERY -----------")
print(self._query.full_sql)
print("--- working ---")
response, _ = query_client.execute_query(self._query.full_sql)
if _:
raise CustomException(
data={
"message": _,
"errors": ["analytic_query_error"]
}
)
for item in response:
item.update(self._annotations)
yield item
def annotate(self, **annotations):
"""Custom method to handle annotations like Sum or other aggregations."""
for alias, annotation in annotations.items():
if isinstance(annotation, str) and annotation.lower() == 'sum':
# Construct and execute a sum query
query_client = AWSAthenaClient()
sum_query = f"SELECT SUM(event_count) AS {alias} FROM {self._query.table_name}"
response, error = query_client.execute_query(sum_query)
if error:
raise CustomException(
data={"message": error, "errors": ["annotation_query_error"]}
)
# Store the result in _annotations
self._annotations[alias] = response[0].get(alias, 0)
return self
def count(self):
query_client = AWSAthenaClient()
response, _ = query_client.execute_query(self._query.count_sql)
if _ or not (response):
raise CustomException(
data={
"message": _,
"errors": ["analytic_query_error"]
}
)
return int(response[0]['_col0'])
def filter(self, filter_input=None, **kwargs):
self._query._where_conditions = []
print(" -- Filter ---")
print(filter_input)
if filter_input:
self._query.add_where(filter_input)
else:
print(" --- I am here ----")
self._query.add_where(ModelQueryUtils.FilterInput(kwargs, ModelQueryUtils.FilterInput.JoinOperator.AND))
return self
def order_by(self, *args, order_input=None):
if order_input:
self._query.add_order(order_input)
else:
self._query.add_order(ModelQueryUtils.OrderInput(args))
return self
def __getitem__(self, key):
if isinstance(key, slice):
if key.start:
self._query.set_offset(key.start)
if key.stop:
self._query.set_limit(key.stop - key.start)
else:
self._query.set_offset(key)
self._query.set_limit(1)
return self
经理/init.py
class LineIntrusionAnalyticsMananger(BaseModelManager):
def __init__(self):
super().__init__(queryset_class=queryset.LineIntrusionAnalyticsQueryset)
像
.add_where(…)
这样的方法在 QuerySet
上没有多大意义。 Django 的 QuerySet
本质上是不可变。这意味着,如果您使用 MyModel.objects.filter(some_filter1).filter(some_filter2)
,每次调用 .filter(…)
都会生成一个 new QuerySet
,它是前一个的修改后的副本。
这可能看起来像一个细节,而且效率低下,因为我们每次都会克隆
QuerySet
,但它是一个重要的细节。事实上,在 Django 中,如果你制作一个 ListView
:
class MyListView(ListView):
queryset = MyModel.objects.all()
然后你给
QuerySet
属性设置一个queryset
,Django每次调用.all()
来cloneQuerySet
,这样就不会使用“parent”QuerySet
的缓存,从而迫使重新评估。出于同样的原因,如果您愿意 .filter(…)
.queryset
,它对原始版本没有任何影响。
因此,您不应该定义更改状态的方法:每个方法都应该创建一个新的QuerySet
。如果不这样做,最终可能会导致“全局状态”,从而更改“父”查询集。