不得不吐槽一下,Scala真的是一门奇葩的语言,太强大了,只要你好好思考,写出来的代码绝对很漂亮,瞬间感觉java太low了。
package Utils
import com.google.common.math.{DoubleMath, IntMath}
/**
* Created by fhqplzj on 16-8-24 at 下午2:12.
*/
object Evaluation {
/**
* 检查标签
*
* @param labelsTrue
* @param labelsPred
*/
private def labelChecker(labelsTrue: Array[Int], labelsPred: Array[Int]): Unit = {
require(labelsTrue.length == labelsPred.length && labelsTrue.length >= 2, "The length must be equal!" +
"The size of labels must be greater than 1!")
}
/**
* 纯度:Purity
*
* @param labelsTrue
* @param labelsPred
* @return
*/
def purity(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
labelChecker(labelsTrue, labelsPred)
val eachCount: Map[(Int, Int), Int] = labelsTrue.zip(labelsPred).groupBy(x => x).mapValues(_.length)
eachCount.groupBy(_._1._1).mapValues(_.values.max).values.sum.toDouble / labelsTrue.length
}
/**
* 互信息:Mutual Information
*
* @param labelsTrue
* @param labelsPred
*/
private def mutualInformation(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
labelChecker(labelsTrue, labelsPred)
val N: Int = labelsTrue.length
val mapTrue: Map[Int, Int] = labelsTrue.groupBy(x => x).mapValues(_.length)
val mapPred: Map[Int, Int] = labelsPred.groupBy(x => x).mapValues(_.length)
labelsTrue.zip(labelsPred).groupBy(x => x).mapValues(_.length).map {
case ((x, y), z) =>
val wk = mapTrue(x)
val cj = mapPred(y)
val common = z.toDouble
common / N * DoubleMath.log2(N * common / (wk * cj))
}.sum
}
/**
* 熵:Entropy
*
* @param labels
* @return
*/
private def entropy(labels: Array[Int]) = {
val N: Int = labels.length
val array: Array[Int] = labels.groupBy(x => x).values.map(_.length).toArray
array.map(x => -1.0 * x / N * DoubleMath.log2(1.0 * x / N)).sum
}
/**
* 标准化互信息:Normalized Mutual Information
*
* @param labelsTrue
* @param labelsPred
* @return
*/
def normalizedMutualInformation(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
labelChecker(labelsTrue, labelsPred)
2 * mutualInformation(labelsTrue, labelsPred) / (entropy(labelsTrue) + entropy(labelsPred))
}
/**
* 混淆矩阵
*
* @param TP
* @param FP
* @param FN
* @param TN
*/
case class Table(TP: Int, FP: Int, FN: Int, TN: Int)
/**
* 计算混淆矩阵
*
* @param labelsTrue
* @param labelsPred
* @return
*/
private def contingencyTable(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
labelChecker(labelsTrue, labelsPred)
def binomial(x: Int) = if (x < 2) 0 else IntMath.binomial(x, 2)
val TPAndFP: Int = labelsPred.groupBy(x => x).values.map(x => binomial(x.length)).sum
val tmp: Map[(Int, Int), Array[(Int, Int)]] = labelsTrue.zip(labelsPred).groupBy(x => x)
val TP: Int = tmp.values.map(x => binomial(x.length)).sum
val FP: Int = TPAndFP - TP
def fun(xs: Array[Int]) = {
val length: Int = xs.length
val sums: Array[Int] = xs.tails.slice(1, length).toArray.map(_.sum)
(xs.init, sums).zipped.map(_ * _).sum
}
val FN: Int = tmp.groupBy(_._1._1).mapValues(_.values.map(_.length).toArray).values.map(fun).sum
val total: Int = binomial(labelsTrue.length)
val TN: Int = total - TPAndFP - FN
Table(TP, FP, FN, TN)
}
/**
* Rand Index值
*
* @param labelsTrue
* @param labelsPred
* @return
*/
def randIndex(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
labelChecker(labelsTrue, labelsPred)
val table: Table = contingencyTable(labelsTrue, labelsPred)
1.0 * (table.TP + table.TN) / (table.TP + table.FP + table.FN + table.TN)
}
/**
* 准确率:Precision
*
* @param labelsTrue
* @param labelsPred
* @return
*/
def precision(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
labelChecker(labelsTrue, labelsPred)
val table: Table = contingencyTable(labelsTrue, labelsPred)
1.0 * table.TP / (table.TP + table.FP)
}
/**
* 召回率:Recall
*
* @param labelsTrue
* @param labelsPred
* @return
*/
def recall(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
labelChecker(labelsTrue, labelsPred)
val table: Table = contingencyTable(labelsTrue, labelsPred)
1.0 * table.TP / (table.TP + table.FN)
}
/**
* FMeasure
* F值
*
* @param labelsTrue
* @param labelsPred
* @param beta
* @return
*/
def FMeasure(labelsTrue: Array[Int], labelsPred: Array[Int])(implicit beta: Double = 1.0) = {
labelChecker(labelsTrue, labelsPred)
val precision1: Double = precision(labelsTrue, labelsPred)
val recall1: Double = recall(labelsTrue, labelsPred)
(math.pow(beta, 2) + 1) * precision1 * recall1 / (math.pow(beta, 2) * precision1 + recall1)
}
def main(args: Array[String]): Unit = {
val labelTrue = Array.fill(8)(1) ++ Array.fill(5)(2) ++ Array.fill(4)(3)
val labelPred = Array(1, 1, 1, 1, 1, 2, 3, 3, 1, 2, 2, 2, 2, 2, 3, 3, 3)
println(contingencyTable(labelTrue, labelPred))
}
}