我正在构建一个 FastAPI 应用程序,其中有很多 Pydantic 模型。尽管应用程序运行良好,但正如预期的那样,OpenAPI (Swagger UI) 文档并未在
Schemas
部分下显示所有这些模型的架构。
这里是pydantic的内容
schemas.py
import socket
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
from pydantic import BaseModel, Field, validator
from typing_extensions import Literal
ResponseData = Union[List[Any], Dict[str, Any], BaseModel]
# Not visible in Swagger UI
class PageIn(BaseModel):
page_size: int = Field(default=100, gt=0)
num_pages: int = Field(default=1, gt=0, exclude=True)
start_page: int = Field(default=1, gt=0, exclude=True)
# visible under schemas on Swagger UI
class PageOut(PageIn):
total_records: int = 0
total_pages: int = 0
current_page: int = 1
class Config: # pragma: no cover
@staticmethod
def schema_extra(schema, model) -> None:
schema.get("properties").pop("num_pages")
schema.get("properties").pop("start_page")
# Not visible in Swagger UI
class BaseResponse(BaseModel):
host_: str = Field(default_factory=socket.gethostname)
message: Optional[str]
# Not visible in Swagger UI
class APIResponse(BaseResponse):
count: int = 0
location: Optional[str]
page: Optional[PageOut]
data: ResponseData
# Not visible in Swagger UI
class ErrorResponse(BaseResponse):
error: str
# visible under schemas on Swagger UI
class BaseFaultMap(BaseModel):
detection_system: Optional[str] = Field("", example="obhc")
fault_type: Optional[str] = Field("", example="disk")
team: Optional[str] = Field("", example="dctechs")
description: Optional[str] = Field(
"",
example="Hardware raid controller disk failure found. "
"Operation can continue normally,"
"but risk of data loss exist",
)
# Not visible in Swagger UI
class FaultQueryParams(BaseModel):
f_id: Optional[int] = Field(None, description="id for the host", example=12345, title="Fault ID")
hostname: Optional[str]
status: Literal["open", "closed", "all"] = Field("open")
created_by: Optional[str]
environment: Optional[str]
team: Optional[str]
fault_type: Optional[str]
detection_system: Optional[str]
inops_filters: Optional[str] = Field(None)
date_filter: Optional[str] = Field("",)
sort_by: Optional[str] = Field("created",)
sort_order: Literal["asc", "desc"] = Field("desc")
所有这些模型实际上都在 FastAPI 路径中用于验证请求正文。
FaultQueryParams
是一个自定义模型,我用它来验证请求查询参数,用法如下:
query_args: FaultQueryParams = Depends()
其余模型与
Body
字段结合使用。我无法弄清楚为什么只有某些模型在 Schemas
部分中不可见,而其他模型则可见。
我注意到的另一件事是
FaultQueryParams
是描述、示例不会显示在路径端点上,即使它们是在模型中定义的。
编辑1:
我进行了更多研究并意识到,所有在 swagger UI 中不可见的模型都没有直接在路径操作中使用,即这些模型没有被用作
response_model
或 Body
类型,而是间接使用的辅助模型类型。因此,FastAPI 似乎没有为这些模型生成架构。
上述语句的一个例外是
query_args: FaultQueryParams = Depends()
,它直接在路径操作中使用,以根据自定义模型映射端点的 Query
参数。这是一个问题,因为 swagger 没有从该模型的字段中识别像 title
、description
、example
等元参数,并且没有显示在 UI 上,这对于此端点的用户来说很重要。
有没有办法欺骗 FastAPI 为自定义模型
FaultQueryParams
生成模式,就像它为 Body
、Query
等生成模式一样?
FastAPI 将为用作 Request Body 或 Response Model 的模型生成架构。声明
query_args: FaultQueryParams = Depends()
(使用 Depends)时,您的端点不会期望 request body
,而是 query
参数;因此,FaultQueryParams
不会包含在OpenAPI文档的模式中。
要添加其他架构,您可以扩展/修改 OpenAPI 架构。下面给出了示例(确保在定义所有路由之后添加用于修改架构的代码,即在代码末尾)。
class FaultQueryParams(BaseModel):
f_id: Optional[int] = Field(None, description="id for the host", example=12345, title="Fault ID")
hostname: Optional[str]
status: Literal["open", "closed", "all"] = Field("open")
...
@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
return query_args
def get_extra_schemas():
return {
"FaultQueryParams": {
"title": "FaultQueryParams",
"type": "object",
"properties": {
"f_id": {
"title": "Fault ID",
"type": "integer",
"description": "id for the host",
"example": 12345
},
"hostname": {
"title": "Hostname",
"type": "string"
},
"status": {
"title": "Status",
"enum": [
"open",
"closed",
"all"
],
"type": "string",
"default": "open"
},
...
}
}
}
from fastapi.openapi.utils import get_openapi
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="FastAPI",
version="1.0.0",
description="This is a custom OpenAPI schema",
routes=app.routes,
)
new_schemas = openapi_schema["components"]["schemas"]
new_schemas.update(get_extra_schemas())
openapi_schema["components"]["schemas"] = new_schemas
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
您可以使用 FastAPI 在代码中添加一个端点(在获取架构后,您将随后删除该端点),从而让 FastAPI 为您完成此操作,而不是手动输入要添加到文档中的额外模型的架构。模型作为请求体或响应模型,例如:
@app.post('/predict')
def predict(query_args: FaultQueryParams):
return query_args
然后,您可以在http://127.0.0.1:8000/openapi.json获取生成的JSON模式,如文档中所述。从那里,您可以将模型的架构复制并粘贴到代码中并直接使用它(如上面的
get_extra_schema()
方法所示),或者将其保存到文件并从文件加载 JSON 数据,如下所示:
import json
...
new_schemas = openapi_schema["components"]["schemas"]
with open('extra_schemas.json') as f:
extra_schemas = json.load(f)
new_schemas.update(extra_schemas)
openapi_schema["components"]["schemas"] = new_schemas
...
要为查询参数声明元数据,例如
description
、example
等,您应该使用 Query
而不是 Field
来定义参数,因为使用 Pydantic 模型无法做到这一点,您需要直接在端点中定义 Query
参数或使用一个自定义依赖类,如下图:
from fastapi import FastAPI, Query, Depends
from typing import Optional
class FaultQueryParams:
def __init__(
self,
f_id: Optional[int] = Query(None, description="id for the host", example=12345)
):
self.f_id = f_id
app = FastAPI()
@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
return query_args
@dataclass
装饰器重写上述内容,如下所示:
from fastapi import FastAPI, Query, Depends
from typing import Optional
from dataclasses import dataclass
@dataclass
class FaultQueryParams:
f_id: Optional[int] = Query(None, description="id for the host", example=12345)
app = FastAPI()
@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
return query_args
不再需要使用自定义依赖类,因为 FastAPI 现在允许使用 Pydantic
BaseModel
通过将 Query()
包装在 Field()
中来定义查询参数;因此,人们应该能够使用 Pydantic 模型来定义多个查询参数并为其声明元数据,即 description
、example
等。相关答案可以在 here 和 here 找到。
感谢@Chris 的指点,最终引导我使用dataclasses 来批量定义查询参数,而且效果很好。
@dataclass
class FaultQueryParams1:
f_id: Optional[int] = Query(None, description="id for the host", example=55555)
hostname: Optional[str] = Query(None, example="test-host1.domain.com")
status: Literal["open", "closed", "all"] = Query(
None, description="fetch open/closed or all records", example="all"
)
created_by: Optional[str] = Query(
None,
description="fetch records created by particular user",
example="user-id",
)