在 add_where 方法中的每个 api 请求列表上附加它不会重置

问题描述 投票:0回答:1

我正在尝试在 django Rest 框架中创建与 athena 交互的自定义管理器。尽管我面临着奇怪的问题,每次我使用过滤器到达端点时,它都会附加到每个请求的列表 where_conditions=[] 中。 不应该发生实际上使转换为 SQL 的查询无效的情况。我想重置每个请求的 where_conditions 列表。

filter_class/base.py

from apps.analytics.models.filter.fields.filter import FilterField
from apps.analytics.models.filter.fields.ordering import OrderingField
from apps.analytics.models.filter.fields.search import SearchField
from apps.analytics.models.manager.query.utils import ModelQueryUtils


class BaseFilterClass:
    _ordering_filter_query_key = "ordering"
    _search_filter_query_key = "search"

    def __init__(self, request):
        self._request = request

    @property
    def filter_params(self):
        if getattr(self.Meta, "filter_fields", None):
            req_query_params = {**self._request.query_params}
            for i in req_query_params:
                if isinstance(req_query_params[i], list):
                    req_query_params[i] = req_query_params[i][0]
            if req_query_params and req_query_params.get("search"):
                del req_query_params["search"]
            req_query_params_key = set(req_query_params.keys())
            valid_filter_params = self.Meta.filter_fields.declared_filter_params.intersection(req_query_params_key)
            params = {key: self.Meta.filter_fields.get_processed_value(key.split('__')[-1],
                                                                       req_query_params[key]) if "__" in key else
            req_query_params[key] for key in valid_filter_params}
            return ModelQueryUtils.FilterInput(params, ModelQueryUtils.FilterInput.JoinOperator.AND)

    @property
    def ordering_params(self):
        if getattr(self.Meta, "ordering_fields", None):
            req_query_params_ordering_params = self._request.query_params and self._request.query_params.get(
                self._ordering_filter_query_key)
            if req_query_params_ordering_params:
                declared_ordering_params = self.Meta.ordering_fields.declared_ordering_params
                return ModelQueryUtils.OrderInput(
                    [str(param) for param in req_query_params_ordering_params.split(",") if
                     ((param.startswith("-") and param[1:]) or param) in declared_ordering_params])

    @property
    def search_params(self):
        if getattr(self.Meta, "search_fields", None):
            req_query_params_search_param = self._request.query_params and self._request.query_params.get(
                self._search_filter_query_key)
            params = {}
            if req_query_params_search_param:
                for key in self.Meta.search_fields.declared_search_params:
                    params[key] = req_query_params_search_param
            return ModelQueryUtils.SearchInput(params, ModelQueryUtils.FilterInput.JoinOperator.OR)

    class Meta:
        filter_fields = FilterField([])
        search_fields = SearchField([])
        ordering_fields = OrderingField([])


经理/查询/base.py

