数据, 术→技巧, 自然语言处理, 观点, 读书笔记

LightGBM的模型保存

钱魏Way · · 15 次浏览

平时在使用LightGMB,需要保存训练好的模型。以下是梳理的几种方式:

使用 LightGBM 自带的 save_model 方法

import lightgbm as lgb

# 假设已经训练好的模型是 model
model = lgb.LGBMClassifier()
model.fit(X_train, y_train)

# 保存模型
model.booster_.save_model('model.txt')

# 加载模型
loaded_model = lgb.Booster(model_file='model.txt')
此方法将模型保存为一个文本文件 model.txt,可以在后续使用 lgb.Booster 加载。
由于你使用了 lightgbm.train 方法训练模型,而不是 Scikit-learn 风格的 LGBMClassifier。
在 LightGBM 的原生接口下,model 不需要访问 .booster_,因为 model 本身就是一个 Booster 对象。
# 直接保存模型
model.save_model('data/model.txt')

优点:

  • 原生支持:LightGBM 自带方法,专为保存 LightGBM 模型设计。
  • 速度快:保存和加载的效率非常高,因为只处理模型本身。
  • 文件体积小:以文本格式保存,包含模型的所有必要信息(如树结构、超参数等)。
  • 无额外依赖:无需安装其他库。

缺点:

  • 功能有限:仅支持保存模型本身,不能保存完整的工作流(如数据预处理)。
  • 跨工具兼容性差:保存为 .txt 文件,仅限于 LightGBM 解析。

适用场景:

  • 轻量级任务:只需保存和加载 LightGBM 模型(不涉及复杂数据预处理)。
  • 模型复现和调试:快速保存和加载模型,用于调试或轻量化部署。

性能:

  • 保存/加载效率:高效。
  • 文件大小:小。
  • 易用性:非常简单,直接调用 API 即可。

使用 joblib 保存和加载

joblib 是一种快速、压缩的数据序列化方式,特别适合保存训练好的模型对象。

from joblib import dump, load
import lightgbm as lgb

# 假设已经训练好的模型是 model
model = lgb.LGBMClassifier()
model.fit(X_train, y_train)

# 保存模型
joblib.dump(model, 'model.pkl')

# 加载模型
loaded_model = joblib.load('model.pkl')

这种方法保存的是整个 LGBM 模型对象,适合与 Scikit-learn 管道一起使用。

优点:

  • 快速序列化:适合保存较大的对象,包括模型和上下文。
  • 兼容性好:与 Scikit-learn 流程和管道高度兼容。
  • 支持压缩:可以通过压缩减少存储空间。

缺点:

  • 对模型单一性要求高:仅适用于 Python 环境,不适合跨语言使用。
  • 文件较大:相比 save_model 的文本文件,保存的二进制文件体积稍大。

适用场景:

  • 与 Scikit-learn 管道集成:当模型与其他预处理组件一起保存。
  • 中小型项目:需要快速保存并在 Python 环境中复现。

性能:

  • 保存/加载效率:高效,但略低于 save_model。
  • 文件大小:中等。
  • 易用性:使用简单,适合 Scikit-learn 用户。

使用 Python 原生的 pickle 序列化

如果模型比较小,也可以使用 pickle 保存。

import pickle
import lightgbm as lgb

# 假设已经训练好的模型是 model
model = lgb.LGBMClassifier()
model.fit(X_train, y_train)

# 保存模型
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)

# 加载模型
with open('model.pkl', 'rb') as f:
    loaded_model = pickle.load(f)

优点:

  • Python 原生支持:无需安装额外库。
  • 灵活性高:可以序列化几乎任何 Python 对象。

缺点:

  • 效率较低:序列化速度比 joblib 慢。
  • 文件较大:文件体积通常较 save_model 和 joblib 大。
  • 兼容性差:仅适用于 Python,且对跨平台支持较弱。

适用场景:

  • 简单保存任务:快速保存模型,不涉及其他复杂需求。
  • 开发调试:快速保存模型以便调试和测试。

性能:

  • 保存/加载效率:慢于 save_model 和 joblib。
  • 文件大小:较大。
  • 易用性:使用简单,但不适合大规模生产。

使用 JSON 格式保存模型参数

LightGBM 模型还可以将模型参数以 JSON 格式导出,便于跨平台应用或在其他语言中加载。

# 导出为 JSON 格式
model_json = model.booster_.dump_model()

# 保存为 JSON 文件
import json
with open('model.json', 'w') as f:
    json.dump(model_json, f)

