好大的一片森林:Spark TreeModel源码分析(二)

  1. model

1.1 Predict.scala

class Predict(
    val predict: Double          // 预测值    ,val prob: Double = 0.0       // 预测概率(分类) ) extends Serializable

1.2 Split.scala

case class Split(
    feature: Int,     // 特征idx     threshold: Double,   // 连续型feature切分阈值,小于等于则在左否则在右边     featureType: FeatureType,  // 特征类型     categories: List[Double])  // 离散型feature切分,如果值在set中则在左否则在右边 
class DummyLowSplit(feature: Int, featureType: FeatureType)
  extends Split(feature, Double.MinValue, featureType, List())

class DummyHighSplit(feature: Int, featureType: FeatureType)
  extends Split(feature, Double.MaxValue, featureType, List())

class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
  extends Split(feature, Double.MaxValue, featureType, List())

1.3 InformationGainStats.scala

class InformationGainStats(
    val gain: Double,   // 信息增益值     val impurity: Double,  // 当前node的不纯度     val leftImpurity: Double,
    val rightImpurity: Double,
    val leftPredict: Predict,
    val rightPredict: Predict) extends Serializable 

class ImpurityStats(
    val gain: Double,
    val impurity: Double,
    val impurityCalculator: ImpurityCalculator,
    val leftImpurityCalculator: ImpurityCalculator,
    val rightImpurityCalculator: ImpurityCalculator,
    val valid: Boolean = true  //当前split是否满足最小信息增益 or node上的最小样本数 ) extends Serializable {

  def leftImpurity: Double = if (leftImpurityCalculator != null) {
    leftImpurityCalculator.calculate()
  } else {
    -1.0
  }

  def rightImpurity: Double = if (rightImpurityCalculator != null) {
    rightImpurityCalculator.calculate()
  } else {
    -1.0
  }
}


object ImpurityStats {

  def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
    new ImpurityStats(Double.MinValue, impurityCalculator.calculate(),
      impurityCalculator, null, null, false)
  }

  def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
    new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null)
  }

}

1.4 Node.scala

class Node (
    val id: Int,      // node id 
    var predict: Predict,
    var impurity: Double,  // 当前node的不纯度
    var isLeaf: Boolean,
    var split: Option[Split],
    var leftNode: Option[Node],
    var rightNode: Option[Node],
    var stats: Option[InformationGainStats]) {


  def predict(features: Vector): Double = {
    if (isLeaf) {
      predict.predict
    } else {
      if (split.get.featureType == Continuous) {
        if (features(split.get.feature) <= split.get.threshold) {
          leftNode.get.predict(features)
        } else {
          rightNode.get.predict(features)
        }
      } else {
        if (split.get.categories.contains(features(split.get.feature))) {
          leftNode.get.predict(features)
        } else {
          rightNode.get.predict(features)
        }
      }
    }
  }

  def numDescendants: Int = if (isLeaf) {   // 该node下的节点数,leaf=0 
    0 
  } else {
    2 + leftNode.get.numDescendants + rightNode.get.numDescendants
  }


  def subtreeDepth: Int = if (isLeaf) {  // 自该node的树深度
    0
  } else {
    1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth)
  }

  def subtreeIterator: Iterator[Node] = {  // DFS遍历
    Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++
      rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty)
  }

}



object Node {

   def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, false, None, None, None, None)

  def apply(
      nodeIndex: Int,
      predict: Predict,
      impurity: Double,
      isLeaf: Boolean): Node = {
    new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None)
  }

  





}

    原文作者:金柔
    原文地址: https://zhuanlan.zhihu.com/p/40321230
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