先前学习FastAPI路由的时候已经介绍过依赖注入,但由于篇幅限制并没有梳理的特别详细,这次期望进行一些完整的梳理。
依赖注入简介
FastAPI 的 依赖注入(Dependency Injection) 是一种解耦代码、复用逻辑的核心机制,允许你将共享功能(如数据库连接、权限校验)注入到路由函数中。
核心概念
- 依赖项(Dependencies):
- 可以是函数或类,用于封装可复用的逻辑(如认证、数据库连接)。
- 通过Depends() 声明依赖,FastAPI 自动解析并注入所需参数。
- 作用:
- 代码复用:避免重复代码(如多个路由共享认证逻辑)。
- 解耦:分离业务逻辑和基础设施(如数据库操作)。
- 测试:轻松替换依赖项以模拟外部服务(如模拟数据库)。
- 层次化校验:在路由执行前完成预处理(如权限检查)。
常见使用场景
场景 | 依赖示例 |
数据库会话管理 | 注入 Session 对象到 CRUD 操作。 |
用户认证 | 校验 Token 并返回当前用户信息。 |
权限控制 | 根据用户角色限制访问。 |
请求日志记录 | 记录请求信息到日志系统。 |
限流与频率控制 | 限制接口调用频率。 |
配置管理 | 注入全局配置(如 API 密钥)。 |
注意事项
- 执行顺序:依赖项按声明顺序执行,子依赖项优先解析。
- 异常处理:依赖项中抛出HTTPException 会直接终止请求。
- 文档生成:依赖项的参数(如查询参数)会自动合并到 OpenAPI 文档。
依赖注入基本用法
函数作为依赖项
通过 Depends() 注入函数:
from fastapi import Depends, FastAPI app = FastAPI() # 定义一个依赖函数 def get_db_connection(): return "Database Connection" @app.get("/items/") def read_items(db: str = Depends(get_db_connection)): return {"db": db}
效果:read_items 函数会自动接收 get_db_connection 的返回值。
类作为依赖项
注入类的实例方法:
class Database: def get_conn(self): return "Database Connection" db = Database() @app.get("/items/") def read_items(conn: str = Depends(db.get_conn)): return {"conn": conn}
依赖的嵌套与复用
依赖可以嵌套其他依赖,形成依赖链。
示例:权限校验链
from fastapi import HTTPException, Depends, status def get_current_user(token: str = Header(...)): if token != "secret": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) return {"user": "admin"} def check_admin(user: dict = Depends(get_current_user)): if user["user"] != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return user @app.get("/admin/") def admin_panel(user: dict = Depends(check_admin)): return {"message": "Admin Panel"}
全局依赖
路由级依赖
为单个路由添加多个依赖:
def log_request(request: Request): print(f"Request to {request.url}") @app.get("/items/", dependencies=[Depends(log_request)]) def read_items(): return [{"item": "Foo"}]
应用级依赖
为所有路由添加依赖:
app = FastAPI(dependencies=[Depends(log_request)])
异步依赖
依赖函数可以是异步的:
async def async_dependency(): await asyncio.sleep(1) return "Async Data" @app.get("/async/") async def read_async(data: str = Depends(async_dependency)): return {"data": data}
依赖注入的高级模式
依赖的作用域与生命周期
依赖项的作用范围与缓存
- 默认缓存:同一请求中多次调用同一依赖项时,结果会被缓存。
- 禁用缓存:设置 use_cache=False。
from fastapi import Depends # 缓存依赖结果(同一请求中多次调用只执行一次) def get_heavy_service(): print("Initializing Heavy Service...") return "Heavy Service Data" @app.get("/data/") def get_data( service1: str = Depends(get_heavy_service), service2: str = Depends(get_heavy_service) ): return {"service1": service1, "service2": service2}
输出:Initializing Heavy Service… 只会打印一次。
依赖覆盖(测试场景)
在测试中替换依赖实现:
from fastapi.testclient import TestClient def override_dependency(): return "Mocked Data" app.dependency_overrides[get_db_connection] = override_dependency client = TestClient(app) response = client.get("/items/") assert response.json() == {"db": "Mocked Data"}
参数化依赖
通过返回依赖函数实现动态配置:
def get_pagination_params( skip: int = Query(0, ge=0), limit: int = Query(10, ge=1, le=100) ): return {"skip": skip, "limit": limit} @app.get("/items/") def read_items(pagination: dict = Depends(get_pagination_params)): return {"skip": pagination["skip"], "limit": pagination["limit"]}
类作为依赖提供者
class AuthChecker: def __init__(self, role: str): self.role = role def __call__(self, token: str = Header(...)): if token != self.role: raise HTTPException(status_code=403) return {"role": self.role} admin_checker = AuthChecker(role="admin") @app.get("/admin/") def admin_panel(user: dict = Depends(admin_checker)): return user
常见应用场景
认证与权限
def verify_token(token: str = Header(...)): if token != "secret": raise HTTPException(status_code=403) return "user" @app.get("/secure") async def secure_route(user: str = Depends(verify_token)): return {"user": user}
数据库会话管理
from sqlalchemy.orm import Session def get_db(): db = SessionLocal() try: yield db finally: db.close() @app.post("/items/") async def create_item(item: Item, db: Session = Depends(get_db)): db.add(item) db.commit() return item
请求参数预处理
def pagination_params(skip: int = 0, limit: int = 100): return {"skip": skip, "limit": limit} @app.get("/items/") async def list_items(pagination: dict = Depends(pagination_params)): return fetch_items(pagination["skip"], pagination["limit"])