上一篇博文介绍了层次聚类算法的实现http://blog.csdn.net/c_son/article/details/43900503 ,可以发现其效率比较低下,因为每次迭代都要计算每两个聚簇之间的距离。这次的k-means算法在效率上要优于层次聚类算法。
算法实现:
1)从样本D中随机选取K个元素,作为K个簇的中心
2)分别计算剩下的元素到K个簇的距离,将这些元素归化到距离最短的簇
3)根据聚类结果,重新计算K个簇各自的中心,计算方法是取簇中所有元素各自维度的算术平均
4)将D中的元素按照新的中心重新聚类
5)重复第四步,直到中心不发生变化
6)将结果输出
这次所用的数据集还是层次聚类所使用的数据集:
A | 2 | 3 |
B | 2 | 7 |
C | 1 | 2 |
D | 1 | 6 |
E | 2 | 1 |
F | 3 | 5 |
G | 8 | 5 |
H | 9 | 6 |
I | 7 | 7 |
J | 7 | 4 |
K | 8 | 2 |
L | 8 | 22 |
M | 8 | 19 |
N | 7 | 21 |
O | 7 | 17 |
P | 9 | 20 |
其在二维坐标面上的表示如下:
package kmeansClustering;
/**
* @author shenchao
* 数据点
*
*/
public class Point {
private double X;
private double Y;
private String name;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public double getX() {
return X;
}
public void setX(double x) {
X = x;
}
public double getY() {
return Y;
}
public void setY(double y) {
Y = y;
}
@Override
public boolean equals(Object obj) {
Point point = (Point) obj;
if (this.getX() == point.getX() && this.getY() == point.getY()) {
return true;
}
return false;
}
@Override
public String toString() {
return "(" + X + "," + Y + ")";
}
@Override
public int hashCode() {
return (int) (X+Y);
}
}
将数据点进行封装,重写equals与hashCode方法。
package kmeansClustering;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
/**
* @author shenchao K-means 聚类算法
*/
public class KmeansClustering {
private List<Point> dataset = null;
public KmeansClustering() throws IOException {
initDataSet();
}
/**
* 初始化数据集
*
* @throws IOException
*/
private void initDataSet() throws IOException {
dataset = new ArrayList<Point>();
BufferedReader bufferedReader = new BufferedReader(
new InputStreamReader(KmeansClustering.class.getClassLoader()
.getResourceAsStream("data.txt")));
String line = null;
while ((line = bufferedReader.readLine()) != null) {
String[] s = line.split("\t");
Point point = new Point();
point.setX(Double.parseDouble(s[0]));
point.setY(Double.parseDouble(s[1]));
point.setName(s[2]);
dataset.add(point);
}
}
/**
* @param k
* 聚类的数目
*/
public Map<Point,List<Point>> kcluster(int k) {
// 随机从样本集合中选取k个样本点作为聚簇中心
// 每个聚簇中心有哪些点
Map<Point,List<Point>> nowClusterCenterMap = new HashMap<Point, List<Point>>();
for (int i = 0; i < k; i++) {
Random random = new Random();
int num = random.nextInt(dataset.size());
nowClusterCenterMap.put(dataset.get(num), new ArrayList<Point>());
}
//上一次的聚簇中心
Map<Point,List<Point>> lastClusterCenterMap = null;
// 找到离中心最近的点,然后加入以该中心为map键的list中
while (true) {
for (Point point : dataset) {
double shortest = Double.MAX_VALUE;
Point key = null;
for (Entry<Point, List<Point>> entry : nowClusterCenterMap.entrySet()) {
double distance = distance(point, entry.getKey());
if (distance < shortest) {
shortest = distance;
key = entry.getKey();
}
}
nowClusterCenterMap.get(key).add(point);
}
//如果结果与上一次相同,则整个过程结束
if (isEqualCenter(lastClusterCenterMap,nowClusterCenterMap)) {
break;
}
lastClusterCenterMap = nowClusterCenterMap;
nowClusterCenterMap = new HashMap<Point, List<Point>>();
//把中心点移到其所有成员的平均位置处,并构建新的聚簇中心
for (Entry<Point,List<Point>> entry : lastClusterCenterMap.entrySet()) {
nowClusterCenterMap.put(getNewCenterPoint(entry.getValue()), new ArrayList<Point>());
}
}
return nowClusterCenterMap;
}
/**
* 判断前后两次是否是相同的聚簇中心,若是则程序结束,否则继续,知道相同
* @param lastClusterCenterMap
* @param nowClusterCenterMap
* @return bool
*/
private boolean isEqualCenter(Map<Point, List<Point>> lastClusterCenterMap,
Map<Point, List<Point>> nowClusterCenterMap) {
if (lastClusterCenterMap == null) {
return false;
}else {
for (Entry<Point, List<Point>> entry : lastClusterCenterMap.entrySet()) {
if (!nowClusterCenterMap.containsKey(entry.getKey())) {
return false;
}
}
}
return true;
}
/**
* 计算新的中心
*
* @param value
* @return Point
*/
private Point getNewCenterPoint(List<Point> value) {
double sumX = 0.0, sumY = 0.0;
for (Point point : value) {
sumX += point.getX();
sumY += point.getY();
}
// System.out.println((int)sumX / value.size() + "===" + (int)sumY / value.size());
Point point = new Point();
point.setX(sumX / value.size());
point.setY(sumY / value.size());
return point;
}
/**
* 使用欧几里得算法计算两点之间距离
*
* @param point1
* @param point2
* @return 两点之间距离
*/
private double distance(Point point1, Point point2) {
double distance = Math.pow((point1.getX() - point2.getX()), 2)
+ Math.pow((point1.getY() - point2.getY()), 2);
distance = Math.sqrt(distance);
return distance;
}
public static void main(String[] args) throws IOException {
KmeansClustering kmeansClustering = new KmeansClustering();
Map<Point, List<Point>> result = kmeansClustering.kcluster(3);
for (Entry<Point, List<Point>> entry : result.entrySet()) {
System.out.println("===============聚簇中心为:" + entry.getKey() + "================");
for (Point point : entry.getValue()) {
System.out.println(point.getName());
}
}
}
}
样本点之间的距离计算这里仍然采用欧几里得距离算法。程序的输出如下,与坐标面显示相同。
算法不足之处:最终的结果与初始化K个中心值的选择有很大的关系,容易受噪音干扰。
OK,如有什么问题,欢迎大家和我一起学习交流。