数据, 术→技巧, 自然语言处理, 观点, 读书笔记

LightGBM的模型保存

钱魏Way · · 262 次浏览
!            文章内容如有格式错误,请反馈,谢谢...

平时在使用LightGMB,需要保存训练好的模型。以下是梳理的几种方式:

使用LightGBM 自带的save_model 方法

import lightgbm as lgb

# 假设已经训练好的模型是 model
model = lgb.LGBMClassifier()
model.fit(X_train, y_train)

# 保存模型
model.booster_.save_model('model.txt')

# 加载模型
loaded_model = lgb.Booster(model_file='model.txt')
此方法将模型保存为一个文本文件 model.txt,可以在后续使用 lgb.Booster 加载。
由于你使用了 lightgbm.train 方法训练模型,而不是 Scikit-learn 风格的 LGBMClassifier。
在 LightGBM 的原生接口下,model 不需要访问 .booster_,因为 model 本身就是一个 Booster 对象。
# 直接保存模型
model.save_model('data/model.txt')

优点:

  • 原生支持:LightGBM 自带方法,专为保存 LightGBM 模型设计。
  • 速度快:保存和加载的效率非常高,因为只处理模型本身。
  • 文件体积小:以文本格式保存,包含模型的所有必要信息(如树结构、超参数等)。
  • 无额外依赖:无需安装其他库。

缺点:

  • 功能有限:仅支持保存模型本身,不能保存完整的工作流(如数据预处理)。
  • 跨工具兼容性差:保存为 .txt 文件,仅限于 LightGBM 解析。

适用场景:

  • 轻量级任务:只需保存和加载 LightGBM 模型(不涉及复杂数据预处理)。
  • 模型复现和调试:快速保存和加载模型,用于调试或轻量化部署。

性能:

  • 保存/加载效率:高效。
  • 文件大小:小。
  • 易用性:非常简单,直接调用 API 即可。

使用 joblib 保存和加载

joblib 是一种快速、压缩的数据序列化方式,特别适合保存训练好的模型对象。

from joblib import dump, load
import lightgbm as lgb

# 假设已经训练好的模型是 model
model = lgb.LGBMClassifier()
model.fit(X_train, y_train)

# 保存模型
joblib.dump(model, 'model.pkl')

# 加载模型
loaded_model = joblib.load('model.pkl')

这种方法保存的是整个 LGBM 模型对象,适合与 Scikit-learn 管道一起使用。

优点:

  • 快速序列化:适合保存较大的对象,包括模型和上下文。
  • 兼容性好:与 Scikit-learn 流程和管道高度兼容。
  • 支持压缩:可以通过压缩减少存储空间。

缺点:

  • 对模型单一性要求高:仅适用于 Python 环境,不适合跨语言使用。
  • 文件较大:相比 save_model 的文本文件,保存的二进制文件体积稍大。

适用场景:

  • 与 Scikit-learn 管道集成:当模型与其他预处理组件一起保存。
  • 中小型项目:需要快速保存并在 Python 环境中复现。

性能:

  • 保存/加载效率:高效,但略低于 save_model。
  • 文件大小:中等。
  • 易用性:使用简单,适合 Scikit-learn 用户。

使用 Python 原生的 pickle 序列化

如果模型比较小,也可以使用pickle 保存。

import pickle
import lightgbm as lgb

# 假设已经训练好的模型是 model
model = lgb.LGBMClassifier()
model.fit(X_train, y_train)

# 保存模型
with open('model.pkl', 'wb') as f:
pickle.dump(model, f)

# 加载模型
with open('model.pkl', 'rb') as f:
loaded_model = pickle.load(f)

优点:

  • Python 原生支持:无需安装额外库。
  • 灵活性高:可以序列化几乎任何 Python 对象。

缺点:

  • 效率较低:序列化速度比 joblib 慢。
  • 文件较大:文件体积通常较 save_model 和 joblib 大。
  • 兼容性差:仅适用于 Python,且对跨平台支持较弱。

适用场景:

  • 简单保存任务:快速保存模型,不涉及其他复杂需求。
  • 开发调试:快速保存模型以便调试和测试。

性能:

  • 保存/加载效率:慢于 save_model 和 joblib。
  • 文件大小:较大。
  • 易用性:使用简单,但不适合大规模生产。

使用 JSON 格式保存模型参数

LightGBM 模型还可以将模型参数以 JSON 格式导出,便于跨平台应用或在其他语言中加载。

# 导出为 JSON 格式
model_json = model.booster_.dump_model()

