Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现

1. 概要

Gradient Tree Boosting (别名 GBM, GBRT, GBDT, MART)是一类很常用的集成学习算法,在KDD Cup, Kaggle组织的很多数据挖掘竞赛中多次表现出在分类和回归任务上面最好的performance。同时在2010年Yahoo Learning to Rank Challenge中, 夺得冠军的LambdaMART算法也属于这一类算法。因此Tree Boosting算法和深度学习算法DNN/CNN/RNN等等一样在工业界和学术界中得到了非常广泛的应用。

最近研读了UW Tianqi Chen博士写的关于Gradient Tree Boosting 的Slide和Notes, 牛人就是牛人,可以把算法和模型讲的如此清楚,深入浅出,感觉对Tree Boosting算法的理解进一步加深了一些。本来打算写一篇比较详细的算法解析的文章,后来一想不如记录一些阅读心得和关键点,感兴趣的读者可以直接看英文原版资料如下:

A. Introduction to Boosted Tree. https://xgboost.readthedocs.io/en/latest/model.html

B. Introduction to Boosted Trees. By Tianqi Chen. http://homes.cs.washington.edu/~tqchen/data/pdf/BoostedTree.pdf

C. Tianqi Chen and Carlos Guestrin. XGBoost: A Scalable Tree Boosting System. in KDD ’16. http://www.kdd.org/kdd2016/papers/files/rfp0697-chenAemb.pdf

感觉资料B这个slide反倒比资料A Document讲的更细致一些,这个Document跳过了一些slide里面提到的细节。这篇KDD paper的Section 2基本和这个Slide和Dcoument里面提到的公式一样。

2. 阅读笔记 (注释: 这部分摘录的图片出自Tianqi的slide,感谢原作者的精彩分享,我主要加上了一些个人理解性笔记,具体细节可以参考原版slide)

Tianqi的Slide首先给出了监督学习中一些常用基本概念的介绍,然后给出了Tree Ensemble 模型的目标函数定义

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

监督学习算法的目标函数通常包括Loss和Regularization两部分,这里给出的是一般形式,具体Loss的定义可以是Square Loss, Hinge Loss, Logistic Loss等等,关于Regularization可以是模型参数的L2或者 L1 norm等等。 这里针对Tree Ensemble算法,可以用树的节点数量,深度,叶子节点weights的L2 norm,叶子节点的数目等等来定义模型的复杂度。总体目标是学习出既有足够预测能力又不过于复杂/过拟合训练数据的模型。

给定模型目标函数,如何进行优化最小化cost得到最优模型参数? 这里SGD就不适用了,因为模型参数是一些Tree Structure的集合而不是数值向量。我们可以用Additive Training (Boosting)算法来进行训练

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

因此模型的训练分多轮进行,每一轮我们在已经学到的tree的基础上尝试新添加一颗新树,这里显示了每一轮后预测值的变化关系。每一轮我们尝试去寻找最可能最小化目标函数的tree f_t(x_i)加入模型。那么如何寻找这样的tree呢?先来分析一下目标函数:

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

如果是square loss,还是很容易展开成如上简洁的形式。对于logistic loss等比较复杂的loss function的一般情形,我们可以使用泰勒展开式:

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

这张slide给出泰勒展开式后定义了g_i和h_i,分别是前一轮prediction loss对于前一轮的预测值的一阶导数和二阶导数。square loss下的目标函数形式可以视为这种泰勒展开形式下的特殊情形。我们可以把g_i,h_i带入计算一下很快就可以发现。下面还有另一个问题,如何定义regularization term 呢?

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

slide里面解释很清楚,还给出了具体例子。注意这里q是一个把训练example隐射到叶子节点index的函数。

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

于是可以带入regularization term的定义,然后按照每个叶子节点上面的score重新group目标函数的计算,这个I_j是所有被映射到叶子节点j的example的集合。这样重新group后我们更容易看出最值点和最优值:

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

这里只用到了二次函数的最值点(-b/2a)和对应的最小值。当q(x)即树的结构不变时,上面的slide给出了最优的叶子节点的weight和对应的目标函数值,这个目标函数值可以被视为给定q(x)可以达到的最小cost值,因此可以被用来evaluate一个树的结构好不好。下面的slide给出了一个计算实例:

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

于是可以用如下算法来搜索最优的待添加树

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

为了避免穷举所有可能的树结构,我们可以采用如下的贪心搜索的策略:

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

对于某个feature,如何确定最佳split的点?可以先对examples按照feature值进行排序,然后对每一个可能的切分点计算Gain,选择可以最大化Gain的切分点,然后对所有 d 个feature,所有K个level重复此过程。和决策树的建树算法有点相通之处。

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

注意Gain的计算公式里面的gamma是因为split后增加1个叶子节点导致的。公式里面也可以看出最小化loss和最简化模型中间的trade-off:

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

最后,总结以上所有目标函数定义,分析及其优化训练过程就得到了Boosted Tree 算法:

《Gradient Tree Boosting (GBM, GBRT, GBDT, MART)算法解析和基于XGBoost/Scikit-learn的实现》

