文章内容如有错误或排版问题,请提交反馈,非常感谢!
FastAPI 是一个现代化的 Python Web 框架,专为构建高性能 API 而设计。它支持异步请求处理,并自动生成交互式 API 文档。本指南以 LightGBM 模型为例,详细介绍如何开发、优化和部署机器学习模型接口。

基础实现
环境准备
pip install fastapi uvicorn lightgbm numpy pydantic
模型训练与保存
import lightgbm as lgb
import numpy as np
from sklearn.datasets import make_classification
# 生成示例数据
X, y = make_classification(n_samples=100, n_features=2, random_state=42)
train_data = lgb.Dataset(X, label=y)
# 配置模型参数
params = {
'objective': 'binary',
'verbosity': -1,
}
# 训练模型
model = lgb.train(params, train_data, num_boost_round=10)
# 保存模型(推荐二进制格式)
model.save_model('model.bin')
FastAPI 应用基础实现
from fastapi import FastAPI
from pydantic import BaseModel
import lightgbm as lgb
import numpy as np
# 定义请求数据结构
class PredictionRequest(BaseModel):
feature1: float
feature2: float
# 初始化应用
app = FastAPI()
# 加载模型
try:
model = lgb.Booster(model_file='model.bin')
except Exception as e:
raise RuntimeError("模型加载失败") from e
# 配置CORS(开发环境)
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
)
# 预测端点
@app.post("/predict")
def predict(request: PredictionRequest):
try:
# 数据预处理
input_data = [[request.feature1, request.feature2]]
# 执行预测
prediction = model.predict(input_data)
# 返回结果(二分类阈值0.5)
return {"prediction": float(prediction[0] > 0.5)}
except Exception as e:
return {"error": str(e)}
启动服务
uvicorn main:app --reload --host 0.0.0.0 --port 8000
接口测试
方法一:Swagger UI
访问 http://localhost:8000/docs,在 /predict 端点进行交互测试。
方法二:Python requests
import requests
data = {"feature1": 0.5, "feature2": -0.3}
response = requests.post(" http://localhost:8000/predict ", json=data)
print(response.json())
关键优化策略
模型格式优化
| 格式类型 | 文件体积 | 加载速度 | 可读性 | 适用场景 |
| 文本格式 | 大 | 慢 | 高 | 调试/模型分析 |
| 二进制格式 | 小 | 快 | 无 | 生产环境推荐 |
| JSON格式 | 中等 | 中等 | 高 | 跨平台兼容 |
最佳实践:
# 保存为二进制格式
model.save_model("model.bin")
# 加载时指定格式
model = lgb.Booster(model_file="model.bin")
输入验证增强
from pydantic import Field
class PredictionRequest(BaseModel):
city: str = Field(..., min_length=2, max_length=10)
feature1: float = Field(..., ge=0, le=200)
feature2: float = Field(..., ge=0)
单例模式加载
from fastapi import FastAPI
from contextlib import asynccontextmanager
import lightgbm as lgb
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动时加载模型
app.state.model = lgb.Booster(model_file="model.bin")
yield
# 关闭时清理资源
del app.state.model
app = FastAPI(lifespan=lifespan)
@app.post("/predict")
def predict(request: PredictionRequest):
model = app.state.model # 使用全局共享模型
# ... 预测逻辑
优势:
- 避免重复加载,节省内存
- 减少磁盘I/O开销
- 确保线程安全
特征处理支持
import json
import os
# 特征映射配置
FEATURE_MAPPING = {
"city": {"北京": 0, "上海": 1, "深圳": 2, "其他": 3}
}
def preprocess_features(data: PredictionRequest) -> np.ndarray:
"""特征预处理函数"""
# 类别特征编码
city_code = FEATURE_MAPPING["city"].get(data.city, 3)
# 数值特征标准化
feature1 = (data.feature1 - 100) / 50
feature2 = data.feature2 * 0.01
return np.array([[city_code, feature1, feature2]])
# 从文件加载特征映射(推荐)
MODELS_DIR = "models"
with open(os.path.join(MODELS_DIR, "feature_mapping.json")) as f:
FEATURE_MAPPING = json.load(f)
模型版本管理
models/
├── v1.0/
│ ├── model.bin
│ └── feature_mapping.json
├── v1.1/
│ ├── model.bin
│ └── feature_mapping.json
from fastapi import HTTPException
@app.post("/predict")
async def predict(request: PredictionRequest, version: str = "v1.0"):
try:
# 检查模型缓存
if version not in app.state.model_cache:
app.state.model_cache[version] = load_model(version)
model = app.state.model_cache[version]
# ... 预测逻辑
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail="模型版本不存在")
性能优化
异步线程池处理CPU密集型任务:
import asyncio
@app.post("/predict")
async def predict(request: PredictionRequest):
loop = asyncio.get_event_loop()
# 将同步预测任务提交到线程池
prediction = await loop.run_in_executor(
None, # 使用默认线程池
_predict_batch, # 同步预测函数
processed_features
)
return prediction
def _predict_batch(features: np.ndarray):
"""线程安全的同步预测函数"""
return model.predict(features)
LightGBM多线程配置:
# 加载时配置并行
model = lgb.Booster(
model_file=model_path,
params={"num_threads": 4} # 根据CPU核心数调整
)
# 或预测时指定
prediction = model.predict(data, num_threads=4)
批量预测支持:
from typing import List
from pydantic import BaseModel, Query
class BatchRequest(BaseModel):
requests: List[PredictionRequest]
version: str = Query("v1.0")
@app.post("/batch_predict")
async def batch_predict(batch: BatchRequest):
# 批量特征处理
features = np.array([preprocess_features(r) for r in batch.requests])
# 批量预测
predictions = await run_in_threadpool(model.predict, features)
return {"results": predictions.tolist()}
资源管控
请求限流:
from slowapi import Limiter
from slowapi.util import get_remote_address
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
@app.post("/predict")
@limiter.limit("100/minute") # 每分钟100次请求
async def predict(request: PredictionRequest):
# ... 业务逻辑
进程级隔离部署:
# 使用gunicorn启动(4个工作进程,每个2线程) gunicorn -w 4 --threads 2 -k uvicorn.workers.UvicornWorker main:app
资源监控:
import psutil
@app.middleware("http")
async def resource_monitor(request, call_next):
mem = psutil.virtual_memory()
if mem.percent > 90:
raise HTTPException(503, "系统资源紧张")
return await call_next(request)
完整优化示例
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from slowapi import Limiter
from slowapi.util import get_remote_address
import lightgbm as lgb
import numpy as np
import asyncio
import psutil
import os
# 初始化限流器
limiter = Limiter(key_func=get_remote_address)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
app.state.model_cache = {}
app.state.limiter = limiter
yield
app.state.model_cache.clear()
app = FastAPI(lifespan=lifespan)
class PredictionRequest(BaseModel):
city: str
feature1: float
feature2: float
def _load_model(version: str) -> lgb.Booster:
"""线程安全的模型加载"""
return lgb.Booster(
model_file=f"models/{version}/model.bin",
params={"num_threads": os.cpu_count() // 2} # 预留一半CPU资源
)
@app.post("/predict")
@limiter.limit("100/minute")
async def predict(
request: PredictionRequest,
version: str = "v1.0"
):
# 资源检查
if psutil.virtual_memory().percent > 90:
raise HTTPException(503, "系统资源不足")
# 获取模型
if version not in app.state.model_cache:
app.state.model_cache[version] = _load_model(version)
# 特征处理
features = preprocess_features(request)
# 提交预测任务到线程池
loop = asyncio.get_event_loop()
try:
prediction = await loop.run_in_executor(
None,
app.state.model_cache[version].predict,
features,
{"num_threads": 4}
)
return {"result": float(prediction[0](@ref)}
except Exception as e:
raise HTTPException(500, str(e))
# 系统监控中间件
@app.middleware("http")
async def check_resources(request: Request, call_next):
if psutil.cpu_percent() > 95:
raise HTTPException(503, "CPU过载")
return await call_next(request)
部署建议
服务器选择
- 实例类型:选择计算优化型实例(如AWS C5系列、GCP C2系列)
- CPU要求:确保支持AVX指令集,加速数值计算
- 内存配置:根据模型大小和并发量合理配置
容器化部署
# Dockerfile示例 FROM python:3.9-slim # 安装系统依赖 RUN apt-get update && apt-get install -y libgomp1 # 复制应用代码 COPY . /app WORKDIR /app # 安装Python依赖 RUN pip install -r requirements.txt # 启动命令 CMD ["gunicorn", "-w", "4", "-k", "uvicorn.workers.UvicornWorker", "main:app"]
自动伸缩策略(Kubernetes示例)
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: ml-api-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: ml-api
minReplicas: 2
maxReplicas: 10
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
总结
本文详细介绍了使用FastAPI开发机器学习模型接口的全过程,从基础实现到生产级优化。关键优化点包括:
- 模型格式选择(推荐二进制格式)
- 单例模式加载避免重复I/O
- 异步处理CPU密集型任务
- 完善的输入验证和错误处理
- 资源管控和监控
- 容器化部署和自动伸缩
遵循这些最佳实践,可以构建出高性能、可扩展、易维护的机器学习API服务。



