Spark.GBDT学习-GBTClassifier

用于分类的GBT(Gradient-Boosted Trees)算法,基于J.H. Friedman. “Stochastic Gradient Boosting”实现,目前不支持多分类任务。Gradient Boosting vs. TreeBoost:

  • 本实现基于Stochastic Gradient Boosting(随机梯度提升),而不是TreeBoost
  • 两种方法都是通过最小化损失函数,学习树的集成
  • TreeBoost方法相对于原始方法,基于损失函数对叶节点的输出进行了额外的修改
  • Spark考虑未来实现TreeBoost

GBTClassifier

定义

一个唯一标识uid,继承了Predictor类,继承了GBTClassifierParamsDefaultParamsWritableLogging特质。其中Predictor中的三个元素分别代表: 特征类型、学习器、学习到用于预测的模型

class GBTClassifier(override val uid: String) 
extends Predictor[Vector, GBTClassifier, GBTClassificationModel] 
with GBTClassifierParams with DefaultParamsWritable with Logging 
{
    def this() = this(Identifiable.randomUID("gbtc"))
    ...
}

参数

为了兼容JAVA API,覆盖了继承自特质(with trait)的参数setter方法。

  1. TreeClassifierParams参数
  • maxDepth
    树的最大深度,0意味着只有一个叶节点,1意味着有一个内部节点+两个叶节点。
    支持:>=0
    默认:5
  • maxBins
    用于离散连续特征的最大分桶数,用于每个节点特征分裂时分裂点的选择,分桶数越大意味着粒度越高。
    支持:>=2并且>=任一类别特征的分类数
    默认:32
  • minInstancesPerNode
    分裂后每个子节点含有的最小样本数,如果分裂后左孩子或右孩子含有的样本数少于该值,则该分裂无效。
    支持:>=1
    默认:1
  • minInfoGain
    树节点分裂时的最小信息增益。
    支持:>=0.0
    默认:0.0
  • maxMemoryInMB
    每次会对一组节点进行切分,分组是按照树的层次逐步进行。每组需要切分的节点个数视内存大小而定,如果内存太小,每次只能切分一个节点。单位MB
    默认:256MB
  • cacheNodeIds
    如果为true,算法会为每个实例缓存树节点ID;如果为false,算法会将树传递给执行器用于匹配实例和树节点。缓存有利于加速训练深度较大的树,用户可以通过参数checkpointInterval设置缓存被检查的频率或者不检查。
    默认:false
  • checkpointInterval
    表示缓存的树节点ID的检查频率,当cacheNodeIds为true并且检查目录(checkpoint directory)通过sparkContext设置过才有效。
    支持:>=1或者-1代表不检查,10意味着每10次迭代检查一次。
    默认:10
  • impurity
    用于计算信息增益的准则。不支持通过GBTClassifier.setImpurity方法设置该值。
    支持:entropy、gini
    默认:gini
  1. TreeEnsembleParams参数
  • subsamplingRate
    每一次迭代训练基学习器(决策树)时所使用的训练数据集的百分比。
    支持:(0, 1]
    默认:1.0
  • seed
    随机数种子
    默认:this.getClass.getName.hashCode.toLong
  1. GBTParams参数
  • maxIter
    最大迭代次数
    支持:>=0
    默认:20
  • stepSize
    学习率(learning rate/step size)参数,用于缩小(shrinking)每个基学习器的贡献。
    支持:(0, 1]
    默认:0.1
  1. GBTClassifierParams参数
  • lossType
    GBT最小化的损失函数,不区分大小写。
    支持:logistic
    默认:logistic

方法

  1. copy方法
    GBTClassifier的拷贝函数。
  2. train方法
    GBTClassifier类的主要方法,用于训练得到学习好的用于预测的模型。
// @input: 训练数据, DataSet
// @output: 学习到的模型, GBTClassificationModel
override protected def train(dataset: Dataset[_]):
GBTClassificationModel = {
    // 得到类别特征
    val categoricalFeatures: Map[Int, Int] =
    MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
    // 转换训练数据并进行验证
    // 将DataSet转换成RDD[LabeledPoint]
    // 只支持二分类,要求label为0或1
    val oldDataset: RDD[LabeledPoint] =
        dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
            case Row(label: Double, features: Vector) =>
                require(label == 0 || label == 1, s"GBTClassifier was given dataset with invalid label $label.  Labels must be in {0,1}; note that GBTClassifier currently only supports binary classification.")
            LabeledPoint(label, features)
        }
    // 获得特征个数及boosting策略
    val numFeatures = oldDataset.first().features.size
    val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
    // 用于记录日志
    val instr = Instrumentation.create(this, oldDataset)
    instr.logParams(params: _*)
    instr.logNumFeatures(numFeatures)
    instr.logNumClasses(2)
    // 调用GradientBoostedTrees训练得到一组学习器及其权重
    val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed))
    // 将学到的模型封装成GBTClassificationModel并返回
    val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
    instr.logSuccess(m)
    m
}

GBTClassifier对象

object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
    // final变量,访问支持的损失函数类型
    final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes
    // 从目录中加载GBTClassifier
    override def load(path: String): GBTClassifier = super.load(path)
}

GBTClassificationModel

用于分类的GBT模型,仅支持二分类,支持连续特征和类别特征。

定义

继承了PredictionModel类以及多个特质,其中PredictionModel的两个元素分别代表特征类型、学习到用于预测的模型

class GBTClassificationModel private[ml](
    override val uid: String,
    private val _trees: Array[DecisionTreeRegressionModel],
    private val _treeWeights: Array[Double],
    override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
with GBTClassifierParams 
with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable 
{
    // 检查_trees.nonEmpty
    // 检查_trees.length == _treeWeights.length
    val numTrees: Int = _trees.length
    ...
}

方法

  1. transformImpl方法
    首先将GBTClassificationModel进行广播,然后通过udf进行预测数据,udf中调用predict方法实现。
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
    // 广播本类
    val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
    val predictUDF = udf { (features: Any) =>
        // udf通过本类的predict方法实现
        bcastModel.value.predict(features.asInstanceOf[Vector])
    }
    // 使用udf将特征数据转换成预测数据
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  }
  1. predict方法
    关键的预测方法,先得到每个基学习器的预测值,然后进行融合得到最终的预测结果,最后得到类别结果。可以看到这里得到的预测值不是概率而是类别0/1,因为label被转换成了-1/+1,所以这里通过prediction>0.0得到预测lebel。
override protected def predict(features: Vector): Double = {
    // 得到每棵树的预测结果
    val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
    // 乘以权重之后求和得到融合结果
    val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
    // 得到预测lebel
    if (prediction > 0.0) 1.0 else 0.0
  }
  1. copy方法
    GBTClassificationModel的拷贝方法。
  2. toOld方法
    将ml的模型转换成mllib中老的API,ml域的私有方法。
private[ml] def toOld: OldGBTModel = {
    new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
}
  1. write方法
    调用GBTClassificationModel对象的方法保存本模型。
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)

GBTClassificationModel对象

  1. fromOld方法
    从老的API中转换出当前模型
  2. GBTClassificationModelReader
    私有类,其中的load方法用于从目录中读取模型
  3. GBTClassificationModelWriter
    私有类,其中的saveImpl方法用于将本模型保存
  4. read方法
    新建GBTClassificationModelReader
  5. load方法
    原文作者:松鼠胃口好
    原文地址: https://www.jianshu.com/p/421f76b8ac1d
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