Scala语言实现Kmeans聚类算法

 Kmeans算法是一种简单的聚类算法,现在仍然广泛使用,其优点就是收敛速度快,人为干涉少,但是缺点也很明显:需要提前了解K值,以及聚类结果不稳定,

其原理可以参照:

http://blog.csdn.net/qll125596718/article/details/8243404/

这个博客关于原理解释的很清楚,并且给出了C/C++版和Java版,大家可以在深入了解一下,不过我仍然想强调一下我理解的一些细节:

       1.Kmeans首先随机找到K个聚类中心,是一个预分配的过程,然后通过计算新的聚类中心才是聚类过程。

        2.Kmeans算法得出结果以后,最后的聚类中心并不一定是给定的点集中的点,而是人为计算得到。

这里我将会利用Scala语言实现Kmeans算法,程序运行所需的数据,以及中间结果,还有大家最关心的代码注释,我都会详细给出解释。

我是采用eclipse开发的,结构图如下:

《Scala语言实现Kmeans聚类算法》

大家也可以在下面这个链接直接下载,不需要积分

http://download.csdn.net/detail/u014512572/9677346

       由于笔者本人现在仍然是菜鸟身份,所以很了解一个人摸索的痛苦,所以代码注释主要针对新手,各位大神不要见笑,不过在此之前,我有必要解释一下map()方法和reduce()方法:

1.  map()方法是映射方法,当你需要逐个处理数据集的时候就可以使用,例如:
points.map(_.toDouble)
points.map(a => a+1)
points.map(a => {
函数块
}) 上面的不管是通配符“_”还是a,都代表数据集里的每一个数据.        

2.  reduce()方法是规约方法,当你需要将数据集进行合并操作时可以使用,我一般喜欢使用reduceLeft(),这让我自己有一定的顺序感,例如下面这段在程序中出现的reduce()方法:

