器→工具, 开源项目, 数据, 术→技巧

Python因子分解库:fastFM

钱魏Way · · 1,713 次浏览

FastFM简介

FastFM的主要特点是将是将因子分解封装成scikit-learn API接口,核心代码使用C编写,性能有一定的保障。

fastFM主要提供了回归、分类、排序三种问题的解决方法。其中对于优化器,有als,mcmc,sgd三种,Loss function则对应于所需要解决的问题而不同。

Task Solver Loss
Regression als, mcmc, sgd Square Loss
Classification als, mcmc, sgd Probit(Map), Probit, Sigmoid
Ranking sgd BPR

如何选取不同的优化器?

  • ALS:
    • 优点:预测速度快,相比SGD,所需要的参数更少
    • 缺点:需要手动正则化
  • SGD:
    • 优点:预测速度快,可以在大数据量的基础上迭代
    • 缺点:必须手动指定正则化,超参数,step_size步长需要自己指定
  • MCMC:
    • 优点:需要的超参数很少,一般只需要指定迭代次数、初始化方差、以及rank,自动正则化
    • 缺点:在训练的过程中需要预测测试集

FastFM的安装

电脑中的Anaconda环境运行在Windows 10的Linux子系统中,昨天在安装fastFM遇到了一些问题,前前后后花了很多时间才解决。这里记录下来供同样遇到问题的同学参考。

