1.1 Predict.scala
class Predict(
val predict: Double // 预测值 ,val prob: Double = 0.0 // 预测概率(分类) ) extends Serializable
1.2 Split.scala
case class Split(
feature: Int, // 特征idx threshold: Double, // 连续型feature切分阈值,小于等于则在左否则在右边 featureType: FeatureType, // 特征类型 categories: List[Double]) // 离散型feature切分,如果值在set中则在左否则在右边
class DummyLowSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MinValue, featureType, List())
class DummyHighSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())
class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())
1.3 InformationGainStats.scala
class InformationGainStats(
val gain: Double, // 信息增益值 val impurity: Double, // 当前node的不纯度 val leftImpurity: Double,
val rightImpurity: Double,
val leftPredict: Predict,
val rightPredict: Predict) extends Serializable
class ImpurityStats(
val gain: Double,
val impurity: Double,
val impurityCalculator: ImpurityCalculator,
val leftImpurityCalculator: ImpurityCalculator,
val rightImpurityCalculator: ImpurityCalculator,
val valid: Boolean = true //当前split是否满足最小信息增益 or node上的最小样本数 ) extends Serializable {
def leftImpurity: Double = if (leftImpurityCalculator != null) {
leftImpurityCalculator.calculate()
} else {
-1.0
}
def rightImpurity: Double = if (rightImpurityCalculator != null) {
rightImpurityCalculator.calculate()
} else {
-1.0
}
}
object ImpurityStats {
def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
new ImpurityStats(Double.MinValue, impurityCalculator.calculate(),
impurityCalculator, null, null, false)
}
def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null)
}
}
1.4 Node.scala
class Node (
val id: Int, // node id
var predict: Predict,
var impurity: Double, // 当前node的不纯度
var isLeaf: Boolean,
var split: Option[Split],
var leftNode: Option[Node],
var rightNode: Option[Node],
var stats: Option[InformationGainStats]) {
def predict(features: Vector): Double = {
if (isLeaf) {
predict.predict
} else {
if (split.get.featureType == Continuous) {
if (features(split.get.feature) <= split.get.threshold) {
leftNode.get.predict(features)
} else {
rightNode.get.predict(features)
}
} else {
if (split.get.categories.contains(features(split.get.feature))) {
leftNode.get.predict(features)
} else {
rightNode.get.predict(features)
}
}
}
}
def numDescendants: Int = if (isLeaf) { // 该node下的节点数,leaf=0
0
} else {
2 + leftNode.get.numDescendants + rightNode.get.numDescendants
}
def subtreeDepth: Int = if (isLeaf) { // 自该node的树深度
0
} else {
1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth)
}
def subtreeIterator: Iterator[Node] = { // DFS遍历
Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++
rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty)
}
}
object Node {
def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, false, None, None, None, None)
def apply(
nodeIndex: Int,
predict: Predict,
impurity: Double,
isLeaf: Boolean): Node = {
new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None)
}
}