JAVA实现K-means聚类

转载请注明出处:http://blog.csdn.net/xiaojimanman/article/details/51086879

http://www.llwjy.com/blogdetail/bf27dd0be964886d11185743779e40e0.html

个人博客站已经上线了,网址 www.llwjy.com ~欢迎各位吐槽~

————————————————————————————————-

      在开始之前先打一个小小的广告,自己创建一个QQ羣:321903218,点击链接加入羣【Lucene案例开发】,主要用于交流如何使用Lucene来创建站内搜索后台,同时还会不定期的在羣内开相关的公开课,感兴趣的童鞋可以加入交流。

      在上一篇博客中已经介绍了KNN分类算法,这篇博客将重点介绍下K-means聚类算法。K-means算法是比较经典的聚类算法,算法的基本思想是选取K个点(随机)作为中心进行聚类,然后对聚类的结果计算该类的质心,通过迭代的方法不断更新质心,直到质心不变或稍微移动为止,则最后的聚类结果就是最后的聚类结果。下面首先介绍下K-means具体的算法步骤。

K-means算法

      在前面已经大概的介绍了下K-means,下面就介绍下具体的算法描述:

1)选取K个点作为初始质心;

2)对每个样本分别计算到K个质心的相似度或距离,将该样本划分到相似度最高或距离最短的质心所在类;

3)对该轮聚类结果,计算每一个类别的质心,新的质心作为下一轮的质心;

4)判断算法是否满足终止条件,满足终止条件结束,否则继续第2、3、4步。

      在介绍算法之前,我们首先看下K-means算法聚类平面200,000个点聚成34个类别的结果(如下图)

《JAVA实现K-means聚类》

算法实现

      K-means聚类算法整体思想比较简单,下面 就分步介绍如何用JAVA来实现K-means算法。

一、K-means算法基础属性

      在K-means算法中,有几个重要的指标,比如K值、最大迭代次数等,对于这些指标,我们统一把它们设置为类的属性,如下:

private List<T> dataArray;//待分类的原始值
private int K = 3;//将要分成的类别个数
private int maxClusterTimes = 500;//最大迭代次数
private List<List<T>> clusterList;//聚类的结果
private List<T> clusteringCenterT;//质心

二、初始质心的选择

      K-means聚类算法的结果很大程度收到初始质心的选取,这了为了保证有充分的随机性,对于初始质心的选择这里采用完全随机的方法,先把待分类的数据随机打乱,然后把前K个样本作为初始质心(通过多次迭代,会减少初始质心的影响)。

List<T> centerT = new ArrayList<T>(size);
//对数据进行打乱
Collections.shuffle(dataArray);
for (int i = 0; i < size; i++) {
	centerT.add(dataArray.get(i));
}

三、一轮聚类

      在K-means算法中,大部分的时间都在做一轮一轮的聚类,具体功能也很简单,就是对每一个样本分别计算和所有质心的相似度或距离,找到与该样本最相似的质心或者距离最近的质心,然后把该样本划分到该类中,具体逻辑介绍参照代码中的注释。

private void clustering(List<T> preCenter, int times) {
	if (preCenter == null || preCenter.size() < 2) {
		return;
	}
	//打乱质心的顺序
	Collections.shuffle(preCenter);
	List<List<T>> clusterList =  getListT(preCenter.size());
	for (T o1 : this.dataArray) {
		//寻找最相似的质心
		int max = 0;
		double maxScore = similarScore(o1, preCenter.get(0));
		for (int i = 1; i < preCenter.size(); i++) {
			if (maxScore < similarScore(o1, preCenter.get(i))) {
				maxScore = similarScore(o1, preCenter.get(i));
				max = i;
			}
		}
		clusterList.get(max).add(o1);
	}
	//计算本次聚类结果每个类别的质心
	List<T> nowCenter = new ArrayList<T> ();
	for (List<T> list : clusterList) {
		nowCenter.add(getCenterT(list));
	}
	//是否达到最大迭代次数
	if (times >= this.maxClusterTimes || preCenter.size() < this.K) {
		this.clusterList = clusterList;
		return;
	}
	this.clusteringCenterT = nowCenter;
	//判断质心是否发生移动,如果没有移动,结束本次聚类,否则进行下一轮
	if (isCenterChange(preCenter, nowCenter)) {
		clear(clusterList);
		clustering(nowCenter, times + 1);
	} else {
		this.clusterList = clusterList;
	}
}

