如何将asynccontextmanager下的对象的生命周期延长到FastAPI中的后台任务中?

问题描述 投票:0回答:1

我在 FastAPI 端点中使用需要启动和拆卸过程(例如,从缓存加载/保存到缓存)的对象。我使用

asynccontextmanager
来管理对象的上下文,但我也想在稍后的后台任务中处理该对象。

现在,在我的环境中 (

fastapi==0.115.5
),该对象的上下文在响应请求之前结束,但这通常早于后台任务的结束,因此某些后台任务是在上下文之外执行的。例如,如果上下文管理器的拆卸部分中有一个“保存到缓存”过程,则后台任务中的后续更改将不会被保存,因为它在拆卸过程之后运行。

关于这个要点有一个最小的(但仍然〜150行)工作示例。我也贴到这里吧

from fastapi import FastAPI, Depends, BackgroundTasks, Request
from typing import Annotated, AsyncIterator
from pydantic import BaseModel, Field
from uuid import uuid4
from contextlib import asynccontextmanager
import random
import asyncio

app = FastAPI()


class Chat(BaseModel):
    """
    This is a over-simplified Chat History Manager, that can be used in e.g. LangChain-like system
    There is an additional `total` field because history are serialized and cached on their own, and we don't want to load all histories when unserialize them from cache/database.
    """

    id: str = Field(default_factory=lambda: uuid4().hex)
    meta: str = "some meta information"
    history: list[str] = []
    total: int = 0
    uncached: int = 0

    def add_message(self, msg: str):
        self.history.append(msg)
        self.total += 1
        self.uncached += 1

    async def save(self, cache: dict):
        # cache history that are not cached
        for imsg in range(-self.uncached, 0):
            cache[f"msg:{self.id}:{self.total + imsg}"] = self.history[-self.uncached]
        self.uncached = 0
        # cache everything except history
        cache[f"sess:{self.id}"] = self.model_dump(exclude={"history"})

        print(f"saved: {self}")

    @classmethod
    async def load(cls, sess_id: str, cache: dict, max_read: int = 30):
        sess_key = f"sess:{sess_id}"
        obj = cls.model_validate(cache.get(sess_key))
        for imsg in range(max(0, obj.total - max_read), obj.total):
            obj.history.append(cache.get(f"msg:{obj.id}:{imsg}"))

        print(f"loaded: {obj}")
        return obj

    async def chat(self, msg: str, cache: dict):
        """So this"""
        self.add_message(msg)

        async def get_chat():
            resp = []
            for i in range(random.randint(3, 5)):
                # simulate long network IO
                await asyncio.sleep(0.5)
                chunk = f"resp{i}:{random.randbytes(2).hex()};"

                resp.append(chunk)
                yield chunk

            self.add_message("".join(resp))

            # NOTE to make the message cache work properly, we have to manually save this:
            # await self.save(cache)

        return get_chat()


# use a simple dict to mimic an actual cache, e.g. Redis
cache = {}


async def get_cache():
    return cache


# didn't figure out how to make Chat a dependable
# I have read https://fastapi.tiangolo.com/advanced/advanced-dependencies/#parameterized-dependencies but still no clue
# the problem is: `sess_id` is passed from user, not something we can fix just like this tutorial shows.
# As an alternative, I used this async context manager.
# Theoretically this would automatically save the Chat object after exiting the `async with` block
@asynccontextmanager
async def get_chat_from_cache(sess_id: str, cache: dict):
    """
    get object from cache (possibly create one), yield it, then save it back to cache
    """
    sess_key = f"sess:{sess_id}"
    if sess_key not in cache:
        obj = Chat()
        obj.id = sess_id
        await obj.save(cache)
    else:
        obj = await Chat.load(sess_id, cache)

    yield obj

    await obj.save(cache)


async def task(sess_id: str, task_id: int, resp_gen: AsyncIterator[str], cache: dict):
    """ """
    async for chunk in resp_gen:
        # do something with chunk, e.g. stream it to the client via a websocket
        await asyncio.sleep(0.5)
        cache[f"chunk:{sess_id}:{task_id}"] = chunk
        task_id += 1


@app.get("/{sess_id}/{task_id}/{prompt}")
async def get_chat(
    req: Request,
    sess_id: str,
    task_id: int,
    prompt: str,
    background_task: BackgroundTasks,
    cache: Annotated[dict, Depends(get_cache)],
):
    print(f"req incoming: {req.url}")
    async with get_chat_from_cache(sess_id=sess_id, cache=cache) as chat:
        resp_gen = await chat.chat(f"prompt:{prompt}", cache=cache)

        background_task.add_task(
            task, sess_id=sess_id, task_id=task_id, resp_gen=resp_gen, cache=cache
        )

    return "success"


@app.get("/{sess_id}")
async def get_sess(
    req: Request, sess_id: str, cache: Annotated[dict, Depends(get_cache)]
):
    print(f"req incoming: {req.url}")
    return (await Chat.load(sess_id=sess_id, cache=cache)).model_dump()

我发现了一个接近(但不相同)的讨论,讨论了可靠对象的使用寿命。似乎

dependable
的生命周期可以转发/延长到后台任务中,尽管他们认为这是一种挥舞行为。我确实有过让
get_chat_from_cache
成为一个基于产量的可靠的想法,尽管我不知道如何正确地做到这一点。但无论如何,FastAPI 开发人员似乎不推荐这种方法,因为拆卸可信赖对象的实际时间是未记录的行为,并且可能在未来版本中发生变化。

我知道我可能可以在后台任务中手动重复

teardown

过程,但这似乎是一种黑客行为。我问是否有更优雅的方法来做到这一点。也许有更好的设计模式可以完全避免这个问题,请告诉我。

python fastapi lifecycle contextmanager
1个回答
0
投票
后台任务在您的端点执行完成后执行。 因此,在后台任务完成之前,您无法保持上下文管理器打开。

get_chat_from_cache

 转换为依赖项不会对您有帮助(它在 FastAPI 0.106.0 之前有效,但行为已更改,现在您不能在后台任务中将依赖项与 Yield 一起使用)。

您需要考虑到这一点重新设计您的应用程序..

© www.soinside.com 2019 - 2024. All rights reserved.