我一直在我的 fastAPI 应用程序上使用同步 SQLA 会话。 现在我们正在转向异步数据库调用。帮助我们在每个单元测试后回滚数据库交互的代码不再起作用
我的代码片段基于https://stackoverflow.com/a/67348153/6495199,我根据这个建议更新了它https://github.com/sqlalchemy/sqlalchemy/issues/5811#issuecomment- 756269881
test_engine = create_async_engine(get_db_url(), poolclass=StaticPool)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine, class_=AsyncSession)
@pytest.fixture()
async def db_session():
connection = await test_engine.connect()
transaction = await connection.begin()
await connection.begin_nested()
session = TestingSessionLocal(bind=connection)
@sqlalchemy.event.listens_for(session.sync_session, "after_transaction_end")
def end_savepoint(session, transaction):
if connection.closed:
return
if not connection.in_nested_transaction():
connection.sync_connection.begin_nested()
yield session
# Rollback the overall transaction, restoring the state before the test ran.
await session.close()
await transaction.rollback()
await connection.close()
@pytest.fixture()
def client(db_session):
# import ipdb; ipdb.set_trace()
# async def override_get_db():
def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
yield TestClient(app)
del app.dependency_overrides[get_db]
def test_get_by_id(client):
response = client.get("/users/1")
assert response.status_code == 404
在我的代码的某个时刻,我正在运行这个函数,但我收到了错误
async def find_by_id(cls, session: AsyncSession, pk: int):
return await session.get(cls, pk)
# AttributeError: 'async_generator' object has no attribute 'get'