class BaseModelQuery:
    __slots__ = ['_model']
    # _where_conditions = None
    _order_conditions = None
    _offset: int = None
    _limit: int = None
    _full_query_mtd_name: str = None

    def __init__(self, model):
        self._model = model
        self._where_conditions = []
        self._order_conditions = []
        self._offset = None
        self._limit = None
        self._full_query_mtd_name = None

    @property
    def count_sql(self):
        if self._full_query_mtd_name:
            return getattr(self, self._full_query_mtd_name)()["count"]
        return f"""
            SELECT
                Count(*)
            FROM
                {self._model.table_name}
            {self._where_sql}
        """

    @property
    def full_sql(self):
        if self._full_query_mtd_name:
            return getattr(self, self._full_query_mtd_name)()["full"]
        return f"""
            {self._base_sql}
            {self._where_sql}
            {self._order_by_sql}
            {self._offset_sql}
            {self._limit_sql}
        """

    @property
    def _base_sql(self):
        return f"""
            SELECT
                {','.join(field_name for field_name in self._model.fields.keys())}
            FROM
                {self._model.table_name}
        """

    @property
    def _where_sql(self):
        if self._where_conditions:
            statements = [item.get_sql_statement(self._model) for item in self._where_conditions]
            print(statements)
            statements = [f"({item})" for item in statements if item]
            if statements:
                return f"""
                WHERE
                    {" AND ".join(statements)}
                """
        return ""

    @property
    def _order_by_sql(self):
        if self._order_conditions:
            statements = [item.get_sql_statement() for item in self._order_conditions]
            statements = [item for item in statements if item]
            if statements:
                return f"""
            ORDER BY
                {", ".join(statements)}
            """
        return ""

    @property
    def _offset_sql(self):
        if self._offset is not None:
            return f"""
            OFFSET {self._offset}
            """
        return ""

    @property
    def _limit_sql(self):
        if self._limit is not None:
            return f"""
            LIMIT {self._limit}
            """
        return ""

    def add_where(self, condition):
        self._where_conditions.append(condition)

    def add_order(self, condition):
        self._order_conditions.append(condition)

    def set_offset(self, offset: int):
        self._offset = offset

    def set_limit(self, limit: int):
        self._limit = limit

    def set_full_query_mtd_name(self, name: str):
        self._full_query_mtd_name = name

经理/查询/utils.py

from apps.analytics.models.fields import ModelField
from apps.analytics.models.filter.lookups import FilterFieldLookup


class ModelQueryUtils:
    class FilterInput:
        _filter_lookup_operator_mapping = {
            FilterFieldLookup.Definition.NOT_EQUAL: "<>",
            FilterFieldLookup.Definition.GREATER_THAN_EQUAL_TO: ">=",
            FilterFieldLookup.Definition.LESS_THAN_EQUAL_TO: "<=",
            FilterFieldLookup.Definition.GREATER_THAN: ">",
            FilterFieldLookup.Definition.LESS_THAN: "<",
            FilterFieldLookup.Definition.IS_IN: "IN",
            FilterFieldLookup.Definition.CONTAINS: "LIKE",
            FilterFieldLookup.Definition.ICONTAINS: "LIKE"
        }

        class JoinOperator:
            OR = "OR"
            AND = "AND"

        def __init__(self, filter_params: dict, operator: str):
            self.filter_params = filter_params
            self.operator = operator

        def get_sql_statement(self, model):
            list_conditions = []
            for key, value in self.filter_params.items():
                field, lookup = key.split("__") if "__" in key else (key, None)
                model_field: ModelField = model.fields.get(field)
                if model_field:
                    value = model_field.get_cast_statement_for_value(value)
                    list_conditions.append(
                        f"{field} {self._filter_lookup_operator_mapping.get(lookup) or '='} {value}")
            return list_conditions and f" {self.operator} ".join(list_conditions)

    class SearchInput(FilterInput):

        def get_sql_statement(self, model):
            list_conditions = []
            for key, value in self.filter_params.items():
                field, lookup = key.split("__") if "__" in key else (key, None)
                model_field: ModelField = model.fields.get(field)
                if model_field:
                    value = model_field.get_cast_statement_for_value(value)
                    value = value.strip("'")
                    list_conditions.append(
                        f"LOWER({field}) {self._filter_lookup_operator_mapping.get(lookup) or '='} LOWER('%{value}%')")
            return list_conditions and f" {self.operator} ".join(list_conditions)

    class OrderInput:
        def __init__(self, order_params: list[str]):
            self.order_params = order_params

        def get_sql_statement(self):
            return ", ".join("{}{}".format(param.lstrip("-"), " DESC" if param.startswith("-") else "") for param in
                             self.order_params)

经理/queryset/base.py

from apps.analytics.models.manager.query.base import BaseModelQuery
from apps.analytics.models.manager.query.utils import ModelQueryUtils
from apps.utils.custom_exeption.custom_exceptions import CustomException
from apps.utils.helpers.externals.aws.athena.client import AWSAthenaClient

