KNN分类器-Java实现

KNN,即K近邻算法。其基本思想或者说是实现步骤如下:
(1)计算样本数据点到每个已知类别的数据集中点的距离
(2)将(1)中得到的距离按递增顺序排列
(3)选取(2)中前K个点(即与当前样本距离最小的K个已知类别的数据点)
(4)统计(3)中得到的K个点所在类别的出现频率
(5)返回(4)中出现频率最高的类别作为样本点的预测类别
在给出具体实现代码之前,说明一点:Java下的矩阵操作类基于开源jama包,我自己基于它的源码,做了部分必要的扩充和修改。
具体实现代码如下:

/** * Created by Song on 2016/9/30. */
public class KnnHandler implements DMHandler {
    //训练集中,每个特征的最小值
    private Matrix minVals;
    //训练集中,每个特征的最大值
    private Matrix maxVals;
    //训练集中,每个特征的取值范围
    private Matrix ranges;

    public KnnHandler(Matrix dataSet){
        double [][] minMax = dataSet.getMinMax();
        this.minVals = new Matrix(minMax[0],1);
        this.maxVals = new Matrix(minMax[1],1);
        this.ranges = maxVals.minus(minVals);
    }
    /** * 归一化特征值 * @param dataSet 特征集 */
    public Matrix autoNorm(Matrix dataSet){
        double[][] norm = dataSet.getArray();
        for(int j=0;j<dataSet.getColumnDimension();j++){
            for(int i=0;i<norm.length;i++){
                norm[i][j] = (norm[i][j]-minVals.get(0,j))/ranges.get(0,j);
            }
        }
        return new Matrix(norm);
    }

    /** * K近邻算法 * @param sample 待评估样本 * @param dataSet 数据集 * @param labels 数据集中,每行数据对应的类别 * @param rate 将距离按由小至大排列,按比例选择固定数量的类别 */
    public double classify(Matrix sample,Matrix dataSet,Matrix labels,double rate){
        //统计样本频率
        HashMap<Double,Integer> levels = new HashMap<Double, Integer>();
        //遍历类别,得出一共有几类
        for(int i=0;i<labels.getRowDimension();i++){
            if(!levels.containsKey(labels.get(i,0))) levels.put(labels.get(i,0),0);
        }
        //获得距离,并递增排序
        Matrix sortedDistance = sample.distance(dataSet).expand(labels,true).sort();
        //取前num个数据
        int num = (int)Math.ceil(sortedDistance.getRowDimension()*rate);
        for(int i=0;i<num;i++){
            levels.put(sortedDistance.get(i,1),levels.get(sortedDistance.get(i,1))+1);
        }
        //按频率排序
        double targetLevel = 0;
        int count = 0;
        for(double key:levels.keySet()){
            if(levels.get(key)>count) {
                count = levels.get(key);
                targetLevel = key;
            }
        }
        return targetLevel;
    }
    //测试
    public static void main(String [] args){
        //随机生成训练集(已知类别)
        Random random = new Random();
        double [][] dataSet = new double[100][4];
        for(int i=0;i<100;i++){
            for(int j=0;j<4;j++){
                dataSet[i][j]=random.nextInt(10);
            }
        }
        //训练集中100组数据对应的类别
        double [] lables = new double[100];
        for(int i=0;i<100;i++){
            lables[i]=i/10;
        }
        //生成待分类样本
        double [] sample = {1,2,3,4};
        //KNN操作类实例化
        KnnHandler handler = new KnnHandler(new Matrix(dataSet));
        //handler.autoNorm(new Matrix(dataSet)).print(4,3);
        //输出分类结果
        System.out.println(handler.classify(new Matrix(sample,1),new Matrix(dataSet),new Matrix(lables,1).transpose(),0.3));
    }
}

其中部分函数,例如构造器中获得数据集中每个特征的最小最大取值(即一个二维数组中每列值的最小最大值)方法getMinMax()等,都是自己基于jama源码扩充得到的,原理很简单,此处就不列出来了。
可以看出,KNN分类是一种非常基础的分类算法,适用于数值型数据。通过计算未知数据点到已知数据点的距离,来判断其具体分类。

点赞