网上有许多Kmeans写的java算法,当然依据个人编码风格的不同,导致编写出来的代码,各有不同。所以在理解原理的基础上,最好就是按照自己设计思路将代码自己写出来。
度娘搜Kmeans的基本原理吧,直接上代码,代码中都有注释:
package net.codeal.suanfa.kmeans;
import java.util.Set;
/**
*
* @ClassName: Distancable
* @Description: TODO(可计算两点之间距离的可中心化的父类)
* @author fuhuaguo
* @date 2015年9月1日 上午11:41:23
*
*/
public class Kmeansable<E> {
/**
* 获取两点之间的距离
* @param other
* @return
*/
public double getDistance(E other){
return 0;
}
/**
* 获取新的中心点
* @param eSet
* @return
*/
public E getNewCenter(Set<E> eSet){
return null;
}
}
package net.codeal.suanfa.kmeans;
import java.util.Set;
/**
*
* @ClassName: Point
* @Description: TODO(聚类的维度信息bean,可以分为K个维度,相似度计算是自身行为,放在bean内部才合适,取消注解使用)
* @author fuhuaguo
* @email [email protected]
* @date 2015年9月1日 上午10:43:25
*
*/
public class Point extends Kmeansable<Point>{
private String id;
//维度1
private double k1;
//维度2
private double k2;
//维度3
private double k3;
public Point() {
}
public Point(String id,double k1,double k2,double k3) {
this.id = id;
this.k1 = k1;
this.k2 = k2;
this.k3 = k3;
}
/**
* 计算和另一个点的距离,采用欧几里得算法 ,计算维度算数平方和的sqrt值,即:相异度
* @param other
* @return
*/
@Override
public double getDistance(Point other){
return Math.sqrt((this.k1-other.getK1())*(this.k1-other.getK1())
+ (this.k2-other.getK2())*(this.k2-other.getK2())
+ (this.k3-other.getK3())*(this.k3-other.getK3()));
}
@Override
public Point getNewCenter(Set<Point> eSet) {
if(eSet == null || eSet.size() == 0){
return this;
}
Point temp = new Point();
int count = 0;
for (Point p : eSet) {
temp.setK1(temp.getK1() + p.getK1());
temp.setK2(temp.getK2() + p.getK2());
temp.setK3(temp.getK3() + p.getK3());
count++;
}
temp.setK1(temp.getK1()/count);
temp.setK2(temp.getK2()/count);
temp.setK3(temp.getK3()/count);
return temp;
}
@Override
public boolean equals(Object obj) {
if(obj == null || !(obj instanceof Point))
return false;
Point other = (Point) obj;
return (this.k1 == other.getK1()) && (this.k2 == other.getK2()) && (this.k3 == other.getK3());
}
@Override
public int hashCode() {
return new Double(k1+k2+k3).hashCode();
}
@Override
public String toString() {
return "("+k1+","+k2+","+k3+")";
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public double getK1() {
return k1;
}
public void setK1(double k1) {
this.k1 = k1;
}
public double getK2() {
return k2;
}
public void setK2(double k2) {
this.k2 = k2;
}
public double getK3() {
return k3;
}
public void setK3(double k3) {
this.k3 = k3;
}
}
package net.codeal.suanfa.kmeans;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
public class KmeansAlgorithm<E extends Kmeansable<E>> {
/**
* 对Set进行K个值聚类,计算深度最大为depth
*/
public void kmeans(Set<E> dataSet, int k, int depth){
//分类数设置不合适
if(k <= 1 || dataSet.size() <= k){
return;
}
Set<E> kSet = new HashSet<E>();
int count = 0;
//随机确定K个中心点
for (E e : dataSet) {
if(count++ >= k)
break;
kSet.add(e);
}
//计算每个值距离各个中心点的距离,分配到距离最小的那个中心上
boolean flag = true;
while(flag && depth > 0){
Map<E, Set<E>> kMap = new HashMap<E, Set<E>>();
for (E e : kSet) {
kMap.put(e, new HashSet<E>());
}
//完成聚类
for (E data : dataSet) {
double d = Double.MAX_VALUE;
E e = null;
for (E center : kSet) {
double d1 = data.getDistance(center);
if (d > d1){
e = center;
d = d1;
}
}
kMap.get(e).add(data);
}
//第一组计算完毕,同时获取新的中心点
System.out.println("这是第"+depth+"次聚类");
for (Map.Entry<E, Set<E>> m : kMap.entrySet()) {
System.out.println(m.getKey()+":"+m.getValue());
}
//获取新的聚类中心
Set<E> oldSet = kSet;
kSet = getNewCenters(kMap);
flag = !isSameCenters(kSet,oldSet);
depth--;
}
}
/**
* 获取新的中心点 列表
*/
public Set<E> getNewCenters(Map<E, Set<E>> kMap){
Set<E> eSet = new HashSet<E>();
for (Map.Entry<E, Set<E>> m : kMap.entrySet()) {
eSet.add(m.getKey().getNewCenter(m.getValue()));
}
return eSet;
}
/**
* 判断是否为同一个中心列表
*/
public boolean isSameCenters(Set<E> oldSet,Set<E> newSet){
//两个集合只要交集为0就是相同的
return oldSet.containsAll(newSet);
}
public static void main(String[] args) {
Set<Point> dataSet = new HashSet<Point>();
dataSet.add(new Point("1",1,1,1));
dataSet.add(new Point("1",2,2,2));
dataSet.add(new Point("1",5,6,1));
dataSet.add(new Point("1",10,10,10));
dataSet.add(new Point("1",11,11,11));
new KmeansAlgorithm<Point>().kmeans(dataSet, 2,10);
}
}
结果:
这是第10次聚类
(1.0,1.0,1.0):[(1.0,1.0,1.0), (2.0,2.0,2.0), (5.0,6.0,1.0)]
(10.0,10.0,10.0):[(10.0,10.0,10.0), (11.0,11.0,11.0)]
这是第9次聚类
(10.5,10.5,10.5):[(10.0,10.0,10.0), (11.0,11.0,11.0)]
(2.6666666666666665,3.0,1.3333333333333333):[(1.0,1.0,1.0), (2.0,2.0,2.0), (5.0,6.0,1.0)]