# 保存为 JSON 文件
import json
with open('model.json', 'w') as f:
json.dump(model_json, f)

对比总结

  • save_model: 推荐用于 LightGBM 自身的训练和推理,简单高效。
  • joblib: 推荐用于与 Scikit-learn 一起使用,便于保存更多上下文信息。
  • pickle: 通用方式,但效率和灵活性不如 joblib。
  • JSON 格式: 适合跨平台需求,但仅保存模型结构和参数。

使用 PMML 文件进行保存

LightGBM 模型也可以导出为PMML文件,PMML (Predictive Model Markup Language)是一种开放标准,用于表示数据挖掘模型。通过将LightGBM模型导出为PMML文件,可以实现模型的跨平台部署和与其他系统的集成。

下面是如何保存LightGBM模型为PMML文件的方法。

使用sklearn2pmml工具

sklearn2pmml是一个流行的Python库,可以将兼容Scikit-learn接口的模型(例如LightGBM的LGBMClassifier和LGBMRegressor)导出为PMML文件。

安装依赖

pip install sklearn2pmml
pip install pypmml # 如果需要加载PMML文件

保存PMML文件

import pandas as pd
from sklearn2pmml import PMMLPipeline, sklearn2pmml
import lightgbm as lgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# 加载数据集
data = load_iris()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target, name="target")

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建LightGBM模型
model = lgb.LGBMClassifier()
model.fit(X_train, y_train)

# 包装模型到PMMLPipeline
pipeline = PMMLPipeline([
("model", model)
])

# 导出为PMML文件
sklearn2pmml(pipeline, "model.pmml", with_repr=True)

加载PMML文件并进行预测

如果需要加载PMML文件,可以使用pypmml。

from pypmml import Model

# 加载PMML模型
pmml_model = Model.load('model.pmml')

# 进行预测
predictions = pmml_model.predict(X_test)
print(predictions)

使用nyoka工具

nyoka是另一个用于生成PMML文件的Python库。

安装依赖

pip install nyoka

保存PMML文件

from nyoka import lgb_to_pmml
import lightgbm as lgb

# 假设已经训练好的模型
model = lgb.LGBMClassifier()
model.fit(X_train, y_train)

# 导出为PMML文件
lgb_to_pmml(model, feature_names=X_train.columns.tolist(), target_name="target", pmml_f_name="model.pmml")

sklearn2pmml和nyoka的对比

  • sklearn2pmml:
    • 优势:支持完整的Scikit-learn管道,适合需要进行数据预处理的场景。
    • 限制:需要将模型封装为PMMLPipeline。
  • nyoka:
    • 优势:更直接地支持LightGBM,适合仅保存模型的情况。
    • 限制:不支持复杂的管道,需要自行处理数据预处理。

优点:

  • 跨平台支持:PMML是开放标准,可在各种工具中加载(如KNIME、RapidMiner)。
  • 便于部署:适合模型的跨平台和跨语言部署。
  • 模型透明度:PMML文件可视化模型结构和参数。

缺点:

  • 开发复杂度高:需要借助sklearn2pmml或nyoka,额外步骤多。
  • 性能较低:生成文件的时间和文件解析的开销较高。
  • 文件较大:PMML文件体积通常较大。

适用场景:

  • 跨平台部署:模型需在多种语言或平台(如Java)中使用。
  • 与外部工具集成:需要与BI工具或外部分析工具交互。

性能:

  • 保存/加载效率:最慢。
  • 文件大小:最大。
  • 易用性:复杂,需要额外开发工作。

综合对比表

特性 LightGBM save_model joblib pickle PMML (sklearn2pmml/nyoka)
保存速度 快速 快速 中等 较慢
加载速度 快速 快速 中等 较慢
文件大小 中等 最大
易用性 非常简单 简单 简单 较复杂
跨平台支持
保存对象 模型本身 模型及上下文 模型及上下文 模型及部分上下文
适用场景 轻量部署、调试 Python项目 Python项目 跨平台部署、大型生产环境

推荐选择

  • 首选save_model:如果仅需要保存LightGBM模型本身,这是性能和易用性的最佳选择。
  • 选用joblib:如果需要与Scikit-learn或其他组件配合,joblib提供了更高的灵活性。
  • 备用pickle:适用于小型项目或临时保存,但不推荐用于生产环境。
  • 选用PMML:如果需要跨语言或跨平台部署,PMML是唯一合适的选择,尽管开发复杂度和性能较低。

如果不考虑跨平台兼容性,save_model的效率和易用性最佳;若有更复杂的需求,可以视场景选择其他方式。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注