对于K-Means算法想必做机器学习和数据挖掘的广大同胞们已经不再陌生,做为数据挖据的十大经典算法之一,k-Means做聚类分析上有得天独厚的优势。对于其原理进行简单的描述:
k-Means算法是典型的基于距离的聚类算法,采用的是距离作为相似性指标。经过n次迭代后,当中心的位置不在发生变换的时候即是收敛完成。
算法:
1. 从n个文档中随机的选择出k个文档作为质心
2.从剩余的文档中测量出每个文档到质心的距离,并归类到最小质心的一类中
3. 重新计算质心的位置
4.重复2-3步,直到迭代完成。
由以上步骤,可以有java实现K-Means算法。随机产生100个点,设置k=5后进行聚类操作:
1.主函数:
package KMeans;
import java.util.ArrayList;
/**
* K-Means算法
* @author Administrator
*
*/
public class k_means {
/**
* @param args
*/
public static void main(String[] args) {
//1.创建二维数组 10x10的数组
int num_1[]=new int[100];
int num_2[]=new int[100];
//随机赋值
for(int i=0;i<100;i++){
num_1[i]=(int)( Math.random()*100);
}
for(int i=0;i<100;i++){
num_2[i]=(int)( Math.random()*100);
}
// 2.创建点坐标
ArrayList<pointBean> list=new ArrayList<pointBean>();
pointBean bean;
for(int i=0;i<100;i++){
bean=new pointBean();
bean.point_x=num_1[i];
bean.point_y=num_2[i];
list.add(bean);
}
// 执行k-means算法
getDataKMeans gg=new getDataKMeans();
gg.setData(list);
}
}
2.设置点的x,y坐标的Bean
package KMeans;
public class pointBean {
int point_x;
int point_y;
public int getPoint_x() {
return point_x;
}
public void setPoint_x(int point_x) {
this.point_x = point_x;
}
public int getPoint_y() {
return point_y;
}
public void setPoint_y(int point_y) {
this.point_y = point_y;
}
@Override
public String toString() {
return "pointBean [point_x=" + point_x + ", point_y=" + point_y + "]";
}
public pointBean(int point_x, int point_y) {
super();
this.point_x = point_x;
this.point_y = point_y;
}
public pointBean() {
super();
}
}
3.聚类的计算部分:
package KMeans;
import java.util.ArrayList;
public class getDataKMeans {
int k=5;//k值
//第一个中心点x,y
static double con1_x;
static double con1_y;
//第一个中心点x,y
static double con2_x;
static double con2_y;
//第一个中心点x,y
static double con3_x;
static double con3_y;
//第一个中心点x,y
static double con4_x;
static double con4_y;
//第一个中心点x,y
static double con5_x;
static double con5_y;
//创建5个list装各个点
ArrayList<pointBean> list1=new ArrayList<pointBean>();
ArrayList<pointBean> list2=new ArrayList<pointBean>();
ArrayList<pointBean> list3=new ArrayList<pointBean>();
ArrayList<pointBean> list4=new ArrayList<pointBean>();
ArrayList<pointBean> list5=new ArrayList<pointBean>();
public void setData(ArrayList<pointBean> list){
con1_x=list.get(0).point_x;
con1_y=list.get(0).point_y;
con2_x=list.get(1).point_x;
con2_y=list.get(1).point_y;
con3_x=list.get(2).point_x;
con3_y=list.get(2).point_y;
con4_x=list.get(3).point_x;
con4_y=list.get(3).point_y;
con5_x=list.get(4).point_x;
con5_y=list.get(4).point_y;
//分别加入list中
list1.add(list.get(0));
list2.add(list.get(1));
list3.add(list.get(2));
list4.add(list.get(3));
list5.add(list.get(4));
//循环操作
for(int i=5;i<list.size();i++){
getLength(list.get(i));
}
// 打印出对应的中心点 、聚类的值
System.out.println("-------1-------");
System.out.println("1的中心点:"+con1_x+" "+con1_y);
for(int i=0;i<list1.size();i++){
System.out.println(list1.get(i).point_x+" "+list1.get(i).point_y);
}
System.out.println("-------2-------");
System.out.println("2的中心点:"+con2_x+" "+con2_y);
for(int i=0;i<list2.size();i++){
System.out.println(list2.get(i).point_x+" "+list2.get(i).point_y);
}
System.out.println("-------3-------");
System.out.println("3的中心点:"+con3_x+" "+con3_y);
for(int i=0;i<list3.size();i++){
System.out.println(list3.get(i).point_x+" "+list3.get(i).point_y);
}
System.out.println("-------4-------");
System.out.println("4的中心点:"+con4_x+" "+con4_y);
for(int i=0;i<list4.size();i++){
System.out.println(list4.get(i).point_x+" "+list4.get(i).point_y);
}
System.out.println("-------5-------");
System.out.println("5的中心点:"+con5_x+" "+con5_y);
for(int i=0;i<list5.size();i++){
System.out.println(list5.get(i).point_x+" "+list5.get(i).point_y);
}
}
/**
* 求出每个点到中心点距离
* @param point
*/
public void getLength(pointBean point) {
int x=point.point_x;
int y=point.point_y;
double s1=(x-con1_x)*(x-con1_x)+(y-con1_y)*(y-con1_y);
double s2=(x-con2_x)*(x-con2_x)+(y-con2_y)*(y-con2_y);
double s3=(x-con3_x)*(x-con3_x)+(y-con3_y)*(y-con3_y);
double s4=(x-con4_x)*(x-con4_x)+(y-con4_y)*(y-con4_y);
double s5=(x-con5_x)*(x-con5_x)+(y-con5_y)*(y-con5_y);
double nn[]={s1,s2,s3,s4,s5};
// 找出最小的一个
double temp=nn[0];
for(int i=1;i<nn.length;i++){
if(nn[i]<=temp)
temp=nn[i];
}
// 添加点
if(temp==s1){
list1.add(point);
upDataPoint(list1,con1_x,con1_y);
}
if(temp==s2){
list2.add(point);
upDataPoint(list2,con2_x,con2_x);
}
if(temp==s3){
list3.add(point);
upDataPoint(list3,con3_x,con3_x);
}
if(temp==s4){
list4.add(point);
upDataPoint(list4,con4_x,con4_x);
}
if(temp==s5){
list5.add(point);
upDataPoint(list5,con5_x,con5_x);
}
}
/**
* 更新中心点坐标
* @param list
*/
private void upDataPoint(ArrayList<pointBean> list,double x,double y) {
double up_x=0;
double up_y=0;
for(int i=0;i<list.size();i++){
up_x+=list.get(i).point_x;
up_y+=list.get(i).point_y;
}
x=up_x/(list.size());
y=up_y/(list.size());
}
}
得到的测试结果:
-------1-------
1的中心点:37.0 80.0
37 80
54 88
10 50
45 85
40 95
51 87
47 90
42 97
30 61
20 63
60 80
37 93
47 79
37 96
58 86
-------2-------
2的中心点:55.0 57.0
55 57
89 81
56 58
49 53
58 62
42 52
26 49
94 95
21 44
1 19
27 53
59 74
61 77
32 56
49 54
10 39
53 55
48 58
8 36
63 63
4 26
49 62
63 80
45 62
-------3-------
3的中心点:79.0 22.0
79 22
84 41
68 17
79 2
99 33
69 11
70 29
52 8
94 25
81 8
54 20
81 32
81 34
48 2
22 1
89 27
57 18
42 11
50 6
74 28
98 27
98 36
-------4-------
4的中心点:59.0 51.0
59 51
64 40
21 8
5 19
29 32
62 40
7 5
16 25
53 36
28 29
33 19
80 55
50 40
98 76
81 53
23 23
92 62
85 63
65 36
48 44
25 30
11 15
97 79
16 9
60 43
59 51
67 43
-------5-------
5的中心点:10.0 97.0
10 97
14 74
13 89
1 60
4 94
6 72
1 73
4 86
18 80
2 81
11 70
19 97