使用 pip install fastFM返回如下错误:

Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
Collecting fastfm
Downloading http://mirrors.aliyun.com/pypi/packages/41/31/7fb81ab6b11bd35f085eb9b4d0e2e48158056e4a639a1e67057519259512/fastFM-0.2.10.tar.gz (1.6MB)
|████████████████████████████████| 1.6MB 2.7MB/s eta 0:00:01     |████████████████████████████▎   | 1.4MB 2.7MB/s eta 0:00:01
Requirement already satisfied: numpy in /home/qw/anaconda3/lib/python3.7/site-packages (from fastfm) (1.16.4)
Requirement already satisfied: scikit-learn in /home/qw/anaconda3/lib/python3.7/site-packages (from fastfm) (0.21.2)
Requirement already satisfied: scipy in /home/qw/anaconda3/lib/python3.7/site-packages (from fastfm) (1.3.0)
Requirement already satisfied: cython in /home/qw/anaconda3/lib/python3.7/site-packages (from fastfm) (0.29.12)
Requirement already satisfied: joblib>=0.11 in /home/qw/anaconda3/lib/python3.7/site-packages (from scikit-learn->fastfm) (0.13.2)
Building wheels for collected packages: fastfm
Building wheel for fastfm (setup.py) ... error
ERROR: Complete output from command /home/qw/anaconda3/bin/python -u -c 'import setuptools, tokenize;__file__='"'"'/tmp/pip-install-e6au9j7a/fastfm/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-n7dqx8xr --python-tag cp37:
ERROR: running bdist_wheel
running build
running build_py
creating build
creating build/lib.linux-x86_64-3.7
creating build/lib.linux-x86_64-3.7/fastFM
copying fastFM/__init__.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/als.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/base.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/bpr.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/datasets.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/mcmc.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/sgd.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/utils.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/validation.py -> build/lib.linux-x86_64-3.7/fastFM
running build_ext
skipping 'fastFM/ffm.c' Cython extension (up-to-date)
building 'ffm' extension
creating build/temp.linux-x86_64-3.7
creating build/temp.linux-x86_64-3.7/fastFM
gcc -pthread -B /home/qw/anaconda3/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -IfastFM/ -IfastFM-core/include/ -IfastFM-core/externals/CXSparse/Include/ -I/home/qw/anaconda3/lib/python3.7/site-packages/numpy/core/include -I/home/qw/anaconda3/include/python3.7m -c fastFM/ffm.c -o build/temp.linux-x86_64-3.7/fastFM/ffm.o
In file included from /home/qw/anaconda3/lib/python3.7/site-packages/numpy/core/include/numpy/ndarraytypes.h:1824:0,
from /home/qw/anaconda3/lib/python3.7/site-packages/numpy/core/include/numpy/ndarrayobject.h:12,
from /home/qw/anaconda3/lib/python3.7/site-packages/numpy/core/include/numpy/arrayobject.h:4,
from fastFM/ffm.c:528:
/home/qw/anaconda3/lib/python3.7/site-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h:17:2: warning: #warning "Using deprecated NumPy API, disable it with " "#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" [-Wcpp]
#warning "Using deprecated NumPy API, disable it with " \
^~~~~~~
fastFM/ffm.c: In function ‘__Pyx__ExceptionSave’:
fastFM/ffm.c:25496:21: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
*type = tstate->exc_type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25497:22: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
*value = tstate->exc_value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25498:19: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
*tb = tstate->exc_traceback;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c: In function ‘__Pyx__ExceptionReset’:
fastFM/ffm.c:25505:24: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tmp_type = tstate->exc_type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25506:25: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tmp_value = tstate->exc_value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25507:22: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tmp_tb = tstate->exc_traceback;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c:25508:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tstate->exc_type = type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25509:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tstate->exc_value = value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25510:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tstate->exc_traceback = tb;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c: In function ‘__Pyx__GetException’:
fastFM/ffm.c:25580:24: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tmp_type = tstate->exc_type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25581:25: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tmp_value = tstate->exc_value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25582:22: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tmp_tb = tstate->exc_traceback;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c:25583:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tstate->exc_type = local_type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25584:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tstate->exc_value = local_value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25585:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tstate->exc_traceback = local_tb;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c: In function ‘__Pyx__ExceptionSwap’:
fastFM/ffm.c:25822:24: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tmp_type = tstate->exc_type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25823:25: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tmp_value = tstate->exc_value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25824:22: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tmp_tb = tstate->exc_traceback;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c:25825:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tstate->exc_type = *type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25826:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tstate->exc_value = *value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25827:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tstate->exc_traceback = *tb;
^~~~~~~~~~~~~
curexc_traceback
error: command 'gcc' failed with exit status 1
----------------------------------------
ERROR: Failed building wheel for fastfm
Running setup.py clean for fastfm
Failed to build fastfm
Installing collected packages: fastfm
Running setup.py install for fastfm ... error
ERROR: Complete output from command /home/qw/anaconda3/bin/python -u -c 'import setuptools, tokenize;__file__='"'"'/tmp/pip-install-e6au9j7a/fastfm/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' install --record /tmp/pip-record-9srvrii3/install-record.txt --single-version-externally-managed --compile:
ERROR: running install
running build
running build_py
creating build
creating build/lib.linux-x86_64-3.7
creating build/lib.linux-x86_64-3.7/fastFM
copying fastFM/__init__.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/als.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/base.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/bpr.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/datasets.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/mcmc.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/sgd.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/utils.py -> build/lib.linux-x86_64-3.7/fastFM
copying fastFM/validation.py -> build/lib.linux-x86_64-3.7/fastFM
running build_ext
skipping 'fastFM/ffm.c' Cython extension (up-to-date)
building 'ffm' extension
creating build/temp.linux-x86_64-3.7
creating build/temp.linux-x86_64-3.7/fastFM
gcc -pthread -B /home/qw/anaconda3/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -IfastFM/ -IfastFM-core/include/ -IfastFM-core/externals/CXSparse/Include/ -I/home/qw/anaconda3/lib/python3.7/site-packages/numpy/core/include -I/home/qw/anaconda3/include/python3.7m -c fastFM/ffm.c -o build/temp.linux-x86_64-3.7/fastFM/ffm.o
In file included from /home/qw/anaconda3/lib/python3.7/site-packages/numpy/core/include/numpy/ndarraytypes.h:1824:0,
from /home/qw/anaconda3/lib/python3.7/site-packages/numpy/core/include/numpy/ndarrayobject.h:12,
from /home/qw/anaconda3/lib/python3.7/site-packages/numpy/core/include/numpy/arrayobject.h:4,
from fastFM/ffm.c:528:
/home/qw/anaconda3/lib/python3.7/site-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h:17:2: warning: #warning "Using deprecated NumPy API, disable it with " "#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" [-Wcpp]
#warning "Using deprecated NumPy API, disable it with " \
^~~~~~~
fastFM/ffm.c: In function ‘__Pyx__ExceptionSave’:
fastFM/ffm.c:25496:21: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
*type = tstate->exc_type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25497:22: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
*value = tstate->exc_value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25498:19: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
*tb = tstate->exc_traceback;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c: In function ‘__Pyx__ExceptionReset’:
fastFM/ffm.c:25505:24: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tmp_type = tstate->exc_type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25506:25: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tmp_value = tstate->exc_value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25507:22: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tmp_tb = tstate->exc_traceback;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c:25508:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tstate->exc_type = type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25509:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tstate->exc_value = value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25510:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tstate->exc_traceback = tb;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c: In function ‘__Pyx__GetException’:
fastFM/ffm.c:25580:24: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tmp_type = tstate->exc_type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25581:25: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tmp_value = tstate->exc_value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25582:22: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tmp_tb = tstate->exc_traceback;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c:25583:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tstate->exc_type = local_type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25584:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tstate->exc_value = local_value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25585:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tstate->exc_traceback = local_tb;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c: In function ‘__Pyx__ExceptionSwap’:
fastFM/ffm.c:25822:24: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tmp_type = tstate->exc_type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25823:25: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tmp_value = tstate->exc_value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25824:22: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tmp_tb = tstate->exc_traceback;
^~~~~~~~~~~~~
curexc_traceback
fastFM/ffm.c:25825:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
tstate->exc_type = *type;
^~~~~~~~
curexc_type
fastFM/ffm.c:25826:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’; did you mean ‘curexc_value’?
tstate->exc_value = *value;
^~~~~~~~~
curexc_value
fastFM/ffm.c:25827:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’; did you mean ‘curexc_traceback’?
tstate->exc_traceback = *tb;
^~~~~~~~~~~~~
curexc_traceback
error: command 'gcc' failed with exit status 1
----------------------------------------
ERROR: Command "/home/qw/anaconda3/bin/python -u -c 'import setuptools, tokenize;__file__='"'"'/tmp/pip-install-e6au9j7a/fastfm/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' install --record /tmp/pip-record-9srvrii3/install-record.txt --single-version-externally-managed --compile" failed with error code 1 in /tmp/pip-install-e6au9j7a/fastfm/

