pytest 上异步会话的正确使用

问题描述 投票:0回答:1

鉴于以下实现,我尝试使用异步会话进行测试。我的尝试如下:

models.py

from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncConnection

class Paginator:
    def __init__(
        self, 
        conn: Union[Connection, AsyncConnection],
        query: str, 
        params: dict = None, 
        batch_size: int = 10
    ):
        self.conn =  conn
        self.query = query
        self.params = params
        self.batch_size = batch_size
        self.current_offset = 0
        self.total_count = None

    async def _get_total_count_async(self) -> int:
        """Fetch the total count of records asynchronously."""
        count_query = f"SELECT COUNT(*) FROM ({self.query}) as total"
        query=text(count_query).bindparams(**(self.params or {}))
        result = await self.conn.execute(query)
        return result.scalar()

test_models.py

@pytest.fixture(scope='function')
async def async_session():
    async_engine=create_async_engine('postgresql+asyncpg://localhost:5432/db')
    async_session = sessionmaker(
        expire_on_commit=False,
        autocommit=False,
        autoflush=False,
        bind=async_engine,
        class_=AsyncSession,
    )

    async with async_session() as session:
        await session.begin()

        yield session

        await session.rollback()

@pytest.mark.asyncio
async def test_get_total_count_async(async_session):
    # Prepare the paginator
    paginator = Paginator(
        conn=session,
        query="SELECT * FROM test_table",
        batch_size=2
    )

    # Perform the total count query asynchronously
    total_count = await paginator._get_total_count_async()

    # Assertion to verify the result
    assert total_count == 0

当我运行命令

pytest
时,我收到以下错误:
AttributeError: 'async_generator' object has no attribute 'execute'
。我很确定,有一个简单的方法可以做到这一点,但我不知道。

python async-await pytest python-asyncio
1个回答
0
投票

您应该将

AsyncConnection
的实例传递给
Paginator
类,但您直接发送
session
本身。

要解决此问题,有两种可能的方法:

  1. 解决
    session
    以在测试中达到
    AsyncConnection
    功能:
@pytest.mark.asyncio
async def test_get_total_count_async(async_session):
    async for conn in async_session:
        paginator = Paginator(
            conn=conn,
            query="SELECT * FROM test_table",
            batch_size=2
        )
        ...
  1. Uging
    pytest_asyncio
    夹具的 PyPI 包:
@pytest_asyncio.fixture
async def async_session():
    async_engine=create_async_engine('postgresql+asyncpg://localhost:5432/db')
    async_session = sessionmaker(
        expire_on_commit=False,
        autocommit=False,
        autoflush=False,
        bind=async_engine,
        class_=AsyncSession,
    )

    async with async_session() as session:
        await session.begin()
        yield session
        await session.rollback()

@pytest.mark.asyncio
async def test_get_total_count_async(async_session):
    # Prepare the paginator
    paginator = Paginator(
        conn=session,
        query="SELECT * FROM test_table",
        batch_size=2
    )

    # Perform the total count query asynchronously
    total_count = await paginator._get_total_count_async()

    # Assertion to verify the result
    assert total_count == 0

这是关于此问题的帖子

© www.soinside.com 2019 - 2024. All rights reserved.