KD tree算法(1)-简介&构建KD tree

KD tree算法是KNN(K-nearest neighbor)实现的重要算法之一,下面我们先简单介绍一些KNN的知识,然后开始我们KD tree的讲解。

KNN分类算法

KNN是一种简单的分类方法:

分类时,对新的实例,根据k个最近邻的训练实例的类别,通过多数表决等方式进行预测。

以上是《统计学习方法》中对KNN的解释,简单明了。那么问题就在于如何快速有效找到K个近邻就是该算法的关键了。关于KNN算法的详细知识,请看另一篇文章《统计学习方法之K近邻法》,下面我们来简单学习一下如何快速找到K个近邻的样本。

KD Tree

为了对训练数据进行快速k近邻搜索,我们使用特殊的数据结构存储训练数据-kd Tree方法。

kd Tree是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd树是二叉树,表示对k维空间的一个划分(partition),构造kd Tree相当于奖k维空间划分,构成一系列的k维超矩形区域。

以上是《统计学习方法》对KD Tree的解释,同样简单明了。
KD Tree方法即将空间根据坐标不断划分,划分成k块,如下图所示:

《KD tree算法(1)-简介&构建KD tree》

  1. 构建 KD-Tree
    《KD tree算法(1)-简介&构建KD tree》
    《KD tree算法(1)-简介&构建KD tree》

    以上是《统计学习方法》书中对构造平衡kd Tree的算法描述。我们进行分析。
    1.1 为什么要构造“平衡“kd Tree?
    学过快速排序的读者们都知道,当我们取得的K值刚好保证左右两边的数的个数相等时,快排算法的时间复杂度最低,而当K值取到最大或最小时算法时间复杂度最高。为此提出了随机快速排序算法。以此为基础,我们可以分析知道为什么要构造“平衡“kd Tree了。很简单,为了提高搜索效率。
    为了更加简单明了的解释,我们据如下例子:
    1 2 3 4 5 6 7 8 9
    我们将以上数字分为两部分:
    1) >=5;
    2) <5;
    为了找到数字3,如果我们知道3比5小,那么我们只需要搜索5左边的四个数字即可。
    如果我们将以上数字分成以下两部分:
    1) >=9;
    2) <9;
    我们知道3比9小,那么我们需要搜索9左边的8个数字,这样的效率明显没有上一种的效率高。
    2.2 如何构造平衡kd Tree?
    中位数法。

    中位数:将一组数据按大小顺序排列,处在最中间位置的一个数叫做这组数据的中位数 。

    这就是中位数的定义,初中的知识了。那么如何快速找到中位数呢?先排序后查找?还是找到前n/2个数然后得到中位数呢?显然效率不高。我们利用快速排序 的思想进行求解。先给出代码:

private KDNode findMiD(int begin, int end, int flag) {
        if (begin >= end) {return null;}
        KDNode lastNode = set.get(end-1);//得到该树节点样本节点集最后一个节点
        int dia = flag%(KDNode.dimension);//得到该树节点分割的维数
        float keyValue = lastNode.get(dia);//得到分割点
//一趟快排算法中的交换部分
        int LastSmall = begin-1;
        for (int i=begin; i<end-1; i++) {
            if (set.get(i).get(dia) < keyValue) {
                exchange(++LastSmall, i);
            }
        }
        exchange(end-1, ++LastSmall);
//快排算法中的交换部分结束
//如果该分割点正好为中位数,则返回该位置的样本节点,如果中位数点小于分割点,则说明中位数在前半部分,故递归搜索前半部分,否则搜索后半部分。
        if (midPos == LastSmall)
            return set.get(midPos);
        else if (midPos < LastSmall)
            return findMiD(begin, LastSmall, flag);
        else
            return findMiD(LastSmall+1, end, flag);
    }

这是java实现的代码。我们先不管各种node是什么意思,先来看看思想是什么。
《KD tree算法(1)-简介&构建KD tree》
根据上图及注释应该很好理解了。
下面我们来说一下我的关于KD tree的设计思想。
结构图:
《KD tree算法(1)-简介&构建KD tree》

样本信息:
《KD tree算法(1)-简介&构建KD tree》

class KDNode {//每一个数据
    public static int dimension = 1;//存储维度信息,每个数据维度相同,故使用static。但我还不知道如何保证一旦设置不需修改。
    float[] coordinate;//存储每一维度的数值

    KDNode() {
        coordinate = new float[1];
    }//默认数据为一维

    KDNode(int dimension) {
        this.dimension = dimension;
        coordinate = new float[dimension];
    }
    void set(int pos, float val) { coordinate[pos] = val; }
    float get(int pos) { return coordinate[pos]; }

    @Override
    public String toString() {
        String s = " ";
        for (int i=0; i<dimension; i++)
            s += (" " + coordinate[i]);
        return s;
    }
}

比较简单,一目了然。

样本集:
《KD tree算法(1)-简介&构建KD tree》

class KDNodeSet {//用来存节点的集合
    ArrayList<KDNode> set;//样本集
    int midPos;//中位数位置
    KDNodeSet() {
        set = new ArrayList<KDNode>();
    }
    KDNodeSet(KDNodeSet Nodeset) {
        this.set = Nodeset.set;
    }//拷贝构造函数
    void add(KDNode node) {
        set.add(node);
    }//添加