报错信息

从错误信息中,主要错误信息有:

  • error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’; did you mean ‘curexc_type’?
  • error: command ‘gcc’ failed with exit status 1

经过搜索确认了大致的解决方案:

  • 第一个错误, Cython的一个Bug,网上的解决方案是安装最新的Cython或低版本的Cython,总共尝试的方法有:
    • pip install Cython –upgrade
    • pip install Cython==0.27.3
    • pip install https://github.com/cython/cython.git
  • 第二个错误,大致是缺少Python的dev包,网上找到的解决方案是:
    • sudo apt-get install python-dev
    • sudo apt-get install python3-dev libevent-dev

在执行完上述步骤后,报错信息仍然没有改变。至此,完全被卡克了。今天起来后 ,从新看了一遍资料,最终把此问题解决了。

解决方案

重新看了fastFM的项目,发现其编译后的fastFM,只支持到了3.6,并没有3.7版本的,产生错误的原因可能是使用pip install fastFM时获取的包不对,可以尝试本地编译的方式进行安装。

于是安装官方的编译流程进行了编译,最终大功告成~

# Install cblas and python-dev header (Linux only).
# - cblas can be installed with libatlas-base-dev or libopenblas-dev (Ubuntu)
$ sudo apt-get install python-dev libopenblas-dev
# Clone the repo including submodules (or clone + `git submodule update --init --recursive`)
$ git clone --recursive https://github.com/ibayer/fastFM.git
# Enter the root directory
$ cd fastFM
# Install Python dependencies (Cython>=0.22, numpy, pandas, scipy, scikit-learn)
$ pip install -r ./requirements.txt
# Compile the C extension.
$ make                      # build with default python version (python)
$ PYTHON=python3 make       # build with custom python version (python3)
# Install fastFM
$ pip install .

