Celery 任务 ECS 终止问题 - 需要帮助更新装饰器以处理 ProtectionEnabled 状态更改

问题描述 投票:0回答:1

说明:

我有一个 Django 应用程序,在 AWS Elastic Container Service (ECS) 上运行多个 Celery 任务,使用 SQS 作为代理。我遇到一个问题,即一旦前一个任务完成,就会在现有 ECS 任务中启动 Celery 任务。出现此问题的原因是我的装饰器将 ProtectionEnabled 的状态从 true 更改为 false,几秒钟后,ECS 任务终止。新启动的任务将无法工作。下面是我运行来启动 celery 任务的命令。

celery -A myapp_settings.celery worker --concurrency=1 l info -Q sqs-celery

我正在使用 CloudWatch 上的警报来检查代理中的消息并终止那些已完成的 ECS 任务。问题是,一旦前一个任务完成,celery 就会在现有 ECS 任务中启动任务。这不是问题,但我的装饰器将 ProtectionEnabled 的状态从 true 更改为 false,20 秒后 ECS 任务终止,新启动的任务不再工作。

问题:

我正在考虑更新我的装饰器,以便在新的 Celery 任务启动时将 ProtectionEnabled 值从 false 改回 true,但我不确定如何实现这一点。

代码:

container_decorator.py

class ContainerAgent:
    class Error(Exception):
        pass

    class RequestError(Error, IOError):
        pass

    def __init__(
        self,
        ecs_agent_uri: str,
        timeout: int = 10,
        session: requests.Session = None,
        logger: logging.Logger = None,
    ) -> None:
        self._ecs_agent_uri = ecs_agent_uri
        self._timeout = timeout

        self._session = session or requests.Session()
        self._logger = logger or logging.getLogger(self.__class__.__name__)

    def _request(self, *, path: str, data: Optional[dict] = None) -> dict:
        url = f"{self._ecs_agent_uri}{path}"
        self._logger.info(f"Performing request... {url=} {data=}")

        try:
            response = self._session.put(url=url, json=data, timeout=self._timeout)
            self._logger.info(f"Got response. {response.status_code=} {response.content=}")

            response.raise_for_status()
            return response.json()
        except requests.RequestException as e:
            response_body = e.response.text if e.response is not None else None
            self._logger.warning(f"Request error! {url=} {data=} {e=} {response_body=}")

            raise self.RequestError(str(e)) from e

    def toggle_scale_in_protection(self, *, enable: bool = True, expire_in_minutes: int = 2880):
        response = self._request(
            path="/task-protection/v1/state",
            data={"ProtectionEnabled": enable, "ExpiresInMinutes": expire_in_minutes},
        )

        try:
            return response["protection"]["ProtectionEnabled"]
        except KeyError as e:
            raise self.Error(f"Task scale-in protection endpoint error: {response=}") from e


def enable_scale_in_protection(*, logger: logging.Logger = None):
    def decorator(f):
        if not (ecs_agent_uri := os.getenv("ECS_AGENT_URI")):
            (logger or logging).warning(f"Scale-in protection not enabled. {ecs_agent_uri=}")
            return f

        client = ContainerAgent(ecs_agent_uri=ecs_agent_uri, logger=logger)

        @wraps(f)
        def wrapper(*args, **kwargs):
            try:
                client.toggle_scale_in_protection(enable=True)
            except client.Error as e:
                (logger or logging).warning(f"Scale-in protection not enabled. {e}")
                protection_set = False
            else:
                protection_set = True

            try:
                return f(*args, **kwargs)
            finally:
                if protection_set:
                    client.toggle_scale_in_protection(enable=False)

        return wrapper
    return decorator

celery_tasks.py

@shared_task(name="add_spider_schedule", base=AbortableTask)
@enable_scale_in_protection(logger=get_task_logger(__name__))
def add_spider_schedule(user_id, spider_id):
    settings_module = os.environ.get('DJANGO_SETTINGS_MODULE')
    if settings_module == 'myapp_settings.settings.production':
        return add_spider_schedule_production(user_id, spider_id, add_spider_schedule)
    else:
        return print('Unknown settings module')

