跳到主要内容

中间件

前言

在Web开发中,通常需要一种机制处理请求前和响应后的一些钩子函数,通常把这类函数称为中间件。FastAPI的中间件是在应用程序处理HTTP请求和响应之前或之后执行功能的一个组件。中间件允许用户对HTTP请求进行重写、过滤、修改或添加信息,以及对HTTP响应进行修改或处理。

HTTP请求中间件

from fastapi import Request

# 为每个HTTP请求计算请求响应耗时, 并在控制台输出
@app.middleware("http")
async def add_process_time(request: Request, call_next):
start_time = time.time()
resp = await call_next(request)
process_time = time.time() - start_time
print(f"url: {request.url}, process_time: {process_time:.4f}")
return resp

CORS跨域中间件

在前后端分离架构中,经常会遇到跨域请求的问题。比如前端部署所在域名为web.ab.com,后端API部署的域名为api.ab.com,这是典型的非同源环境。在web.ab.com域名下访问api.ab.com,就会触发浏览器同源安全策略。这种安全策略策略是浏览器自带的基本安全策略。

在这种情况下,我们应该开启部分域名(或IP)来发起访问请求。常见跨域请求处理方法有:

  • 使用代理机制,也就是通过同源服务器下的后端进行代理请求以获取非同源服务下的资源数据。
  • 使用jsonp方式,但是jsonp方式仅限于GET请求。
  • 使用CORS方式,相较于jsonp,CORS方式的优势在于支持的请求方式更多,浏览器兼容性更好。

FastAPI提供了CORSMiddleware中间件

from fastapi.middleware.cors import CORSMiddleware

origins = [
"https://api.ab.com",
"https://api2.ab.com",
"https://api3.ab.com:8080",
]

app.add_middleware(
CORSMiddleware,
allow_origins=origins, # 允许跨域请求的域名列表
allow_credentials=True, # 跨域请求时是否允许发送跨域凭证, 一般为cookie
allow_methods=["*"], # 跨域允许的HTTP请求方法
allow_headers=["*"], # 跨域允许的HTTP请求头
)

HTTPSRedirectMiddleware

该内置中间件用于强制所有请求使用HTTPS。

from fastapi.middlerware.httpsredirect import HTTPSRedirectMiddlerware

app.add_middleware(HTTPSRedirectMiddlerware)

TrustedHostMiddlerware

某些场景下,需要强制请求Header中的host选项必须来自于指定的host才允许访问指定的地址。

示例

from fastapi.middleware.trustedhost import TrustedHostMiddleware

allowed_hosts = [
"exaple.com",
"*.example.com",
]

app.add_middleware(TrustedHostMiddleware, allowed_hosts = allowed_hosts)

自定义中间件

FastAPI内置的中间件未必满足所有业务场景需求,因此还需要自定义中间件。

基本示例

将中间件的代码放到单独的模块中,编辑middlewares/demo.py

from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request
import time

class TimeCalculate(BaseHTTPMiddleware):
# dispatch必须实现
async def dispatch(self, request: Request, call_next):
print(">>> TimeCalculate Middleware <<<")
start_time = time.time()
resp = await call_next(request)
elapsed_time = round(time.time() - start_time, 4)
print(f"URL: {request.url}, Elapsed time: {elapsed_time}s")
return resp

引入自定义中间件

from middlewares import demo

app.add_middleware(demo.TimeCalculate)

日志追踪链路ID

通过链路追踪,可以在一个请求服务过程中把涉及多个的能源服务或其他第三方请求的日志都关联起来,这样可以快速进行问题定位及排错。要实现链路追踪,就需要为请求打标签。一般的做法是,每进来一个请求,就在当前请求上下文中生成一个链路ID,这个链路ID可以关联第三方请求或其他服务日志,并在整个请求链路上下文中进行传递,直到请求完成响应处理。

import contextvars
import uuid

# 定义一个上下文变量对象,主要用于存储当前请求的上下文信息
request_context = contextvars.ContextVar("request_context")

class TraceID(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request_context.set(request) # 将当前请求传入当前request_context的上下文变量对象中
request.state.traceid = uuid.uuid4() # 生成traceid,并写入request.state中
resp = await call_next(request)
return resp


# 创建一个视图函数测试
@app.get("/")
async def get_response():
request: Request = request_context.get()
print(f"index-request: {request.state.traceid}")
return {"message": "Hello World"}

# 添加到中间件
app.add_middleware(TraceID)

自定义类实现中间件

除了通过集成BaseHTTPMiddleware类来自定义中间件外,还可以基于自定义类来实现。

from starlette.responses import JSONResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.requests import HTTPConnection
import typing


# 自定义类实现黑名单IP中间件
class BlackIPMiddleware(BaseHTTPMiddleware):
# ASGIApp对象必须要有,其它参数根据实际需求设置
def __init__(self, app: ASGIApp, denied_ip: typing.Sequence[str] = ()):
self.app = app
self.denied_ip = denied_ip
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope['type'] in ("http", "websocket") and scope["scheme"] in ("http", "ws"):
conn = HTTPConnection(scope=scope)
if self.denied_ip and conn.client.host in self.denied_ip:
resp = JSONResponse({'message': 'IP Denied'}, status_code=403)
await resp(scope, receive, send)
return
await self.app(scope, receive, send)
else:
await self.app(scope, receive, send)

# 注册到app实例
app.add_middleware(BlackIPMiddleware, denied_ip = ["192.168.1.108"])

基于中间件获取响应内容

在某种场景下,需要在中间件中获取对应请求的响应报文,如常见的日志记录场景。

to be continue