Spark — 基于DataFrame API实现KNN算法
KNN简介
KNN(k-Nearest Neighbors)又称作k-近邻,核心思想用一句古话解释就是“近朱者赤,近墨者黑”,k-nn就是把未标记分类的案列归为与它们最相似的带有分类标记的案例所在的类。k-nn虽然简单,但是很强大。
KNN的特点
优点 | 缺点 |
---|---|
简单且有效 | 不产生模型 |
训练阶段很快 | 分类过程比较慢 |
对数据分布无要求 | 模型解释性较差 |
适合稀疏时间和多分类问题 | 名义变量和缺失数据需要额外处理 |
…… | …… |
实现步骤
- 计算距离:计算待测案例与训练样本之间的距离,常用的距离有欧式距离、曼哈顿距离、余弦距离等。
- 选择一个合适的k:确定用于KNN算法的邻居数量,一般用交叉验证或仅凭经验选择一个合适的k值,待测案例与训练样本之间距离最小的k个样本组成一个案例池。
- 类别判定:根据案例池的数据采用投票法或者加权投票法等方法来决定待测案例所属的类别。
代码实现
// 创建一个 KNN的方法
def runKnn(trainSet: DataFrame, testSet: DataFrame, k: Int, cl: String): DataFrame = {
val testFetures = testSet
.drop(cl).map(row => {
val fetuers: Seq[Double] = row.mkString(",").split(",").map(_.toDouble)
fetuers
}).toDF("fetuers")
val trainFetures = trainSet.map(row => {
val cla = row.getAs[String](cl)
val fetuers: Seq[Double] = row.mkString(",")
.split(",").filter(NumberUtils.isNumber(_)).map(_.toDouble)
(cla, fetuers)
}).toDF("class", "tfetuers")
val spec = Window.partitionBy($"fetuers").orderBy($"distans")
val spec2 = Window.partitionBy($"fetuers").orderBy($"count".desc)
testFetures.crossJoin(trainFetures)
.withColumn("distans", distanceUDF($"fetuers", $"tfetuers"))
.drop("tfetuers")
.withColumn("r", row_number().over(spec))
.where($"r" <= k)
.groupBy($"fetuers", $"class").count()
.withColumn("r", row_number().over(spec2))
.where($"r" === 1).drop("r", "count") // 投票选择
}
本案例使用的又是鸢尾花数据集
val iris = spark.read
.option("header", true)
.option("inferSchema", true)
.csv(inputFile)
// 将鸢尾花分成两部分:训练集和测试集
val Array(testSet, trainSet) = iris.randomSplit(Array(0.3, 0.7), 1234L)
val knnModel = new KNNModel(spark)
// 调用方法,设置k为10,要进行分类的属性为“class”
val res = knnModel.runKnn(trainSet, testSet, 10, "class")
完成分类之后,我们对结果做一个检测
// 结果检验
val testFetures = testSet
.map(row => {
val id = row.getAs[String]("class")
val fetuers = row.mkString(",").split(",")
.filter(NumberUtils.isNumber(_))
.map(_.toDouble)
(id, fetuers)
}).toDF("yclass", "fetuers")
val check = udf((f1: String, f2: String) => {
if (f1.equals(f2)) 1 else 0
})
res.show()
res.join(testFetures, "fetuers")
.withColumn("check", check($"class", $"yclass"))
.groupBy("check").count().show()
+--------------------+---------------+
| fetuers| class|
+--------------------+---------------+
|[5.5, 3.5, 1.3, 0.2]| Iris-setosa|
|[6.9, 3.1, 5.4, 2.1]| Iris-virginica|
|[5.6, 2.5, 3.9, 1.1]|Iris-versicolor|
|[4.9, 3.0, 1.4, 0.2]| Iris-setosa|
|[5.1, 3.5, 1.4, 0.2]| Iris-setosa|
|[5.6, 2.7, 4.2, 1.3]|Iris-versicolor|
|[7.2, 3.2, 6.0, 1.8]| Iris-virginica|
|[5.0, 3.5, 1.3, 0.3]| Iris-setosa|
|[6.1, 3.0, 4.6, 1.4]|Iris-versicolor|
|[5.5, 2.4, 3.7, 1.0]|Iris-versicolor|
|[5.2, 3.4, 1.4, 0.2]| Iris-setosa|
从分类结果看,当k取10的时候,准确率可以达到96%
+-----+-----+
|check|count|
+-----+-----+
| 1| 53|
| 0| 2|
+-----+-----+
不足之处
- 很多时候数据在进行训练之前,需要做一定的 标准化处理,以消除量纲等的影响
- KNN算法在分类的过程中效率比较慢,而且在本文使用的方法当中用到笛卡尔积以及大规模的排序,对整个过程的效率都有较大影响。
- 当样本的各个类数量不平衡的时候,会造成结果误差,因此需要预先进行平衡采样。
- 在做分类的时候,没有加入权重等因素,仅根据投票数量来决定分类结果。
写在最后:如有不当之处,欢迎指正。