    KDNode findMiD(int flag) {
        midPos = set.size()/2 ;
        return findMiD(0, set.size(), flag);
    }//提供给外部的找中位数的方法,flag表示维度

    private KDNode findMiD(int begin, int end, int flag) {
        if (begin >= end) {return null;}
        KDNode lastNode = set.get(end-1);
        int dia = flag%(KDNode.dimension);
        float keyValue = lastNode.get(dia);
        int LastSmall = begin-1;
        for (int i=begin; i<end-1; i++) {
            if (set.get(i).get(dia) < keyValue) {
                exchange(++LastSmall, i);
            }
        }
        exchange(end-1, ++LastSmall);
        if (midPos == LastSmall)
            return set.get(midPos);
        else if (midPos < LastSmall)
            return findMiD(begin, LastSmall, flag);
        else
            return findMiD(LastSmall+1, end, flag);
    }

    KDNodeSet findLeft() {
        return getSubSet(0, midPos);
    }//返回左子集

    KDNodeSet findRight() {
        return getSubSet(midPos+1, set.size());
    }//返回右子集

    KDNodeSet getSubSet(int begin, int end) {
        KDNodeSet subSet = new KDNodeSet();
        for (int i=begin; i<end; i++)
            subSet.add(set.get(i));
        return subSet;
    }//返回子集

KD Tree的每个节点:
《KD tree算法(1)-简介&构建KD tree》

class KDTreeNode {//KD树的每一个节点,保存有中位数的值,在该层有的集合和父子节点引用
    int flag = 0;
    KDNode value;
    KDTreeNode father;
    KDNodeSet set;
    KDTreeNode left;
    KDNodeSet leftSet;
    KDTreeNode right;
    KDNodeSet rightSet;

    KDTreeNode(KDTreeNode father, KDNodeSet set) {
        this.father = father;
        this.set = set;
        run();
    }

    KDTreeNode(KDTreeNode father, KDNodeSet set, int flag) {
        this.father = father;
        this.set = set;
        this.flag = flag;
        run();
    }

    void run() {
        if (set.set.size() == 0) return;

        value = set.findMiD(flag);
        leftSet = set.findLeft();
        rightSet = set.findRight();
        left = new KDTreeNode(this, leftSet, flag+1);
        right = new KDTreeNode(this, rightSet, flag+1);
    }

    KDTreeNode getFather() {return father;}
    KDTreeNode getLeft() {return left;}
    KDTreeNode getRight() {return right;}
}

KD-Tree:
《KD tree算法(1)-简介&构建KD tree》

public class KDTree {
    KDNodeSet set;
    KDTreeNode root;
    KDTree() {
        set = new KDNodeSet();
    }
    void addNode(KDNode node) {
        set.add(node);
    }
    void BuildTree() {
        root = new KDTreeNode(null, set);//已经建好的KDTree
    }

    void find(KDNode node) {
        KDTreeNode position = find(root,node);

    }

    private KDTreeNode find(KDTreeNode Tnode, KDNode node) {
        int dia = Tnode.flag%KDNode.dimension;
        if (Tnode == null) { return Tnode.getFather();//如果找到叶子节点还未找到,那就把他父节点设为最近邻
        } else if (node.get(dia) < Tnode.value.get(dia)) {
            return find(Tnode.getLeft(), node);
        }else
            return find(Tnode.getRight(), node);
    }

测试用例:

    public static void main(String[] args) {
        KDTree tree = new KDTree();
        KDNode node1 = new KDNode(3);
        node1.set(0,3);
        node1.set(1,2);
        node1.set(2,5);
        tree.addNode(node1);
         node1 = new KDNode(3);
        node1.set(0,4);
        node1.set(1,5);
        node1.set(2,1);
        tree.addNode(node1);
         node1 = new KDNode(3);
        node1.set(0,2);
        node1.set(1,9);
        node1.set(2,0);
        tree.addNode(node1);
         node1 = new KDNode(3);
        node1.set(0,7);
        node1.set(1,4);
        node1.set(2,5);
        tree.addNode(node1);
         node1 = new KDNode(3);
        node1.set(0,1);
        node1.set(1,2);
        node1.set(2,5);
        tree.addNode(node1);
         node1 = new KDNode(3);
        node1.set(0,3);
        node1.set(1,5);
        node1.set(2,5);
        tree.addNode(node1);
         node1 = new KDNode(3);
        node1.set(0,3);
        node1.set(1,2);
        node1.set(2,8);
        tree.addNode(node1);
         node1 = new KDNode(3);
        node1.set(0,3);
        node1.set(1,2);
        node1.set(2,1);
        tree.addNode(node1);
         node1 = new KDNode(3);
        node1.set(0,4);
        node1.set(1,6);
        node1.set(2,5);
        tree.addNode(node1);
         node1 = new KDNode(3);
        node1.set(0,3);
        node1.set(1,1);
        node1.set(2,14);
        tree.addNode(node1);
        tree.BuildTree();
    }
    原文作者:约瑟夫环问题
    原文地址: https://blog.csdn.net/butterfly9844/article/details/74587042
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