FastFM的使用

使用ALS进行回归预测

from fastFM import als
from fastFM.datasets import make_user_item_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import numpy as np
from matplotlib import pyplot as plt
X, y, coef = make_user_item_regression(label_stdev=.4)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
n_iter = 20
step_size = 1
l2_reg_w = 0
l2_reg_V = 0
fm = als.FMRegression(n_iter=0, l2_reg_w=0.1, l2_reg_V=0.1, rank=4)
# Allocates and initalizes the model parameter.
fm.fit(X_train, y_train)
rmse_train = []
rmse_test = []
r2_score_train = []
r2_score_test = []
for i in range(1, n_iter):
fm.fit(X_train, y_train, n_more_iter=step_size)
y_pred = fm.predict(X_test)
rmse_train.append(np.sqrt(mean_squared_error(fm.predict(X_train), y_train)))
rmse_test.append(np.sqrt(mean_squared_error(fm.predict(X_test), y_test)))
r2_score_train.append(r2_score(fm.predict(X_train), y_train))
r2_score_test.append(r2_score(fm.predict(X_test), y_test))
fig, axes = plt.subplots(ncols=2, figsize=(15, 4))
x = np.arange(1, n_iter) * step_size
with plt.style.context('fivethirtyeight'):
axes[0].plot(x, rmse_train, label='RMSE-train', color='r', ls="--")
axes[0].plot(x, rmse_test, label='RMSE-test', color='r')
axes[1].plot(x, r2_score_train, label='R^2-train', color='b', ls="--")
axes[1].plot(x, r2_score_test, label='R^2-test', color='b')
axes[0].set_ylabel('RMSE', color='r')
axes[1].set_ylabel('R^2', color='b')
axes[0].legend()
axes[1].legend()

使用MCMC进行回归预测

from fastFM.datasets import make_user_item_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import numpy as np
from matplotlib import pyplot as plt
from fastFM import mcmc
n_iter = 100
step_size = 10
seed = 123
rank = 3
X, y, coef = make_user_item_regression(label_stdev=.4)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33)
fm = mcmc.FMRegression(n_iter=0, rank=rank, random_state=seed)
# Allocates and initalizes the model and hyper parameter.
fm.fit_predict(X_train, y_train, X_test)
rmse_test = []
rmse_new = []
hyper_param = np.zeros((n_iter -1, 3 + 2 * rank), dtype=np.float64)
for nr, i in enumerate(range(1, n_iter)):
fm.random_state = i * seed
y_pred = fm.fit_predict(X_train, y_train, X_test, n_more_iter=step_size)
rmse_test.append(np.sqrt(mean_squared_error(y_pred, y_test)))
hyper_param[nr, :] = fm.hyper_param_
values = np.arange(1, n_iter)
x = values * step_size
burn_in = 5
x = x[burn_in:]
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, figsize=(15, 8))
axes[0, 0].plot(x, rmse_test[burn_in:], label='test rmse', color="r")
axes[0, 0].legend()
axes[0, 1].plot(x, hyper_param[burn_in:,0], label='alpha', color="b")
axes[0, 1].legend()
axes[1, 0].plot(x, hyper_param[burn_in:,1], label='lambda_w', color="g")
axes[1, 0].legend()
axes[1, 1].plot(x, hyper_param[burn_in:,3], label='mu_w', color="g")
axes[1, 1].legend()

注意事项:fastfm的特征需要csr格式(稀疏矩阵),如果是panas dataframe 需要先进行转换。

X_train = scipy.sparse.csr_matrix(X_train.values)
X_test = scipy.sparse.csr_matrix(X_test.values)

参考链接:

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注