对比总结

  • save_model: 推荐用于 LightGBM 自身的训练和推理,简单高效。
  • joblib: 推荐用于与 Scikit-learn 一起使用,便于保存更多上下文信息。
  • pickle: 通用方式,但效率和灵活性不如 joblib。
  • JSON 格式: 适合跨平台需求,但仅保存模型结构和参数。

使用PMML文件进行保存

LightGBM 模型也可以导出为 PMML 文件,PMML (Predictive Model Markup Language) 是一种开放标准,用于表示数据挖掘模型。通过将 LightGBM 模型导出为 PMML 文件,可以实现模型的跨平台部署和与其他系统的集成。

下面是如何保存 LightGBM 模型为 PMML 文件的方法。

使用 sklearn2pmml 工具

sklearn2pmml 是一个流行的 Python 库,可以将兼容 Scikit-learn 接口的模型(例如 LightGBM 的 LGBMClassifier 和 LGBMRegressor)导出为 PMML 文件。

安装依赖

pip install sklearn2pmml
pip install pypmml  # 如果需要加载 PMML 文件

保存 PMML 文件

import pandas as pd
from sklearn2pmml import PMMLPipeline, sklearn2pmml
import lightgbm as lgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# 加载数据集
data = load_iris()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target, name="target")

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建 LightGBM 模型
model = lgb.LGBMClassifier()
model.fit(X_train, y_train)

# 包装模型到 PMMLPipeline
pipeline = PMMLPipeline([
    ("model", model)
])

# 导出为 PMML 文件
sklearn2pmml(pipeline, "model.pmml", with_repr=True)

加载 PMML 文件并进行预测

如果需要加载 PMML 文件,可以使用 pypmml。

from pypmml import Model

# 加载 PMML 模型
pmml_model = Model.load('model.pmml')

# 进行预测
predictions = pmml_model.predict(X_test)
print(predictions)

使用 nyoka 工具

nyoka 是另一个用于生成 PMML 文件的 Python 库。

安装依赖

pip install nyoka

保存 PMML 文件

from nyoka import lgb_to_pmml
import lightgbm as lgb

# 假设已经训练好的模型
model = lgb.LGBMClassifier()
model.fit(X_train, y_train)

# 导出为 PMML 文件
lgb_to_pmml(model, feature_names=X_train.columns.tolist(), target_name="target", pmml_f_name="model.pmml")

sklearn2pmml和nyoka的对比

  • sklearn2pmml:
    • 优势:支持完整的 Scikit-learn 管道,适合需要进行数据预处理的场景。
    • 限制:需要将模型封装为 PMMLPipeline。
  • nyoka:
    • 优势:更直接地支持 LightGBM,适合仅保存模型的情况。
    • 限制:不支持复杂的管道,需要自行处理数据预处理。

优点:

  • 跨平台支持:PMML 是开放标准,可在各种工具中加载(如 KNIME、RapidMiner)。
  • 便于部署:适合模型的跨平台和跨语言部署。
  • 模型透明度:PMML 文件可视化模型结构和参数。

缺点:

  • 开发复杂度高:需要借助 sklearn2pmml 或 nyoka,额外步骤多。
  • 性能较低:生成文件的时间和文件解析的开销较高。
  • 文件较大:PMML 文件体积通常较大。

适用场景:

  • 跨平台部署:模型需在多种语言或平台(如 Java)中使用。
  • 与外部工具集成:需要与 BI 工具或外部分析工具交互。

性能:

  • 保存/加载效率:最慢。
  • 文件大小:最大。
  • 易用性:复杂,需要额外开发工作。

综合对比表

特性 LightGBM save_model joblib pickle PMML (sklearn2pmml/nyoka)
保存速度 快速 快速 中等 较慢
加载速度 快速 快速 中等 较慢
文件大小 中等 最大
易用性 非常简单 简单 简单 较复杂
跨平台支持
保存对象 模型本身 模型及上下文 模型及上下文 模型及部分上下文
适用场景 轻量部署、调试 Python 项目 Python 项目 跨平台部署、大型生产环境

推荐选择

  • 首选 save_model:如果仅需要保存 LightGBM 模型本身,这是性能和易用性的最佳选择。
  • 选用 joblib:如果需要与 Scikit-learn 或其他组件配合,joblib 提供了更高的灵活性。
  • 备用 pickle:适用于小型项目或临时保存,但不推荐用于生产环境。
  • 选用 PMML:如果需要跨语言或跨平台部署,PMML 是唯一合适的选择,尽管开发复杂度和性能较低。

如果不考虑跨平台兼容性,save_model 的效率和易用性最佳;若有更复杂的需求,可以视场景选择其他方式。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注