Kmeans聚类算法 java精简版设计实现编程

网上有许多Kmeans写的java算法,当然依据个人编码风格的不同,导致编写出来的代码,各有不同。所以在理解原理的基础上,最好就是按照自己设计思路将代码自己写出来。

度娘搜Kmeans的基本原理吧,直接上代码,代码中都有注释:

package net.codeal.suanfa.kmeans;

import java.util.Set;

/**
 * 
 * @ClassName: Distancable 
 * @Description: TODO(可计算两点之间距离的可中心化的父类) 
 * @author fuhuaguo
 * @date 2015年9月1日 上午11:41:23 
 *
 */
public class Kmeansable<E> {

	/**
	 * 获取两点之间的距离
	 * @param other
	 * @return
	 */
	public double getDistance(E other){
		return 0;
	}
	/**
	 * 获取新的中心点
	 * @param eSet
	 * @return
	 */
	public E getNewCenter(Set<E> eSet){
		return null;
	}
}

package net.codeal.suanfa.kmeans;

import java.util.Set;

/**
 * 
 * @ClassName: Point 
 * @Description: TODO(聚类的维度信息bean,可以分为K个维度,相似度计算是自身行为,放在bean内部才合适,取消注解使用) 
 * @author fuhuaguo
 * @email [email protected]
 * @date 2015年9月1日 上午10:43:25 
 *
 */
public class Point extends Kmeansable<Point>{

	private String id;
	//维度1
	private double k1;
	//维度2
	private double k2;
	//维度3
	private double k3;
	public Point() {
	}
	public Point(String id,double k1,double k2,double k3) {
		this.id = id;
		this.k1 = k1;
		this.k2 = k2;
		this.k3 = k3;
	}
	
	/**
	 * 计算和另一个点的距离,采用欧几里得算法 ,计算维度算数平方和的sqrt值,即:相异度
	 * @param other
	 * @return
	 */
	@Override
	public double getDistance(Point other){
		return Math.sqrt((this.k1-other.getK1())*(this.k1-other.getK1())
		+ (this.k2-other.getK2())*(this.k2-other.getK2())
		+ (this.k3-other.getK3())*(this.k3-other.getK3()));
	}
	@Override
	public Point getNewCenter(Set<Point> eSet) {
		if(eSet == null || eSet.size() == 0){
			return this;
		}
		Point temp = new Point();
		int count = 0;
		for (Point p : eSet) {
			temp.setK1(temp.getK1() + p.getK1());
			temp.setK2(temp.getK2() + p.getK2());
			temp.setK3(temp.getK3() + p.getK3());
			count++;
		}
		temp.setK1(temp.getK1()/count);
		temp.setK2(temp.getK2()/count);
		temp.setK3(temp.getK3()/count);
		
		return temp;
	}
	
	@Override
	public boolean equals(Object obj) {
		if(obj == null || !(obj instanceof Point))
			return false;
		Point other = (Point) obj;
		
		return (this.k1 == other.getK1()) && (this.k2 == other.getK2()) && (this.k3 == other.getK3());
	}
	
	@Override
	public int hashCode() {
		return new Double(k1+k2+k3).hashCode();
	}
	@Override
	public String toString() {
		return "("+k1+","+k2+","+k3+")";
	} 
	public String getId() {
		return id;
	}
	public void setId(String id) {
		this.id = id;
	}
	public double getK1() {
		return k1;
	}
	public void setK1(double k1) {
		this.k1 = k1;
	}
	public double getK2() {
		return k2;
	}
	public void setK2(double k2) {
		this.k2 = k2;
	}
	public double getK3() {
		return k3;
	}
	public void setK3(double k3) {
		this.k3 = k3;
	}
}

package net.codeal.suanfa.kmeans;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class KmeansAlgorithm<E extends Kmeansable<E>> {

	/**
	 * 对Set进行K个值聚类,计算深度最大为depth
	 */
	public void kmeans(Set<E> dataSet, int k, int depth){
		//分类数设置不合适
		if(k <= 1 || dataSet.size() <= k){
			return;
		}
		Set<E> kSet = new HashSet<E>();
		
		int count = 0;
		//随机确定K个中心点
		for (E e : dataSet) {
			if(count++ >= k)
				break;
			kSet.add(e);
		}
		//计算每个值距离各个中心点的距离,分配到距离最小的那个中心上
		boolean flag = true;
		while(flag && depth > 0){
			Map<E, Set<E>> kMap = new HashMap<E, Set<E>>();
			for (E e : kSet) {
				kMap.put(e, new HashSet<E>());
			}
			//完成聚类
			for (E data : dataSet) {
				double d = Double.MAX_VALUE;
				E e = null;
				for (E center : kSet) {
					double d1 = data.getDistance(center);
					if (d > d1){
						e = center;
						d = d1;
					}
				}
				kMap.get(e).add(data);
			}
			//第一组计算完毕,同时获取新的中心点
			System.out.println("这是第"+depth+"次聚类");
			for (Map.Entry<E, Set<E>> m : kMap.entrySet()) {
				System.out.println(m.getKey()+":"+m.getValue());
			}
			//获取新的聚类中心
			Set<E> oldSet = kSet;
			kSet = getNewCenters(kMap);
			flag = !isSameCenters(kSet,oldSet);
			
			depth--;
		}
	}
	/**
	 * 获取新的中心点 列表
	 */
	public Set<E> getNewCenters(Map<E, Set<E>> kMap){
		Set<E> eSet = new HashSet<E>();
		
		for (Map.Entry<E, Set<E>> m : kMap.entrySet()) {
			eSet.add(m.getKey().getNewCenter(m.getValue()));
		}
		
		return eSet;
	}
	/**
	 * 判断是否为同一个中心列表
	 */
	public boolean isSameCenters(Set<E> oldSet,Set<E> newSet){
		//两个集合只要交集为0就是相同的
		return oldSet.containsAll(newSet);
	}
	public static void main(String[] args) {
		Set<Point> dataSet = new HashSet<Point>();
		dataSet.add(new Point("1",1,1,1));
		dataSet.add(new Point("1",2,2,2));
		dataSet.add(new Point("1",5,6,1));
		dataSet.add(new Point("1",10,10,10));
		dataSet.add(new Point("1",11,11,11));
		
		new KmeansAlgorithm<Point>().kmeans(dataSet, 2,10);
	}
}

结果:

这是第10次聚类
(1.0,1.0,1.0):[(1.0,1.0,1.0), (2.0,2.0,2.0), (5.0,6.0,1.0)]
(10.0,10.0,10.0):[(10.0,10.0,10.0), (11.0,11.0,11.0)]
这是第9次聚类
(10.5,10.5,10.5):[(10.0,10.0,10.0), (11.0,11.0,11.0)]
(2.6666666666666665,3.0,1.3333333333333333):[(1.0,1.0,1.0), (2.0,2.0,2.0), (5.0,6.0,1.0)]

点赞