聚类算法:K-Means算法及其实现

        首先声明该文章部分参考文章

        K-means算法是很典型的基于距离的聚类算法,采用距离作为相似性的评价指标,即认为两个对象的距离越近,其相似度就越大。该算法认为簇是由距离靠近的对象组成的,因此把得到紧凑且独立的簇作为最终目标。
        对于聚类问题,我们事先并不知道给定的一个训练数据集到底具有哪些类别(即没有指定类标签),而是根据需要设置指定个数类标签的数量(但不知道具体的类标签是什么),然后通过K-means算法将具有相同特征,或者基于一定规则认为某一些对象相似,与其它一些组明显的不同的数据聚集到一起,自然形成分组。之后,我们可以根据每一组的数据的特点,给定一个合适的类标签(当然,可能给出类标签对实际应用没有实际意义,例如可能我们就想看一下聚类得到的各个数据集的相似性)。
        首先说明一个概念:质心(Centroid)。质心可以认为就是一个样本点,或者可以认为是数据集中的一个数据点P,它是具有相似性的一组数据的中心,即该组中每个数据点到P的距离都比到其他质心的距离近(与其他质心相似性比较低)。
        k个初始类聚类质心(Centroid)的选取对聚类结果具有较大的影响,因为在该算法第一步中是随机的选取任意k个对象作为初始聚类的质心,初始地代表一个聚类结果,当然这个结果一般情况不是合理的,只是随便地将数据集进行了一次随机的划分,具体进行修正这个质心还需要进行多轮的计算,来一步步逼近我们期望的聚类结果:具有相似性的对象聚集到一个组中,它们都具有共同的一个质心。
另外,因为初始质心选择的随机性,可能未必使最终的结果达到我们的期望,所以我们可以多次迭代,每次迭代都重新随机得到初始质心,直到最终的聚类结果能够满足我们的期望为止。

下面,我们描述一下K-means算法的过程:

(1)首先输入k的值,即我们希望将数据集D = {P1, P2, …, Pn}经过聚类得到k个分类(分组)。

(2)从数据集D中随机选择k个数据点作为质心,质心集合定义为:Centroid = {Cp1, Cp2, …, Cpk},排除质心以后数据集O={O1, O2, …, Om}。

(3)对集合O中每一个数据点Oi,计算Oi与Cpj(j=1, 2, …,k)的距离,得到一组距离Si={si1, si2, …, sik},计算Si中距离最小值,则该该数据点Oi就属于该最小距离值对应的质心。

(4)每个数据点Oi都已经属于其中一个质心,然后根据每个质心所包含的数据点的集合,重新计算得到一个新的质心。

(5)如果新计算的质心和原来的质心之间的距离达到某一个设置的阈值(表示重新计算的质心的位置变化不大,趋于稳定,或者说收敛),可以认为我们进行的聚类已经达到期望的结果,算法终止。

(6)如果新质心和原来之心距离变化很大,需要迭代2~5步骤。

        下面,根据参考链接,我们给出一个表达K-means聚类过程的图,描述了k=2时聚类的过程,更加直观一些,如图所示(问题太抽象了可以想象成对星空中的星星进行类聚):

《聚类算法:K-Means算法及其实现》

上图表示的聚类过程,简述如下:
(1)给定一个数据集,包含多个数据点;
(2)随机选择两个质心;
(3)计算数据集中数据点分别属于哪一个质心所在的组中,将数据集中所有数据点聚成2个组;
(4)根据上一步计算得到的2组数据点,分别重新计算出一个新的质心;
(5)重复步骤3,再进行一次聚类过程,得到2组数据点;
(6)再次计算新的质心,该次计算得到的质心与上一次计算得到的质心的距离变化很小(满足指定阈值,或收敛),则结果符合期望,停止聚类过程。

KMeans算法的实现:

数据集
3 3
4 10
9 6
14 8
18 11
21 7
/**
 * 二维平面,坐标点
 */

public class Point {
    private double x;
    private double y;

    public Point(double x, double y) {
        this.x = x;
        this.y = y;
    }

    //读文件初始化时需要用到
    public Point(String x, String y){
        this.x = Double.parseDouble(x);
        this.y = Double.parseDouble(y);
    }

    public double getX() {
        return x;
    }

    public void setX(double x) {
        this.x = x;
    }

    public double getY() {
        return y;
    }

    public void setY(double y) {
        this.y = y;
    }
}
import java.io.*;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

/**
 * KMeans算法的实现
 * 包装成工具类,直接传参数调用即可使用
 */

