我编写了自己的会话逻辑并使用以下装饰器来检查请求:
def require_session(func):
@wraps(func)
async def wrapper(*args, **kwargs):
request = kwargs['request']
session_key = request.headers.get('X-Session-Key')
if session_key is not None:
session = Session.check(session_key) # Check in db
if session is not None:
return await func(*args, **kwargs)
raise exc.access_denied()
return wrapper
我就是这样使用的
import fastapi as fa
employees_router = fa.APIRouter(default_response_class=fa.responses.JSONResponse)
@employees_router.get('/{id}')
@require_session
async def get_employee(request: fa.Request, id: int) -> EmployeeDBSchema:
employee = Employee.get_or_none(id=id)
if employee is None:
raise exc.instance_by_field_not_found(Employee, id)
return employee
我只是先在标头中检查会话密钥,然后在数据库中检查。一切正常,直到我尝试在其他方法中调用这些方法。
@administrators_router.post('')
@require_session
async def create_administrator(request: fa.Request, data: AdministratorSchema) -> AdministratorDBSchema:
data_dump = data.model_dump()
if isinstance(data_dump['employee'], int):
await get_employee(request, data_dump['employee']) # Inner call
try:
admin = Administrator.create(**data_dump)
except pw.IntegrityError as e:
raise exc.integrity(e)
return admin
这会导致会话被检查两次。 如何避免第二次会话检查?
我正在寻找类似 “只能在 fastapi 应用程序内部设置的参数” 来定义一个附加参数
inner_call=false
告诉装饰器不要检查会话,但我没有找到任何东西。
第二种解决方案是将端点与其实现分开,如下所示:
@employees_router.get('/{id}')
@require_session
async def get_employee(request: fa.Request, id: int) -> EmployeeDBSchema:
return get_employee_impl(id)
async def get_employee_impl(id: int) -> EmployeeDBSchema:
employee = Employee.get_or_none(id=id)
if employee is None:
raise exc.instance_by_field_not_found(Employee, id)
return employee
只需在我想避免会话检查的地方调用
get_employee_impl
,但重写所有函数可能会花费很多时间。另外,这个解决方案对我来说似乎不太干净。
使用contextvar,它可以在请求上下文中存储变量
修改您的代码,例如:
from contextvars import ContextVar
ctx_session = ContextVar('ctx_session', default=None)
def require_session(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if ctx_session.get() is None:
request = kwargs['request']
session_key = request.headers.get('X-Session-Key')
if session_key is not None:
session = Session.check(session_key) # Check in db
if session is not None:
ctx_session.set(session)
return await func(*args, **kwargs)
raise exc.access_denied()
else:
return await func(*args, **kwargs)
return wrapper