KD tree算法是KNN(K-nearest neighbor)实现的重要算法之一,下面我们先简单介绍一些KNN的知识,然后开始我们KD tree的讲解。
KNN分类算法
KNN是一种简单的分类方法:
分类时,对新的实例,根据k个最近邻的训练实例的类别,通过多数表决等方式进行预测。
以上是《统计学习方法》中对KNN的解释,简单明了。那么问题就在于如何快速有效找到K个近邻就是该算法的关键了。关于KNN算法的详细知识,请看另一篇文章《统计学习方法之K近邻法》,下面我们来简单学习一下如何快速找到K个近邻的样本。
KD Tree
为了对训练数据进行快速k近邻搜索,我们使用特殊的数据结构存储训练数据-kd Tree方法。
kd Tree是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd树是二叉树,表示对k维空间的一个划分(partition),构造kd Tree相当于奖k维空间划分,构成一系列的k维超矩形区域。
以上是《统计学习方法》对KD Tree的解释,同样简单明了。
KD Tree方法即将空间根据坐标不断划分,划分成k块,如下图所示:
构建 KD-Tree
以上是《统计学习方法》书中对构造平衡kd Tree的算法描述。我们进行分析。
1.1 为什么要构造“平衡“kd Tree?
学过快速排序的读者们都知道,当我们取得的K值刚好保证左右两边的数的个数相等时,快排算法的时间复杂度最低,而当K值取到最大或最小时算法时间复杂度最高。为此提出了随机快速排序算法。以此为基础,我们可以分析知道为什么要构造“平衡“kd Tree了。很简单,为了提高搜索效率。
为了更加简单明了的解释,我们据如下例子:
1 2 3 4 5 6 7 8 9
我们将以上数字分为两部分:
1) >=5;
2) <5;
为了找到数字3,如果我们知道3比5小,那么我们只需要搜索5左边的四个数字即可。
如果我们将以上数字分成以下两部分:
1) >=9;
2) <9;
我们知道3比9小,那么我们需要搜索9左边的8个数字,这样的效率明显没有上一种的效率高。
2.2 如何构造平衡kd Tree?
中位数法。中位数:将一组数据按大小顺序排列,处在最中间位置的一个数叫做这组数据的中位数 。
这就是中位数的定义,初中的知识了。那么如何快速找到中位数呢?先排序后查找?还是找到前n/2个数然后得到中位数呢?显然效率不高。我们利用快速排序 的思想进行求解。先给出代码:
private KDNode findMiD(int begin, int end, int flag) {
if (begin >= end) {return null;}
KDNode lastNode = set.get(end-1);//得到该树节点样本节点集最后一个节点
int dia = flag%(KDNode.dimension);//得到该树节点分割的维数
float keyValue = lastNode.get(dia);//得到分割点
//一趟快排算法中的交换部分
int LastSmall = begin-1;
for (int i=begin; i<end-1; i++) {
if (set.get(i).get(dia) < keyValue) {
exchange(++LastSmall, i);
}
}
exchange(end-1, ++LastSmall);
//快排算法中的交换部分结束
//如果该分割点正好为中位数,则返回该位置的样本节点,如果中位数点小于分割点,则说明中位数在前半部分,故递归搜索前半部分,否则搜索后半部分。
if (midPos == LastSmall)
return set.get(midPos);
else if (midPos < LastSmall)
return findMiD(begin, LastSmall, flag);
else
return findMiD(LastSmall+1, end, flag);
}
这是java实现的代码。我们先不管各种node是什么意思,先来看看思想是什么。
根据上图及注释应该很好理解了。
下面我们来说一下我的关于KD tree的设计思想。
结构图:
样本信息:
class KDNode {//每一个数据
public static int dimension = 1;//存储维度信息,每个数据维度相同,故使用static。但我还不知道如何保证一旦设置不需修改。
float[] coordinate;//存储每一维度的数值
KDNode() {
coordinate = new float[1];
}//默认数据为一维
KDNode(int dimension) {
this.dimension = dimension;
coordinate = new float[dimension];
}
void set(int pos, float val) { coordinate[pos] = val; }
float get(int pos) { return coordinate[pos]; }
@Override
public String toString() {
String s = " ";
for (int i=0; i<dimension; i++)
s += (" " + coordinate[i]);
return s;
}
}
比较简单,一目了然。
样本集:
class KDNodeSet {//用来存节点的集合
ArrayList<KDNode> set;//样本集
int midPos;//中位数位置
KDNodeSet() {
set = new ArrayList<KDNode>();
}
KDNodeSet(KDNodeSet Nodeset) {
this.set = Nodeset.set;
}//拷贝构造函数
void add(KDNode node) {
set.add(node);
}//添加
KDNode findMiD(int flag) {
midPos = set.size()/2 ;
return findMiD(0, set.size(), flag);
}//提供给外部的找中位数的方法,flag表示维度
private KDNode findMiD(int begin, int end, int flag) {
if (begin >= end) {return null;}
KDNode lastNode = set.get(end-1);
int dia = flag%(KDNode.dimension);
float keyValue = lastNode.get(dia);
int LastSmall = begin-1;
for (int i=begin; i<end-1; i++) {
if (set.get(i).get(dia) < keyValue) {
exchange(++LastSmall, i);
}
}
exchange(end-1, ++LastSmall);
if (midPos == LastSmall)
return set.get(midPos);
else if (midPos < LastSmall)
return findMiD(begin, LastSmall, flag);
else
return findMiD(LastSmall+1, end, flag);
}
KDNodeSet findLeft() {
return getSubSet(0, midPos);
}//返回左子集
KDNodeSet findRight() {
return getSubSet(midPos+1, set.size());
}//返回右子集
KDNodeSet getSubSet(int begin, int end) {
KDNodeSet subSet = new KDNodeSet();
for (int i=begin; i<end; i++)
subSet.add(set.get(i));
return subSet;
}//返回子集
KD Tree的每个节点:
class KDTreeNode {//KD树的每一个节点,保存有中位数的值,在该层有的集合和父子节点引用
int flag = 0;
KDNode value;
KDTreeNode father;
KDNodeSet set;
KDTreeNode left;
KDNodeSet leftSet;
KDTreeNode right;
KDNodeSet rightSet;
KDTreeNode(KDTreeNode father, KDNodeSet set) {
this.father = father;
this.set = set;
run();
}
KDTreeNode(KDTreeNode father, KDNodeSet set, int flag) {
this.father = father;
this.set = set;
this.flag = flag;
run();
}
void run() {
if (set.set.size() == 0) return;
value = set.findMiD(flag);
leftSet = set.findLeft();
rightSet = set.findRight();
left = new KDTreeNode(this, leftSet, flag+1);
right = new KDTreeNode(this, rightSet, flag+1);
}
KDTreeNode getFather() {return father;}
KDTreeNode getLeft() {return left;}
KDTreeNode getRight() {return right;}
}
KD-Tree:
public class KDTree {
KDNodeSet set;
KDTreeNode root;
KDTree() {
set = new KDNodeSet();
}
void addNode(KDNode node) {
set.add(node);
}
void BuildTree() {
root = new KDTreeNode(null, set);//已经建好的KDTree
}
void find(KDNode node) {
KDTreeNode position = find(root,node);
}
private KDTreeNode find(KDTreeNode Tnode, KDNode node) {
int dia = Tnode.flag%KDNode.dimension;
if (Tnode == null) { return Tnode.getFather();//如果找到叶子节点还未找到,那就把他父节点设为最近邻
} else if (node.get(dia) < Tnode.value.get(dia)) {
return find(Tnode.getLeft(), node);
}else
return find(Tnode.getRight(), node);
}
测试用例:
public static void main(String[] args) {
KDTree tree = new KDTree();
KDNode node1 = new KDNode(3);
node1.set(0,3);
node1.set(1,2);
node1.set(2,5);
tree.addNode(node1);
node1 = new KDNode(3);
node1.set(0,4);
node1.set(1,5);
node1.set(2,1);
tree.addNode(node1);
node1 = new KDNode(3);
node1.set(0,2);
node1.set(1,9);
node1.set(2,0);
tree.addNode(node1);
node1 = new KDNode(3);
node1.set(0,7);
node1.set(1,4);
node1.set(2,5);
tree.addNode(node1);
node1 = new KDNode(3);
node1.set(0,1);
node1.set(1,2);
node1.set(2,5);
tree.addNode(node1);
node1 = new KDNode(3);
node1.set(0,3);
node1.set(1,5);
node1.set(2,5);
tree.addNode(node1);
node1 = new KDNode(3);
node1.set(0,3);
node1.set(1,2);
node1.set(2,8);
tree.addNode(node1);
node1 = new KDNode(3);
node1.set(0,3);
node1.set(1,2);
node1.set(2,1);
tree.addNode(node1);
node1 = new KDNode(3);
node1.set(0,4);
node1.set(1,6);
node1.set(2,5);
tree.addNode(node1);
node1 = new KDNode(3);
node1.set(0,3);
node1.set(1,1);
node1.set(2,14);
tree.addNode(node1);
tree.BuildTree();
}