def add_spider_schedule_production(user_id, spider_id, task_object):
    """
    Adds the schedule for the specified spider.
    :param spider_id: The ID of the spider to schedule.
    :return: A string representation of the spider and task IDs.
    """
    # below is the logging setup to include all 'prints' (both in the below & in the spider script) in the logger
    logger = get_task_logger(task_object.request.id)
    old_outs = sys.stdout, sys.stderr
    rlevel = add_spider_schedule.app.conf.worker_redirect_stdouts_level
    add_spider_schedule.app.log.redirect_stdouts_to_logger(logger, rlevel)

    # Get the Spider model instance
    spider = Spider.objects.get(id=spider_id)

    # Get the current user
    user = User.objects.get(id=user_id)

    # Get the names of the relevant files from the model instance
    spider_config_file = spider.spider_config_file.file
    yaml_config_file = spider.yaml_config_file.file
    template_file = spider.template_file.file
    mongodb_database_name = spider.mongodb_collection.database_name
    mongodb_collection_name = spider.mongodb_collection.collection_name

    # Read the contents of the files from the S3 bucket
    spider_config_file_contents = load_content_from_s3(AWS_STORAGE_BUCKET_NAME, rf"{PUBLIC_MEDIA_LOCATION}/{spider_config_file}")
    yaml_config_path = load_content_from_s3(AWS_STORAGE_BUCKET_NAME, rf"{PUBLIC_MEDIA_LOCATION}/{yaml_config_file}")
    input_file_path = load_content_from_s3(AWS_STORAGE_BUCKET_NAME, rf"{PUBLIC_MEDIA_LOCATION}/{template_file}")

    # Convert the JSON-encoded keyword arguments to a dictionary
    kwargs = json.loads(spider.kwargs) if spider.kwargs else {}

    # Create a module from the contents of the spider_config_file
    spider_module = import_module(spider_config_file_contents, "spider_config")
 
    is_scraping_finished = False

    async def run_spider():
        try:
            await spider_module.run(
                yaml_config_path=yaml_config_path,
                # page_type = page_type,
                # fields_to_scrape = fields_to_scrape,
                input_file_path=input_file_path,
                mongodb_name=mongodb_database_name,
                mongodb_collection_name=mongodb_collection_name,
                task_object=task_object,
                mode="sf-lab",
                **kwargs
            )
            nonlocal is_scraping_finished
            # print(f"CELERY TASK OBJECT DETAILS: {task_object.request}")
            is_scraping_finished = True
        except Exception as e:
            raise Exception(f"An error occurred while running the spider: {e}")

    async def check_if_aborted():
        while True:
            if task_object.is_aborted():
                print("Parralel function detected that task was cancelled.")
                raise Exception("task was cancelled")
            elif is_scraping_finished:
                # print("Scraping finished - breaking check-if-aborted loop")
                break
            await asyncio.sleep(0.1)

    loop = asyncio.get_event_loop()
    loop.run_until_complete(asyncio.gather(run_spider(), check_if_aborted()))

    sys.stdout, sys.stderr = old_outs  # needed for logging part

    return f"[spider: {spider_id}, task_id: {task_object.request.id}]"

python django celery amazon-ecs
1个回答
0
投票

您走在正确的道路上。您不必总是在受保护任务完成后禁用缩容保护,而是需要检查当前 Worker 是否还有需要保护的剩余任务,如果没有则仅禁用缩容保护(假设每个 Worker 中有一个 Worker) ECS 任务)。

我处理这个问题的方法是为 celery 任务设置一个专用队列,该队列应该受到保护,免受缩减事件的影响:

from celery import Celery
from kombu import Exchange, Queue

# Init App
app = Celery(...)

# Set up task Queues
app.conf.update(
  task_default_queue="default",
  task_default_exchange_type="direct",
  task_default_routing_key="default",
)
app.conf.task_queues = (
  Queue("default", Exchange("default"), routing_key="default"),
  Queue(
    "scale-in-protection",
    Exchange("scale-in-protection"),
    routing_key="scale-in-protection",
  ),
)

接下来,我使用 celeryd_after_setup 信号将 celery 工作 ID 存储在环境变量中,以便可以从任务中访问它。

@celeryd_after_setup.connect
def store_celery_worker_id(sender, instance, **kwargs):
    """Store the current worker ID as an environment variable upon startup."""
    os.environ["CELERY_CURRENT_WORKER_ID"] = sender

最后我将你写的装饰器修改如下:

def task_with_scale_in_protection(*, logger: Optional[logging.Logger] = None):

    def decorator(f):
        ecs_agent_uri = os.getenv("ECS_AGENT_URI")
        if not (ecs_agent_uri):
            (logger or logging).warning(
                f"Scale-in protection not enabled. {ecs_agent_uri}"
            )
            return f
        client = ContainerAgent(ecs_agent_uri=ecs_agent_uri, logger=logger)


        @app.task(bind=True, queue="scale-in-protection")
        @wraps(f)
        def wrapper(self, *args, **kwargs):
            protection_set = True
            try:
                client.toggle_scale_in_protection(enable=True)
                pass
            except client.Error as e:
                (logger or logging).warning(f"Scale-in protection not enabled. {e}")
                protection_set = False
            else:
                protection_set = True

            try:
                # Run the celery task
                return f(*args, **kwargs)
            finally:
                if protection_set:
                    current_worker_id = os.getenv("CELERY_CURRENT_WORKER_ID")
                    if not current_worker_id:
                        (logger or logging).warning(
                            "Current Worker ID not set. Leaving scale-in protection enabled."
                        )
                    else:
                        # Determine if the worker has any remaining protected tasks
                        # If not, disable scale-in protection.
                        # NOTE: This is currently vulnerable to a race condition.
                        # If two or more protected tasks finish and run this code at the same time,
                        # they may see each other as active skip disabling scale-in protection. 
                        worker_has_remaining_protected_tasks = False
                        i = app.control.inspect()
                        for task_list in [i.active(), i.reserved(), i.scheduled()]:
                            for task in task_list.get(current_worker_id):
                                if task["id"] != self.request.id and task["delivery_info"]["routing_key"] == "scale-in-protection":
                                    worker_has_remaining_protected_tasks = True
                        if worker_has_remaining_protected_tasks:
                            (logger or logging).info(
                                "Worker has remaining protected tasks. Leaving scale-in protection enabled."
                            )
                        else:
                            (logger or logging).info(
                                "Worker has no remaining protected tasks. Disabling scale-in protection."
                        )
                            client.toggle_scale_in_protection(enable=False)

        return wrapper

    return decorator

使用方法如下:

@task_with_scale_in_protection
def protected_task():
    ...
© www.soinside.com 2019 - 2024. All rights reserved.