为了进一步加深对算法的理解,彻底搞懂算法所有细节,最好的方法还是效仿Tianqi那样自己动手实现一下这个算法,看最后一张slide这个算法也没有那么复杂,但是估计实现过程还是有很多坑要踩的。

3. 基于XGBoost/Scikit-learn的实现

如果不想自己造轮子,有很多可用的开源实现,例如Scikit-learn就给出了包括Tree Boosting在内的各种supervise learning算法的实现,下面给出一份实例code,总的来说Scikit-learn还是很全很好用的,因此也广受欢迎。注意这只是一份示例code,我省去了从训练数据测试数据中读取对应X/y的code,有python基础的读者应该很容易加上。

import numpy as np
import sys

# !skip code to read train/test data from files
print 'read data...'
X_train = np.nan_to_num(X_train)
X_test = np.nan_to_num(X_test)

print 'train data size: ', len(X_train)
print 'test data size: ', len(X_test)

# Data normalization
#===================================================
from sklearn import preprocessing
# scale the data attributes
scaler = preprocessing.MinMaxScaler().fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)

print 'normalized_X: ', X_train

# Feature selection
#===================================================
from sklearn import metrics
# from sklearn.ensemble import ExtraTreesClassifier
# model = ExtraTreesClassifier()
# model.fit(X_train, y_train)
# # display the relative importance of each attribute
# print('feature_importance', model.feature_importances_)

# Classification
#===================================================

# Build Model

print 'build model...'

#AdaBoost, LR, NeuralNet, SVM, RandomForest, Bagging, ExtraTrees
if model_name == 'LR':
    from sklearn.linear_model import LogisticRegression
    model = LogisticRegression()
elif model_name == 'NeuralNet':
     from sklearn.neural_network import MLPClassifier
     model = MLPClassifier(solver='adam', alpha=1e-5, activation='relu',
                           hidden_layer_sizes=(100, 100), random_state=1)
elif model_name == 'SVM':
    from sklearn.svm import LinearSVC
    model = LinearSVC()
elif model_name == 'RandomForest':
    from sklearn.ensemble import RandomForestClassifier
    model = RandomForestClassifier()
elif model_name == 'AdaBoost':
    from sklearn.ensemble import AdaBoostClassifier
    model = AdaBoostClassifier()
elif model_name == 'GBRT':
    from sklearn.ensemble import GradientBoostingRegressor
    model = GradientBoostingRegressor(n_estimators=1000, learning_rate=0.1, loss='ls')
elif model_name == 'Bagging':
     from sklearn.ensemble import BaggingClassifier
     model = BaggingClassifier()
elif model_name == 'ExtraTrees':
     from sklearn.ensemble import ExtraTreesClassifier
     model = ExtraTreesClassifier()
else:
    raise  NameError("wrong model name!")

from sklearn import metrics

model.fit(X_train, y_train)
print(model)
# make predictions
expected = y_test
predicted = model.predict(X_test)
# summarize the fit of the model
print 'classification_report\n', metrics.classification_report(expected, predicted, digits=6)
print 'confusion_matrix\n', metrics.confusion_matrix(expected, predicted)
print 'accuracy\t', metrics.accuracy_score(expected, predicted)

print 'dump the predicted proba and predicted label to files in the folder ', model_res_path
predicted_score = model.predict_proba(X_test)
predicted_label = predicted
output_file_pred_score = model_res_path + data_name + '_' + model_name + '_' + feature_set + '.pred_score'
output_file_pred_label = model_res_path + data_name + '_' + model_name + '_' + feature_set + '.pred_label'
np.savetxt(output_file_pred_score, predicted_score, delimiter='\t')
np.savetxt(output_file_pred_label, predicted_label, delimiter='\t')

if model_name == 'RandomForest' or model_name == 'AdaBoost' or model_name == 'GBRT':
    print('feature importance score\n')
    print(model.feature_importances_)

    feat_import_score_file = model_res_path + model_name + '_' + feature_set + '.featimportance'
    print('save feature importance file to the model_res_path: ', feat_import_score_file)
    np.savetxt(feat_import_score_file, model.feature_importances_, delimiter='\t')

另外XGBoost (https://xgboost.readthedocs.io/en/latest/)也是很好的实现。具体代码示例在  https://xgboost.readthedocs.io/en/latest/get_started/index.html 已经提供了,感兴趣的读者可以参考XGBoost 官网网站的文档。

4 Reference

[1]. Introduction to Boosted Tree. https://xgboost.readthedocs.io/en/latest/model.html

[2]. Introduction to Boosted Trees. By Tianqi Chen. http://homes.cs.washington.edu/~tqchen/data/pdf/BoostedTree.pdf

[3]. Tianqi Chen and Carlos Guestrin. XGBoost: A Scalable Tree Boosting System. in KDD ’16. http://www.kdd.org/kdd2016/papers/files/rfp0697-chenAemb.pdf

[4]. XGBoost. https://xgboost.readthedocs.io/en/latest/get_started/index.html

点赞