说明:
我有一个 Django 应用程序,在 AWS Elastic Container Service (ECS) 上运行多个 Celery 任务,使用 SQS 作为代理。我遇到一个问题,即一旦前一个任务完成,就会在现有 ECS 任务中启动 Celery 任务。出现此问题的原因是我的装饰器将 ProtectionEnabled 的状态从 true 更改为 false,几秒钟后,ECS 任务终止。新启动的任务将无法工作。下面是我运行来启动 celery 任务的命令。
celery -A myapp_settings.celery worker --concurrency=1 l info -Q sqs-celery
我正在使用 CloudWatch 上的警报来检查代理中的消息并终止那些已完成的 ECS 任务。问题是,一旦前一个任务完成,celery 就会在现有 ECS 任务中启动任务。这不是问题,但我的装饰器将 ProtectionEnabled 的状态从 true 更改为 false,20 秒后 ECS 任务终止,新启动的任务不再工作。
问题:
我正在考虑更新我的装饰器,以便在新的 Celery 任务启动时将 ProtectionEnabled 值从 false 改回 true,但我不确定如何实现这一点。
代码:
container_decorator.py
class ContainerAgent:
class Error(Exception):
pass
class RequestError(Error, IOError):
pass
def __init__(
self,
ecs_agent_uri: str,
timeout: int = 10,
session: requests.Session = None,
logger: logging.Logger = None,
) -> None:
self._ecs_agent_uri = ecs_agent_uri
self._timeout = timeout
self._session = session or requests.Session()
self._logger = logger or logging.getLogger(self.__class__.__name__)
def _request(self, *, path: str, data: Optional[dict] = None) -> dict:
url = f"{self._ecs_agent_uri}{path}"
self._logger.info(f"Performing request... {url=} {data=}")
try:
response = self._session.put(url=url, json=data, timeout=self._timeout)
self._logger.info(f"Got response. {response.status_code=} {response.content=}")
response.raise_for_status()
return response.json()
except requests.RequestException as e:
response_body = e.response.text if e.response is not None else None
self._logger.warning(f"Request error! {url=} {data=} {e=} {response_body=}")
raise self.RequestError(str(e)) from e
def toggle_scale_in_protection(self, *, enable: bool = True, expire_in_minutes: int = 2880):
response = self._request(
path="/task-protection/v1/state",
data={"ProtectionEnabled": enable, "ExpiresInMinutes": expire_in_minutes},
)
try:
return response["protection"]["ProtectionEnabled"]
except KeyError as e:
raise self.Error(f"Task scale-in protection endpoint error: {response=}") from e
def enable_scale_in_protection(*, logger: logging.Logger = None):
def decorator(f):
if not (ecs_agent_uri := os.getenv("ECS_AGENT_URI")):
(logger or logging).warning(f"Scale-in protection not enabled. {ecs_agent_uri=}")
return f
client = ContainerAgent(ecs_agent_uri=ecs_agent_uri, logger=logger)
@wraps(f)
def wrapper(*args, **kwargs):
try:
client.toggle_scale_in_protection(enable=True)
except client.Error as e:
(logger or logging).warning(f"Scale-in protection not enabled. {e}")
protection_set = False
else:
protection_set = True
try:
return f(*args, **kwargs)
finally:
if protection_set:
client.toggle_scale_in_protection(enable=False)
return wrapper
return decorator
celery_tasks.py
@shared_task(name="add_spider_schedule", base=AbortableTask)
@enable_scale_in_protection(logger=get_task_logger(__name__))
def add_spider_schedule(user_id, spider_id):
settings_module = os.environ.get('DJANGO_SETTINGS_MODULE')
if settings_module == 'myapp_settings.settings.production':
return add_spider_schedule_production(user_id, spider_id, add_spider_schedule)
else:
return print('Unknown settings module')
def add_spider_schedule_production(user_id, spider_id, task_object):
"""
Adds the schedule for the specified spider.
:param spider_id: The ID of the spider to schedule.
:return: A string representation of the spider and task IDs.
"""
# below is the logging setup to include all 'prints' (both in the below & in the spider script) in the logger
logger = get_task_logger(task_object.request.id)
old_outs = sys.stdout, sys.stderr
rlevel = add_spider_schedule.app.conf.worker_redirect_stdouts_level
add_spider_schedule.app.log.redirect_stdouts_to_logger(logger, rlevel)
# Get the Spider model instance
spider = Spider.objects.get(id=spider_id)
# Get the current user
user = User.objects.get(id=user_id)
# Get the names of the relevant files from the model instance
spider_config_file = spider.spider_config_file.file
yaml_config_file = spider.yaml_config_file.file
template_file = spider.template_file.file
mongodb_database_name = spider.mongodb_collection.database_name
mongodb_collection_name = spider.mongodb_collection.collection_name
# Read the contents of the files from the S3 bucket
spider_config_file_contents = load_content_from_s3(AWS_STORAGE_BUCKET_NAME, rf"{PUBLIC_MEDIA_LOCATION}/{spider_config_file}")
yaml_config_path = load_content_from_s3(AWS_STORAGE_BUCKET_NAME, rf"{PUBLIC_MEDIA_LOCATION}/{yaml_config_file}")
input_file_path = load_content_from_s3(AWS_STORAGE_BUCKET_NAME, rf"{PUBLIC_MEDIA_LOCATION}/{template_file}")
# Convert the JSON-encoded keyword arguments to a dictionary
kwargs = json.loads(spider.kwargs) if spider.kwargs else {}
# Create a module from the contents of the spider_config_file
spider_module = import_module(spider_config_file_contents, "spider_config")
is_scraping_finished = False
async def run_spider():
try:
await spider_module.run(
yaml_config_path=yaml_config_path,
# page_type = page_type,
# fields_to_scrape = fields_to_scrape,
input_file_path=input_file_path,
mongodb_name=mongodb_database_name,
mongodb_collection_name=mongodb_collection_name,
task_object=task_object,
mode="sf-lab",
**kwargs
)
nonlocal is_scraping_finished
# print(f"CELERY TASK OBJECT DETAILS: {task_object.request}")
is_scraping_finished = True
except Exception as e:
raise Exception(f"An error occurred while running the spider: {e}")
async def check_if_aborted():
while True:
if task_object.is_aborted():
print("Parralel function detected that task was cancelled.")
raise Exception("task was cancelled")
elif is_scraping_finished:
# print("Scraping finished - breaking check-if-aborted loop")
break
await asyncio.sleep(0.1)
loop = asyncio.get_event_loop()
loop.run_until_complete(asyncio.gather(run_spider(), check_if_aborted()))
sys.stdout, sys.stderr = old_outs # needed for logging part
return f"[spider: {spider_id}, task_id: {task_object.request.id}]"
您走在正确的道路上。您不必总是在受保护任务完成后禁用缩容保护,而是需要检查当前 Worker 是否还有需要保护的剩余任务,如果没有则仅禁用缩容保护(假设每个 Worker 中有一个 Worker) ECS 任务)。
我处理这个问题的方法是为 celery 任务设置一个专用队列,该队列应该受到保护,免受缩减事件的影响:
from celery import Celery
from kombu import Exchange, Queue
# Init App
app = Celery(...)
# Set up task Queues
app.conf.update(
task_default_queue="default",
task_default_exchange_type="direct",
task_default_routing_key="default",
)
app.conf.task_queues = (
Queue("default", Exchange("default"), routing_key="default"),
Queue(
"scale-in-protection",
Exchange("scale-in-protection"),
routing_key="scale-in-protection",
),
)
接下来,我使用 celeryd_after_setup 信号将 celery 工作 ID 存储在环境变量中,以便可以从任务中访问它。
@celeryd_after_setup.connect
def store_celery_worker_id(sender, instance, **kwargs):
"""Store the current worker ID as an environment variable upon startup."""
os.environ["CELERY_CURRENT_WORKER_ID"] = sender
最后我将你写的装饰器修改如下:
def task_with_scale_in_protection(*, logger: Optional[logging.Logger] = None):
def decorator(f):
ecs_agent_uri = os.getenv("ECS_AGENT_URI")
if not (ecs_agent_uri):
(logger or logging).warning(
f"Scale-in protection not enabled. {ecs_agent_uri}"
)
return f
client = ContainerAgent(ecs_agent_uri=ecs_agent_uri, logger=logger)
@app.task(bind=True, queue="scale-in-protection")
@wraps(f)
def wrapper(self, *args, **kwargs):
protection_set = True
try:
client.toggle_scale_in_protection(enable=True)
pass
except client.Error as e:
(logger or logging).warning(f"Scale-in protection not enabled. {e}")
protection_set = False
else:
protection_set = True
try:
# Run the celery task
return f(*args, **kwargs)
finally:
if protection_set:
current_worker_id = os.getenv("CELERY_CURRENT_WORKER_ID")
if not current_worker_id:
(logger or logging).warning(
"Current Worker ID not set. Leaving scale-in protection enabled."
)
else:
# Determine if the worker has any remaining protected tasks
# If not, disable scale-in protection.
# NOTE: This is currently vulnerable to a race condition.
# If two or more protected tasks finish and run this code at the same time,
# they may see each other as active skip disabling scale-in protection.
worker_has_remaining_protected_tasks = False
i = app.control.inspect()
for task_list in [i.active(), i.reserved(), i.scheduled()]:
for task in task_list.get(current_worker_id):
if task["id"] != self.request.id and task["delivery_info"]["routing_key"] == "scale-in-protection":
worker_has_remaining_protected_tasks = True
if worker_has_remaining_protected_tasks:
(logger or logging).info(
"Worker has remaining protected tasks. Leaving scale-in protection enabled."
)
else:
(logger or logging).info(
"Worker has no remaining protected tasks. Disabling scale-in protection."
)
client.toggle_scale_in_protection(enable=False)
return wrapper
return decorator
使用方法如下:
@task_with_scale_in_protection
def protected_task():
...