centers.reduceLeft((a, b) => 
     if ((vectorDis(a,point)) < (vectorDis(b,point))) a else b        这里的a,b就是数据集里的数据,既然需要进行规约,所以至少需要挑出两个数据,这里的(a,b)就代表数据集,然后通过规约条件,挑出你想要的数据

下面请看代码:   



/**
 * @author weixu_000
 */

import java.util.Random
import scala.io.Source
import java.io._

object Kmeans {

  val k = 5
  val dim = 41                  //这是我的数据集中每一组数据的维度
  val shold = 0.0000000001      //人为设定的阈值,最后用于判断偏移量 
  
  val centers =new Array[Vector[Double]](k)
  
  def main(args:Array[String]){
      
      //------------------------------------input data ------------------------

      val fileName = "data/testData.txt"
      val lines = Source.fromFile(fileName).getLines()
      val points =lines.map(line => {
             val parts = line.split(" ").map(_.toDouble)     //这里需要了解map()函数的特性,为了能够一次性调度一组数据,我们必须采用Vector类型数据
             var vector = Vector[Double]()                   //Vector类型是不可更改类型,但是可变长,可以利用这个特点将文本数据转为以Vector为元素的数组,即Array[Vector[Double]]类型
             for( i <- 0 to dim-1)                           //“_”这是通配符,使用map(),reduce()以及一些其他方法时经常用到,它表示你当前取出的元素,可以表示任何类型,所以称为通配符 
             vector ++= Vector(parts(i))
             vector
      }).toArray
 
      findCenters(points)
      kmeans(points,centers)
      putout(points,centers)
      
    }
  
  //-------------------------find centers----------------------------------  
  def findCenters(points:Array[Vector[Double]])={
     val rand = new Random(System.currentTimeMillis())
     val pointsNum = points.length
     for(i <- 0 to k-1){
        centers(i) =  points(rand.nextInt(points.length)-1)
     }

     val writerCenters = new PrintWriter(new File("data/centers.txt"))
     for(i <- 0 to k-1){
     writerCenters.println(centers(i))
     }
     writerCenters.close()
   }
   
  //-----------------------------doing cluster---------------------------- 
  def kmeans(points:Array[Vector[Double]],centers:Array[Vector[Double]])={
     var bool = true
     var index = 0
     while(bool){                                                
      
       //这里我们根据聚类中心利用groupBy()进行分组,最后得到的cluster是Map(Vector[Double],Array[Vector[Double]])类型
       //cluster共五个元素,Map中key值就是聚类中心,Value就是依赖于这个中心的点集
       val cluster = points.groupBy { closestCenter(centers,_) } 
       
       //通过Map集合的get()方法取出每一个簇,然后采用匹配方法match()进行求取新的中心,这里再强调一遍,Vector类型是不可更改类型,即数据存入Vector以后就不能改变
       //所以需要你人为的定义Vector类型的加减乘除运算
       val newCenters = centers.map { oldCenter => 
         cluster.get(oldCenter) match{
           case Some(pointsInCluster) => 
             vectorDivide(pointsInCluster.reduceLeft(vectorAdd(_,_)),pointsInCluster.length)
           case None => oldCenter
         }
        }
    
       var movement = 0d
       for(i <- 0 to k-1){
         movement += math.sqrt(vectorDis(centers(i),newCenters(i)))
         centers(i) = newCenters(i) 
       }
       if(movement <= shold){
         bool = false
       }
      index += 1
     }
   }
  
  //---------------------------putout----------------------------------------- 
   //我们最终需要输出的是聚类结果,我将每个点以“1,2,3,4,5”的形式输出,属于同一类的就是相同的数字
   //实在想不出更好的方法,只能再算一遍
   
  def putout(points:Array[Vector[Double]],centers:Array[Vector[Double]])={
     val pointsNum = points.length
     val pointLable = new Array[Int](pointsNum)
     for(i <- 0 to pointsNum-1){
        val temp = centers.reduceLeft((a,b) => 
        if ((vectorDis(a,points(i))) < (vectorDis(b,points(i))))  a
        else  b)
        pointLable(i) = centers.indexOf(temp)
     }

     val writerLable = new PrintWriter(new File("data/output.txt"))
     for(i <- 0 to pointsNum-1){
     writerLable.println(pointLable(i))
     }
      writerLable.close()
     
   }
    
  def vectorDis(v1:Vector[Double],v2:Vector[Double]):Double={
     var distance = 0d
        for(i <- 0 to dim-1){    
           distance += (v1(i)-v2(i))*(v1(i)-v2(i))
        }
        val distance = math.sqrt(t)                          
        distance
      }
   
  def vectorAdd(v1:Vector[Double],v2:Vector[Double])={
      val len=v1.length
      val av1=v1.toArray
      val av2=v2.toArray
      val av3=Array.fill(len)(0.0)
      var vector = Vector[Double]()
      for(i<-0 to len-1){
        av3(i)=av1(i)+av2(i)
        vector ++= Vector(av3(i))
      }
      vector
   }
   
  def vectorDivide(v1:Vector[Double],num:Int)={
      val av1=v1.toArray
      val len=v1.size
      val av2=Array.fill(len)(0.0)
      var vector = Vector[Double]()
      for(i<-0 to len-1){
        av2(i)=av1(i)/num
        vector ++= Vector(av2(i))
      }
      vector
   }
   
   /*
   def vectorAdd(v1:Vector[Double],v2:Vector[Double])={
     val  sumVector = Vector.fill(dim)(0.0)
        for(i <- 0 to dim-1){
          sumVector.updated(i, v1(i)+v2(i))
        }
     sumVector
   }

   def vectorDivide(v1:Vector[Double],num:Int)={
      for(i <- 0 to dim-1){
        v1.updated(i, v1(i)/num)
      }
      v1
   }
   * 
   */

  def closestCenter(centers:Array[Vector[Double]],point:Vector[Double])
   :Vector[Double]={
           centers.reduceLeft((a, b) => 
            if ((vectorDis(a,point)) < (vectorDis(b,point))) a else b
        )
        
   } 
   
  
}











        

tips:

1.写这个程序时,大家可以看到我经常使用println()输出中间结果,这也是一种找bug的过程

2.Vector类型的特性我当初不了解,就像最后我定义的那两个错的vectorAdd()和vectorDivide()方法

3.为了程序的可读性,我采用的函数块的方法编写,但是这样不得不设置一些全局变量,导致所需内存比较大,大家如果需要考虑到资源使用量,可以自行修改

    原文作者:聚类算法
    原文地址: https://blog.csdn.net/u014512572/article/details/53096465
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