四、质心是否移动

      在第三步中,提到了一个重要的步骤:每轮聚类结束后,都要重新计算质心,并且计算质心是否发生移动。对于新质心的计算、样本之间的相似度和判断两个样本是否相等这几个功能由于并不知道样本的具体数据类型,因此把他们定义成抽象方法,供子类来实现。下面就重点介绍如何判断质心是否发生移动。

private boolean isCenterChange(List<T> preT, List<T> nowT) {
	if (preT == null || nowT == null) {
		return false;
	}
	for (T t1 : preT) {
		boolean bol = true;
		for (T t2 : nowT) {
			if (equals(t1, t2)) {//t1在t2中有相等的,认为该质心未移动
				bol = false;
				break;
			}
		}
		//有一个质心发生移动,认为需要进行下一次计算
		if (bol) {
			return bol;
		}
	}
	return false;
}

      从上述代码可以看到,算法的思想就是对于前后两个质心数组分别前一组的质心是否在后一个质心组中出现,有一个没有出现,就认为质心发生了变动。

完整代码

      上面四步已经完整的介绍了K-means算法的具体算法思想,下面就看下完整的代码实现。

 /**  
 *@Description:  K-means聚类
 */ 
package com.lulei.datamining.knn;  

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
  
public abstract class KMeansClustering <T>{
	private List<T> dataArray;//待分类的原始值
	private int K = 3;//将要分成的类别个数
	private int maxClusterTimes = 500;//最大迭代次数
	private List<List<T>> clusterList;//聚类的结果
	private List<T> clusteringCenterT;//质心
	
	public int getK() {
		return K;
	}
	public void setK(int K) {
		if (K < 1) {
			throw new IllegalArgumentException("K must greater than 0");
		}
		this.K = K;
	}
	public int getMaxClusterTimes() {
		return maxClusterTimes;
	}
	public void setMaxClusterTimes(int maxClusterTimes) {
		if (maxClusterTimes < 10) {
			throw new IllegalArgumentException("maxClusterTimes must greater than 10");
		}
		this.maxClusterTimes = maxClusterTimes;
	}
	public List<T> getClusteringCenterT() {
		return clusteringCenterT;
	}
	/**
	 * @return
	 * @Author:lulei  
	 * @Description: 对数据进行聚类
	 */
	public List<List<T>> clustering() {
		if (dataArray == null) {
			return null;
		}
		//初始K个点为数组中的前K个点
		int size = K > dataArray.size() ? dataArray.size() : K;
		List<T> centerT = new ArrayList<T>(size);
		//对数据进行打乱
		Collections.shuffle(dataArray);
		for (int i = 0; i < size; i++) {
			centerT.add(dataArray.get(i));
		}
		clustering(centerT, 0);
		return clusterList;
	}
	
	/**
	 * @param preCenter
	 * @param times
	 * @Author:lulei  
	 * @Description: 一轮聚类
	 */
	private void clustering(List<T> preCenter, int times) {
		if (preCenter == null || preCenter.size() < 2) {
			return;
		}
		//打乱质心的顺序
		Collections.shuffle(preCenter);
		List<List<T>> clusterList =  getListT(preCenter.size());
		for (T o1 : this.dataArray) {
			//寻找最相似的质心
			int max = 0;
			double maxScore = similarScore(o1, preCenter.get(0));
			for (int i = 1; i < preCenter.size(); i++) {
				if (maxScore < similarScore(o1, preCenter.get(i))) {
					maxScore = similarScore(o1, preCenter.get(i));
					max = i;
				}
			}
			clusterList.get(max).add(o1);
		}
		//计算本次聚类结果每个类别的质心
		List<T> nowCenter = new ArrayList<T> ();
		for (List<T> list : clusterList) {
			nowCenter.add(getCenterT(list));
		}
		//是否达到最大迭代次数
		if (times >= this.maxClusterTimes || preCenter.size() < this.K) {
			this.clusterList = clusterList;
			return;
		}
		this.clusteringCenterT = nowCenter;
		//判断质心是否发生移动,如果没有移动,结束本次聚类,否则进行下一轮
		if (isCenterChange(preCenter, nowCenter)) {
			clear(clusterList);
			clustering(nowCenter, times + 1);
		} else {
			this.clusterList = clusterList;
		}
	}
	
	/**
	 * @param size
	 * @return
	 * @Author:lulei  
	 * @Description: 初始化一个聚类结果
	 */
	private List<List<T>> getListT(int size) {
		List<List<T>> list = new ArrayList<List<T>>(size);
		for (int i = 0; i < size; i++) {
			list.add(new ArrayList<T>());
		}
		return list;
	}
	