class BaseModelQueryset:
    __slots__ = ['_query', '_annotations']

    def __init__(self, query: BaseModelQuery):
        self._query = query
        self._annotations = {}
    def __iter__(self):
        query_client = AWSAthenaClient()

        print(" ---- FULL QUERY -----------")
        print(self._query.full_sql)
        print("--- working ---")
        response, _ = query_client.execute_query(self._query.full_sql)

        if _:
            raise CustomException(
                data={
                    "message": _,
                    "errors": ["analytic_query_error"]
                }
            )
        for item in response:
            item.update(self._annotations)
            yield item

    
    def annotate(self, **annotations):
        """Custom method to handle annotations like Sum or other aggregations."""
        for alias, annotation in annotations.items():
            if isinstance(annotation, str) and annotation.lower() == 'sum':
                # Construct and execute a sum query
                query_client = AWSAthenaClient()
                sum_query = f"SELECT SUM(event_count) AS {alias} FROM {self._query.table_name}"
                response, error = query_client.execute_query(sum_query)

                if error:
                    raise CustomException(
                        data={"message": error, "errors": ["annotation_query_error"]}
                    )

                # Store the result in _annotations
                self._annotations[alias] = response[0].get(alias, 0)
        
        return self
    def count(self):
        query_client = AWSAthenaClient()
        response, _ = query_client.execute_query(self._query.count_sql)
        if _ or not (response):
            raise CustomException(
                data={
                    "message": _,
                    "errors": ["analytic_query_error"]
                }
            )
        return int(response[0]['_col0'])

    def filter(self, filter_input=None, **kwargs):
        self._query._where_conditions = []
        print(" -- Filter ---")
        print(filter_input)
        if filter_input:
            self._query.add_where(filter_input)
        else:
            print(" --- I am here ----")
            self._query.add_where(ModelQueryUtils.FilterInput(kwargs, ModelQueryUtils.FilterInput.JoinOperator.AND))
        return self

    def order_by(self, *args, order_input=None):
        if order_input:
            self._query.add_order(order_input)
        else:
            self._query.add_order(ModelQueryUtils.OrderInput(args))
        return self

    def __getitem__(self, key):
        if isinstance(key, slice):
            if key.start:
                self._query.set_offset(key.start)
            if key.stop:
                self._query.set_limit(key.stop - key.start)
        else:
            self._query.set_offset(key)
            self._query.set_limit(1)
        return self

经理/init.py


class LineIntrusionAnalyticsMananger(BaseModelManager):
    def __init__(self):
        super().__init__(queryset_class=queryset.LineIntrusionAnalyticsQueryset)
python django django-rest-framework
1个回答
0
投票

.add_where(…)
这样的方法在
QuerySet
上没有多大意义。 Django 的
QuerySet
本质上是不可变。这意味着,如果您使用
MyModel.objects.filter(some_filter1).filter(some_filter2)
,每次调用
.filter(…)
都会生成一个 new
QuerySet
,它是前一个的修改后的副本。

这可能看起来像一个细节,而且效率低下,因为我们每次都会克隆

QuerySet
,但它是一个重要的细节。事实上,在 Django 中,如果你制作一个
ListView
:

class MyListView(ListView):
    queryset = MyModel.objects.all()

然后你给

QuerySet
属性设置一个
queryset
,Django每次调用
.all()
clone
QuerySet
,这样就不会使用“parent
QuerySet
的缓存,从而迫使重新评估。出于同样的原因,如果您愿意
.filter(…)
.queryset
,它对原始版本没有任何影响。

因此,您不应该定义更改状态的方法:每个方法都应该创建一个新的QuerySet

。如果不这样做,最终可能会导致“全局状态”,从而更改“父”查询集。

© www.soinside.com 2019 - 2024. All rights reserved.