public class KMeans {
    //聚类的个数
    int cluterNum;
    //数据集中的点
    List<Point> points = new ArrayList<>();
    //簇的中心点
    List<Point> centerPoints = new ArrayList<>();
    //聚类结果的集合簇,key为聚类中心点在centerPoints中的下标,value为该类簇下的数据点
    HashMap<Integer, List<Point>> clusters = new HashMap<>();

    public KMeans(String path, int cluterNum){
        this.cluterNum = cluterNum;
        loadData(path);
    }

    //加载数据集
    public void loadData(String path) {
        File file = new File(path);
        try {
            BufferedReader br = new BufferedReader(new FileReader(file));
            String line;
            while ((line = br.readLine()) != null) {
                String[] strs = line.split(" ");
                points.add(new Point(strs[0], strs[1]));
            }
            br.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
        //初始化KMeans模型,这里选数据集前classNum个点作为初始中心点
        for (int i = 0; i < cluterNum; i++) {
            centerPoints.add(points.get(i));
            clusters.put(i, new ArrayList<>());
        }
    }

    //KMeans聚类
    public void doKMeans(){
        double err = Integer.MAX_VALUE;
        while (err > 0.01 * cluterNum){
            for (int key : clusters.keySet()){
                List<Point> list = clusters.get(key);
                list.clear();
                clusters.put(key, list);
            }
            //计算每个点所属类簇
            for (int i=0; i<points.size(); i++){
                dispatchPointToCluster(points.get(i), centerPoints);
            }
            //计算每个簇的中心点,并得到中心点偏移误差
            err = getClusterCenterPoint(centerPoints, clusters);
            show(centerPoints, clusters);
            System.out.println("*************************");
        }
    }

    //计算点对应的中心点,并将该点划分到距离最近的中心点的簇中
    public void dispatchPointToCluster(Point point, List<Point> centerPoints){
        int index = 0;
        double tmpMinDistance = Double.MAX_VALUE;
        for (int i=0; i<centerPoints.size(); i++){
            double distance = calDistance(point, centerPoints.get(i));
            if (distance < tmpMinDistance){
                tmpMinDistance = distance;
                index = i;
            }
        }
        List<Point> list = clusters.get(index);
        list.add(point);
        clusters.put(index, list);
    }

    //计算每个类簇的中心点,并返回中心点偏移误差
    public double getClusterCenterPoint(List<Point> centerPoints, HashMap<Integer, List<Point>> clusters){
        double error = 0;
        for (int i=0; i<centerPoints.size(); i++){
            Point tmpCenterPoint = centerPoints.get(i);
            double centerX = 0, centerY = 0;
            List<Point> lists = clusters.get(i);
            for (int j=0; j<lists.size(); j++){
                centerX += lists.get(j).getX();
                centerY += lists.get(j).getY();
            }
            centerX /= lists.size();
            centerY /= lists.size();
            error += Math.abs(centerX - tmpCenterPoint.getX());
            error += Math.abs(centerY - tmpCenterPoint.getY());
            centerPoints.set(i, new Point(centerX, centerY));
        }
        return error;
    }

    //计算点之间的距离,这里计算欧氏距离(不开方)
    public double calDistance(Point point1, Point point2){
        return Math.pow((point1.getX() - point2.getX()), 2) + Math.pow((point1.getY() - point2.getY()), 2);
    }

    //打印簇中心点坐标,及簇中其他点坐标
    public void show(List<Point> centerPoints, HashMap<Integer, List<Point>> clusters){
        for (int i=0; i<centerPoints.size(); i++){
            System.out.print(MessageFormat.format("类{0}的中心点: <{1}, {2}>",(i+1), centerPoints.get(i).getX(), centerPoints.get(i).getY()));
            List<Point> lists = clusters.get(i);
            System.out.print("\t类中成员点有:");
            for (int j=0; j<lists.size(); j++){
                System.out.print("<"+lists.get(j).getX()+ ","+ lists.get(j).getY()+">\t");
            }
            System.out.println();
        }
    }
}
/**
 * 客户端实现KMeans工具类的调用
 */

public class Client {
    public static void main(String[] args){
        String path = "D:\\Program\\MachineLearn\\src\\main\\java\\KMeans\\Data\\dataset.txt";
        KMeans kMeans = new KMeans(path, 3);
        kMeans.doKMeans();
    }
}

程序运行结果:
类1的中心点: <3, 3>	类中成员点有:<3.0,3.0>	
类2的中心点: <4, 10>	类中成员点有:<4.0,10.0>	
类3的中心点: <15.5, 8>	类中成员点有:<9.0,6.0>	<14.0,8.0>	<18.0,11.0>	<21.0,7.0>	
*************************
类1的中心点: <3, 3>	类中成员点有:<3.0,3.0>	
类2的中心点: <6.5, 8>	类中成员点有:<4.0,10.0>	<9.0,6.0>	
类3的中心点: <17.667, 8.667>	类中成员点有:<14.0,8.0>	<18.0,11.0>	<21.0,7.0>	
*************************
类1的中心点: <3, 3>	类中成员点有:<3.0,3.0>	
类2的中心点: <6.5, 8>	类中成员点有:<4.0,10.0>	<9.0,6.0>	
类3的中心点: <17.667, 8.667>	类中成员点有:<14.0,8.0>	<18.0,11.0>	<21.0,7.0>	
*************************

K-means算法的优点:

(1)算法框架清晰,简单,容易理解。
(2)本算法确定的k个划分到达平方误差最小。当聚类是密集的,且类与类之间区别明显时,效果较好。
(3)对于处理大数据集,这个算法是相对可伸缩和高效的,计算的复杂度为O(NKt),其中N是数据对象的数目,t是迭代的次数。一般来说,K<<N,t<<N 。

K-means算法的缺点:

(1)K-means算法中k是事先给定的,这个k值的选定是非常难以估计的。很多时候,事先并不知道给定的数据集应该分成多少个类别才最合适。这也是K-means算法的一个不足。有的算法是通过类的自动合并和分裂,得到较为合理的类型数目k,例如ISODATA算法。关于K-means算法中聚类数目k值的确定,有些文献中,是根据方差分析理论,应用混合F统计量来确定最佳分类数,并应用了模糊划分熵来验证最佳分类数的正确性,它使用了一种结合全协方差矩阵的RPCL算法,并逐步删除那些只包含少量训练数据的类,这是一种称为次胜者受罚的竞争学习规则,来自动决定类的适当数目。它的思想是:对每个输入而言,不仅竞争获胜单元的权值被修正以适应输入值,而且对次胜单元采用惩罚的方法使之远离输入值。
(2)在K-means算法中,首先需要根据初始聚类中心来确定一个初始划分,然后对初始划分进行优化。这个初始聚类中心的选择对聚类结果有较大的影响,一旦初始值选择的不好,可能无法得到有效的聚类结果,这也成为K-means算法的一个主要问题。对于该问题的解决,许多算法采用遗传算法(GA),以内部聚类准则作为评价指标。
(3)从K-means算法框架可以看出,该算法需要不断地进行样本分类调整,不断地计算调整后的新的聚类中心,因此当数据量非常大时,算法的时间开销是非常大的。所以需要对算法的时间复杂度进行分析、改进,提高算法应用范围,例如,可以从该算法的时间复杂度进行分析考虑,通过一定的相似性准则来去掉聚类中心的侯选集。在有些文献中,使用的K-means算法是对样本数据进行聚类,无论是初始点的选择还是一次迭代完成时对数据的调整,都是建立在随机选取的样本数据的基础之上,这样可以提高算法的收敛速度。
(4)K-means算法对异常数据很敏感。在计算质心的过程中,如果某个数据很异常,在计算均值的时候,会对结果影响非常大

K均值(K-Means)算法与K中心(K-Mediods)算法的区别:

以前总是搞混淆K-Means和K-Mediods算法,现在来区分一下。首先K-Means算法在上面讲的很清楚了,即输入聚类个数K以初始数据集,然后输出满足最小方差标准的K个聚类,过程如下图所示:

《聚类算法:K-Means算法及其实现》

由上图可以看到,KMeans算法不采用簇中对象作为簇中心,而KMediods算法为了减轻KMeans算法对孤立点的敏感性选用簇中距离平均值最近的对象作为簇的中心。其过程大致可以描述为,输入包含N个对象的初始数据集和簇数目K,输出划分好的K个簇,具体过程如下:

(1)随机选择K个代表对象作为初始的中心点;

(2)计算每个剩余对象到各个中心点的距离并分别划分到距离最近的中心点所代表的簇;

(3)随机选择一个非中心点对象Y,计算用Y代替中心点对象的总代价S;

(4)如果S为负则可用Y代替原中心点形成新的中心点,否则不改变原中心点;

(5)重复计算,赶到K个中心点都不再发生变化


参考资料:

http://shiyanjun.cn/archives/539.html

    原文作者:聚类算法
    原文地址: https://blog.csdn.net/u012050154/article/details/47834665
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