二叉查找树
定义(来自百度百科):
二叉排序树或者是一棵空树,或者是具有下列性质的二叉树:
(1)若左子树不空,则左子树上所有结点的值均小于它的根结点的值;
(2)若右子树不空,则右子树上所有结点的值均大于它的根结点的值;
(3)左、右子树也分别为二叉排序树;
(4)没有键值相等的节点。
下面代码中主要实现了求以某结点为根的子树中的结点个数,二叉查找树的查找和排序方法,二叉查找树中的最大键,最小键,向上取整和向下取整,二叉查找树的选择和排名操作,二叉查找树的删除键操作,范围查找操作等.
package cn.edu.zzuli.api;
import java.util.NoSuchElementException;
import edu.princeton.cs.algs4.*;
//基于二叉查找树的符号表
public class BST<Key extends Comparable<Key>, Value> {
private Node root;// 二叉查找树的根结点
private class Node {
private Key key;// 键
private Value val;// 值
private Node left, right;// 指向子树的链接
private int N;// 以该结点为根的子树中的结点个数
public Node(Key key, Value val, int N) {
this.key = key;
this.val = val;
this.N = N;
}
}
public BST() {
}
public int size() {
return size(root);
}
private int size(Node x) {
if (x == null)
return 0;
else
return x.N;
}
/*
* 二叉查找树的查找和排序方法的实现
*/
public Value get(Key key) {
return get(root, key);
}
private Value get(Node x, Key key) {
// 在以x为根结点的子树中查找并返回key所对应的值
// 如果找不到则返回null
if (x == null)
return null;
int cmp = key.compareTo(x.key);
if (cmp < 0)
return get(x.left, key);
else if (cmp > 0)
return get(x.right, key);
return x.val;
}
public void put(Key key, Value val) {
// 查找key,找到则更新它的值,否则为它创建一个新的结点
root = put(root, key, val);
}
private Node put(Node x, Key key, Value val) {
if (x == null)
x = new Node(key, val, 1);
int cmp = key.compareTo(x.key);
if (cmp < 0)
x.left = put(x.left, key, val);
else if (cmp > 0)
x.right = put(x.right, key, val);
else
x.val = val;
x.N = size(x.left) + size(x.right) + 1;
return x;
}
/*
* 二叉查找树中max(),min(),floor(),ceiling()方法的实现
*/
public Key min() {
if (size(root) == 0)
throw new NoSuchElementException("calls min() with empty symbol table");
return min(root).key;
}
private Node min(Node x) {
if (x.left == null)
return x;
return min(x.left);
}
public Key max() {
if (size(root) == 0)
throw new NoSuchElementException("calls min() with empty symbol table");
return max(root).key;
}
private Node max(Node x) {
if (x.right != null)
return max(x.right);
return x;
}
public Key floor(Key key) {
Node x = floor(root, key);
if (x == null)
return null;
return x.key;
}
private Node floor(Node x, Key key) {
if (x == null)
return null;
int cmp = key.compareTo(x.key);
if (cmp == 0)
return x;
// 如果key小于根结点root.key,那么小于等于key的最大键floor(key)一定在根结点的左子树中
if (cmp < 0) {
return floor(x.left, key);
}
// 如果key大于根结点root.key,那么只有当根结点右子树中存在小于等于key的结点时小于等于key的最大键才会出现在右子树中
Node t = floor(x.right, key);
if (t != null)
return t;
// 否则根结点就是小于等于key的最大键
else
return x;
}
public Key ceiling(Key key) {
Node x = ceiling(root, key);
if (x == null)
return null;
return x.key;
}
private Node ceiling(Node x, Key key) {
if (x == null)
return null;
int cmp = key.compareTo(x.key);
if (cmp == 0)
return x;
// 如果key大于根结点root.key,那么大于等于key的最小值ceiling(key)一定在根结点的右子树中
if (cmp > 0) {
return ceiling(x.right, key);
}
// 如果key小于根结点root.key,那么只有当根结点左子树中存在大于等于key的结点时大于等于可以的最小键才会出现在左子树中
Node t = ceiling(x.left, key);
if (t != null)
return t;
// 否则根结点就是大于等于key的最小键
else
return x;
}
/*
* 二叉查找树select()和rank()方法的实现
*/
// 找到排名为k的键
public Key select(int k) {
return select(root, k).key;
}
private Node select(Node x, int k) {
if (x == null)
return null;
int t = size(x.left);
// 如果左子树中的结点数t > k,那么我们就继续(递归地)在左子树中查找排名为k的键
if (t > k) {
return select(x.left, k);
}
// 如果t < k,我们就(递归地)在右子树中查找排名为(k-t-1)的键
else if (t < k) {
return select(x.right, k - t - 1);
}
// 如果t == k,我们就返回根结点中的键
else {
return x;
}
}
// 给给定的键排名
public int rank(Key key) {
return rank(root, key);
}
private int rank(Node x, Key key) {
if (x == null)
return 0;
int cmp = key.compareTo(x.key);
// 如果给定的键小于根结点我们会返回该键在左子树中的排名(递归计算)
if (cmp < 0)
return rank(x.left, key);
// 如果给定的键大于根结点,我们会返回t+1(根结点)加它在右子树中的排名(递归计算)
else if (cmp > 0)
return size(x.left) + 1 + rank(x.right, key);
// 如果给定的键和根结点的键相等,我们返回左子树中的结点总数t
else
return size(x.left);
}
/*
* 二叉查找树的delete()方法的实现
*/
public void deleteMin() {
root = deleteMin(root);
}
private Node deleteMin(Node x) {
// 不断深入根结点的左子树中直至遇见一个空链接
// 然后将指向该结点的链接指向该结点的右子树(只需要在递归调用中返回它的右链接即可)
if (x.left == null) {
return x.right;
}
x.left = deleteMin(x.left);
x.N = size(x.left) + size(x.right) + 1;
return x;
}
public void delete(Key key) {
root = delete(root, key);
}
private Node delete(Node x, Key key) {
if (x == null)
return null;
int cmp = key.compareTo(x.key);
if (cmp < 0)
x.left = delete(x.left, key);
else if (cmp > 0)
x.right = delete(x.right, key);
else {
if (x.left == null)
return x.right;
if (x.right == null)
return x.left;
// 将指向即将被删除的结点的链接保存为t
Node t = x;
// 将x指向它的后继结点min(x.right)
x = min(t.right);
x.right = deleteMin(t.right);
x.left = t.left;
}
x.N = size(x.left) + size(x.right) + 1;
return x;
}
public Iterable<Key> keys() {
return keys(min(), max());
}
public Iterable<Key> keys(Key lo, Key hi) {
Queue<Key> queue = new Queue<Key>();
keys(root, queue, lo, hi);
return queue;
}
public void keys(Node x, Queue<Key> queue, Key lo, Key hi) {
if (x == null)
return;
int cmplo = lo.compareTo(x.key);
int cmphi = hi.compareTo(x.key);
if (cmplo < 0)
keys(x.left, queue, lo, hi);
if (cmplo <= 0 && cmphi >= 0)
queue.enqueue(x.key);
if (cmphi > 0)
keys(x.right, queue, lo, hi);
}
public static void main(String[] args) {
BST<String, Integer> st;
st = new BST<String, Integer>();
for (int i = 0; !StdIn.isEmpty(); i++) {
String key = StdIn.readString();
st.put(key, i);
}
for (String s : st.keys())
StdOut.println(s + " " + st.get(s));
StdOut.println("/*****************test min()*****************/");
StdOut.println(st.min());
StdOut.println("/*****************test max()*****************/");
StdOut.println(st.max());
StdOut.println("/*****************test floor()*****************/");
StdOut.println(st.floor("G"));
StdOut.println("/*****************test ceiling()*****************/");
StdOut.println(st.ceiling("G"));
StdOut.println("/*****************test select()*****************/");
StdOut.println(st.select(3));
StdOut.println("/*****************test rank()*****************/");
StdOut.println(st.rank("H"));
StdOut.println("/*****************test delete()*****************/");
st.delete("E");
for (String s : st.keys())
StdOut.println(s + " " + st.get(s));
StdOut.println("/*****************test deleteMin()*****************/");
st.deleteMin();
for (String s : st.keys())
StdOut.println(s + " " + st.get(s));
}
}
/******************************************************************************
* S E A R C H E X A M P L E
*
* A 8 C 4 E 12 H 5 L 11 M 9 P 10 R 3 S 0 X 7
*
******************************************************************************/