	/**
	 * @param lists
	 * @Author:lulei  
	 * @Description: 清空无用数组
	 */
	private void clear(List<List<T>> lists) {
		for (List<T> list : lists) {
			list.clear();
		}
		lists.clear();
	}
	
	/**
	 * @param value
	 * @Author:lulei  
	 * @Description: 向模型中添加记录
	 */
	public void addRecord(T value) {
		if (dataArray == null) {
			dataArray = new ArrayList<T>();
		}
		dataArray.add(value);
	}
	
	/**
	 * @param preT
	 * @param nowT
	 * @return
	 * @Author:lulei  
	 * @Description: 判断质心是否发生移动
	 */
	private boolean isCenterChange(List<T> preT, List<T> nowT) {
		if (preT == null || nowT == null) {
			return false;
		}
		for (T t1 : preT) {
			boolean bol = true;
			for (T t2 : nowT) {
				if (equals(t1, t2)) {//t1在t2中有相等的,认为该质心未移动
					bol = false;
					break;
				}
			}
			//有一个质心发生移动,认为需要进行下一次计算
			if (bol) {
				return bol;
			}
		}
		return false;
	}
	
	/**
	 * @param o1
	 * @param o2
	 * @return
	 * @Author:lulei  
	 * @Description: o1 o2之间的相似度
	 */
	public abstract double similarScore(T o1, T o2);
	
	/**
	 * @param o1
	 * @param o2
	 * @return
	 * @Author:lulei  
	 * @Description: 判断o1 o2是否相等
	 */
	public abstract boolean equals(T o1, T o2);
	
	/**
	 * @param list
	 * @return
	 * @Author:lulei  
	 * @Description: 求一组数据的质心
	 */
	public abstract T getCenterT(List<T> list);
}

二维数聚类实现

      在算法描述中,介绍了一个200,000个点聚成34个类别的效果图,下面就针对二维座标数据实现其具体子类。

一、相似度

      对于二维座标的相似度,这里我们采取两点间聚类的相反数,具体实现如下:

	@Override
	public double similarScore(XYbean o1, XYbean o2) {
		double distance = Math.sqrt((o1.getX() - o2.getX()) * (o1.getX() - o2.getX()) + (o1.getY() - o2.getY()) * (o1.getY() - o2.getY()));
		return distance * -1;
	}

二、样本/质心是否相等

      判断样本/质心是否相等只需要判断两点的座标是否相等即可,具体实现如下:

	@Override
	public boolean equals(XYbean o1, XYbean o2) {
		return o1.getX() == o2.getX() && o1.getY() == o2.getY();
	}

三、获取一个分类下的新质心

      对于二维座标数据,可以使用所有点的重心作为分类的质心,具体如下:

	@Override
	public XYbean getCenterT(List<XYbean> list) {
		int x = 0;
		int y = 0;
		try {
			for (XYbean xy : list) {
				x += xy.getX();
				y += xy.getY();
			}
			x = x / list.size();
			y = y / list.size();
		} catch(Exception e) {
			
		}
		return new XYbean(x, y);
	}

四、main方法

      对于具体二维座标的源码这里就不再贴出来,就是实现前面介绍的抽象类,并实现其中的3个抽象方法,下面我们就随机产生200,000个点,然后聚成34个类别,具体代码如下:

	public static void main(String[] args) {
		
		int width = 600;
		int height = 400;
		int K = 34;
		XYCluster xyCluster = new XYCluster();
		for (int i = 0; i < 200000; i++) {
			int x = (int)(Math.random() * width) + 1;
			int y = (int)(Math.random() * height) + 1;
			xyCluster.addRecord(new XYbean(x, y));
		}
		xyCluster.setK(K);
		long a = System.currentTimeMillis();
		List<List<XYbean>> cresult = xyCluster.clustering();
		List<XYbean> center = xyCluster.getClusteringCenterT();
		System.out.println(JsonUtil.parseJson(center));
		long b = System.currentTimeMillis();
		System.out.println("耗时:" + (b - a) + "ms");
		new ImgUtil().drawXYbeans(width, height, cresult, "d:/2.png", 0, 0);
	}

      对于这随机产生的200,000个点聚成34类,总耗时5485ms。(计算机配置:i5 + 8G内存)

————————————————————————————————-
小福利
————————————————————————————————-
      个人在极客学院上《Lucene案例开发》课程已经上线了,欢迎大家吐槽~

第一课:Lucene概述

第二课:Lucene 常用功能介绍

第三课:网络爬虫

第四课:数据库连接池

第五课:小说网站的采集

第六课:小说网站数据库操作

第七课:小说网站分布式爬虫的实现

第八课:Lucene实时搜索

第九课:索引的基础操作

点赞