术→技巧, 研发

FastAPI 学习之依赖注入

钱魏Way · · 72 次浏览

先前学习FastAPI路由的时候已经介绍过依赖注入,但由于篇幅限制并没有梳理的特别详细,这次期望进行一些完整的梳理。

依赖注入简介

FastAPI 的 依赖注入(Dependency Injection) 是一种解耦代码、复用逻辑的核心机制,允许你将共享功能(如数据库连接、权限校验)注入到路由函数中。

核心概念

  • 依赖项(Dependencies)
    • 可以是函数或类,用于封装可复用的逻辑(如认证、数据库连接)。
    • 通过Depends() 声明依赖,FastAPI 自动解析并注入所需参数。
  • 作用
    • 代码复用:避免重复代码(如多个路由共享认证逻辑)。
    • 解耦:分离业务逻辑和基础设施(如数据库操作)。
    • 测试:轻松替换依赖项以模拟外部服务(如模拟数据库)。
    • 层次化校验:在路由执行前完成预处理(如权限检查)。

常见使用场景

场景 依赖示例
数据库会话管理 注入 Session 对象到 CRUD 操作。
用户认证 校验 Token 并返回当前用户信息。
权限控制 根据用户角色限制访问。
请求日志记录 记录请求信息到日志系统。
限流与频率控制 限制接口调用频率。
配置管理 注入全局配置(如 API 密钥)。

注意事项

  • ​执行顺序:依赖项按声明顺序执行,子依赖项优先解析。
  • ​异常处理:依赖项中抛出HTTPException 会直接终止请求。
  • ​文档生成:依赖项的参数(如查询参数)会自动合并到 OpenAPI 文档。

依赖注入基本用法

函数作为依赖项

通过 Depends() 注入函数:

from fastapi import Depends, FastAPI

app = FastAPI()

# 定义一个依赖函数
def get_db_connection():
    return "Database Connection"

@app.get("/items/")
def read_items(db: str = Depends(get_db_connection)):
    return {"db": db}

效果:read_items 函数会自动接收 get_db_connection 的返回值。

类作为依赖项

注入类的实例方法:

class Database:
    def get_conn(self):
        return "Database Connection"

db = Database()

@app.get("/items/")
def read_items(conn: str = Depends(db.get_conn)):
    return {"conn": conn}

依赖的嵌套与复用

依赖可以嵌套其他依赖,形成依赖链。

示例:权限校验链

from fastapi import HTTPException, Depends, status

def get_current_user(token: str = Header(...)):
    if token != "secret":
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
    return {"user": "admin"}

def check_admin(user: dict = Depends(get_current_user)):
    if user["user"] != "admin":
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
    return user

@app.get("/admin/")
def admin_panel(user: dict = Depends(check_admin)):
    return {"message": "Admin Panel"}

全局依赖

路由级依赖

为单个路由添加多个依赖:

def log_request(request: Request):
    print(f"Request to {request.url}")

@app.get("/items/", dependencies=[Depends(log_request)])
def read_items():
    return [{"item": "Foo"}]

应用级依赖

为所有路由添加依赖:

app = FastAPI(dependencies=[Depends(log_request)])

异步依赖

依赖函数可以是异步的:

async def async_dependency():
    await asyncio.sleep(1)
    return "Async Data"

@app.get("/async/")
async def read_async(data: str = Depends(async_dependency)):
return {"data": data}

依赖注入的高级模式

依赖的作用域与生命周期

依赖项的作用范围与缓存

  • 默认缓存:同一请求中多次调用同一依赖项时,结果会被缓存。
  • 禁用缓存:设置 use_cache=False。
from fastapi import Depends

# 缓存依赖结果(同一请求中多次调用只执行一次)
def get_heavy_service():
    print("Initializing Heavy Service...")
    return "Heavy Service Data"

@app.get("/data/")
def get_data(
    service1: str = Depends(get_heavy_service),
    service2: str = Depends(get_heavy_service)
):
    return {"service1": service1, "service2": service2}

输出:Initializing Heavy Service… 只会打印一次。

依赖覆盖(测试场景)

在测试中替换依赖实现:

from fastapi.testclient import TestClient

def override_dependency():
    return "Mocked Data"

app.dependency_overrides[get_db_connection] = override_dependency

client = TestClient(app)
response = client.get("/items/")
assert response.json() == {"db": "Mocked Data"}

参数化依赖

通过返回依赖函数实现动态配置:

def get_pagination_params(
    skip: int = Query(0, ge=0),
    limit: int = Query(10, ge=1, le=100)
):
    return {"skip": skip, "limit": limit}

@app.get("/items/")
def read_items(pagination: dict = Depends(get_pagination_params)):
    return {"skip": pagination["skip"], "limit": pagination["limit"]}

类作为依赖提供者

class AuthChecker:
    def __init__(self, role: str):
        self.role = role

    def __call__(self, token: str = Header(...)):
        if token != self.role:
            raise HTTPException(status_code=403)
        return {"role": self.role}

admin_checker = AuthChecker(role="admin")

@app.get("/admin/")
def admin_panel(user: dict = Depends(admin_checker)):
    return user

常见应用场景

认证与权限

def verify_token(token: str = Header(...)):
    if token != "secret":
        raise HTTPException(status_code=403)
    return "user"

@app.get("/secure")
async def secure_route(user: str = Depends(verify_token)):
    return {"user": user}

数据库会话管理

from sqlalchemy.orm import Session

def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

@app.post("/items/")
async def create_item(item: Item, db: Session = Depends(get_db)):
    db.add(item)
    db.commit()
    return item

请求参数预处理

def pagination_params(skip: int = 0, limit: int = 100):
    return {"skip": skip, "limit": limit}

@app.get("/items/")
async def list_items(pagination: dict = Depends(pagination_params)):
    return fetch_items(pagination["skip"], pagination["limit"])

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注