Spark --基于DataFrame API实现KNN算法

Spark — 基于DataFrame API实现KNN算法

KNN简介

KNN(k-Nearest Neighbors)又称作k-近邻,核心思想用一句古话解释就是“近朱者赤,近墨者黑”,k-nn就是把未标记分类的案列归为与它们最相似的带有分类标记的案例所在的类。k-nn虽然简单,但是很强大。

KNN的特点

优点缺点
简单且有效不产生模型
训练阶段很快分类过程比较慢
对数据分布无要求模型解释性较差
适合稀疏时间和多分类问题名义变量和缺失数据需要额外处理
…………

实现步骤

  1. 计算距离:计算待测案例与训练样本之间的距离,常用的距离有欧式距离曼哈顿距离余弦距离等。
  2. 选择一个合适的k:确定用于KNN算法的邻居数量,一般用交叉验证或仅凭经验选择一个合适的k值,待测案例与训练样本之间距离最小的k个样本组成一个案例池。
  3. 类别判定:根据案例池的数据采用投票法或者加权投票法等方法来决定待测案例所属的类别。

代码实现

// 创建一个 KNN的方法 
def runKnn(trainSet: DataFrame, testSet: DataFrame, k: Int, cl: String): DataFrame = {

    val testFetures = testSet
      .drop(cl).map(row => {
      val fetuers: Seq[Double] = row.mkString(",").split(",").map(_.toDouble)
      fetuers
    }).toDF("fetuers")

    val trainFetures = trainSet.map(row => {
      val cla = row.getAs[String](cl)
      val fetuers: Seq[Double] = row.mkString(",")
        .split(",").filter(NumberUtils.isNumber(_)).map(_.toDouble)
      (cla, fetuers)
    }).toDF("class", "tfetuers")

    val spec = Window.partitionBy($"fetuers").orderBy($"distans")
    val spec2 = Window.partitionBy($"fetuers").orderBy($"count".desc)

    testFetures.crossJoin(trainFetures)
      .withColumn("distans", distanceUDF($"fetuers", $"tfetuers"))
      .drop("tfetuers")
      .withColumn("r", row_number().over(spec))
      .where($"r" <= k)
      .groupBy($"fetuers", $"class").count()
      .withColumn("r", row_number().over(spec2))
      .where($"r" === 1).drop("r", "count")  // 投票选择

  }

本案例使用的又是鸢尾花数据集

val iris = spark.read
      .option("header", true)
      .option("inferSchema", true)
      .csv(inputFile)

   // 将鸢尾花分成两部分:训练集和测试集
    val Array(testSet, trainSet) = iris.randomSplit(Array(0.3, 0.7), 1234L)

    val knnModel = new KNNModel(spark)
    // 调用方法,设置k为10,要进行分类的属性为“class”
    val res = knnModel.runKnn(trainSet, testSet, 10, "class")

完成分类之后,我们对结果做一个检测

 // 结果检验
    val testFetures = testSet
      .map(row => {
        val id = row.getAs[String]("class")
        val fetuers = row.mkString(",").split(",")
          .filter(NumberUtils.isNumber(_))
          .map(_.toDouble)
        (id, fetuers)
      }).toDF("yclass", "fetuers")

    val check = udf((f1: String, f2: String) => {
      if (f1.equals(f2)) 1 else 0
    })


     res.show()
     
    res.join(testFetures, "fetuers")
      .withColumn("check", check($"class", $"yclass"))
      .groupBy("check").count().show()

+--------------------+---------------+
|             fetuers|          class|
+--------------------+---------------+
|[5.5, 3.5, 1.3, 0.2]|    Iris-setosa|
|[6.9, 3.1, 5.4, 2.1]| Iris-virginica|
|[5.6, 2.5, 3.9, 1.1]|Iris-versicolor|
|[4.9, 3.0, 1.4, 0.2]|    Iris-setosa|
|[5.1, 3.5, 1.4, 0.2]|    Iris-setosa|
|[5.6, 2.7, 4.2, 1.3]|Iris-versicolor|
|[7.2, 3.2, 6.0, 1.8]| Iris-virginica|
|[5.0, 3.5, 1.3, 0.3]|    Iris-setosa|
|[6.1, 3.0, 4.6, 1.4]|Iris-versicolor|
|[5.5, 2.4, 3.7, 1.0]|Iris-versicolor|
|[5.2, 3.4, 1.4, 0.2]|    Iris-setosa|

从分类结果看,当k取10的时候,准确率可以达到96%

+-----+-----+
|check|count|
+-----+-----+
|    1|   53|
|    0|    2|
+-----+-----+

不足之处

  1. 很多时候数据在进行训练之前,需要做一定的 标准化处理,以消除量纲等的影响
  2. KNN算法在分类的过程中效率比较慢,而且在本文使用的方法当中用到笛卡尔积以及大规模的排序,对整个过程的效率都有较大影响。
  3. 当样本的各个类数量不平衡的时候,会造成结果误差,因此需要预先进行平衡采样。
  4. 在做分类的时候,没有加入权重等因素,仅根据投票数量来决定分类结果。

写在最后:如有不当之处,欢迎指正。

    原文作者:k_wzzc
    原文地址: https://www.jianshu.com/p/644299fff27d
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