一般的二叉搜索树,在使用的过程中可能会慢慢变得不平衡,这样很可能会降低查找、插入等等的效率,因此我们需要使用算法来实现树的平衡。AVL树是一种比较经典的平衡二叉搜索树,它规定每个节点的左子树和右子树的高度差最多为1,代码如下。(参考博客中对remove函数的书写存在一定的问题,忘记在最后对每个结点进行height的更新,本文代码对此做出来修改,并用实例测试得以验证)
avl.h
#ifndef AVL_H
#define AVL_H
#include <iostream>
using namespace std;
template <class T = int>
class AVLTree
{
private:
struct AVLTreeNode
{
T key;
int height; //结点的高度
AVLTreeNode* left;
AVLTreeNode* right;
AVLTreeNode(T val, AVLTreeNode* l, AVLTreeNode* r):
key(val), height(0), left(l), right(r) {}
};
public:
AVLTree();
~AVLTree();
int height() const;
int max(int a, int b) const; //比较height的大小,因此直接用int类型
void preOrder() const;
void inOrder() const;
void postOrder() const;
AVLTreeNode* search(T key) const; //递归实现
AVLTreeNode* iterativeSearch(T key) const; //非递归实现
T* minimum() const;
T* maximum() const;
void insert(T key);
void remove(T key);
void destroy();
void print() const;
private:
AVLTreeNode* mRoot;
int height(AVLTreeNode* tree) const;
void preOrder(AVLTreeNode* tree) const;
void inOrder(AVLTreeNode* tree) const;
void postOrder(AVLTreeNode* tree) const;
AVLTreeNode* search(AVLTreeNode* x, T key) const; //递归查找
AVLTreeNode* iterativeSearch(AVLTreeNode* x, T key) const; //非递归查找
AVLTreeNode* minimum(AVLTreeNode* tree) const;
AVLTreeNode* maximum(AVLTreeNode* tree) const;
AVLTreeNode* leftLeftRotation(AVLTreeNode* k2);
AVLTreeNode* rightRightRotation(AVLTreeNode* k1);
AVLTreeNode* leftRightRotation(AVLTreeNode* k3);
AVLTreeNode* rightLeftRotation(AVLTreeNode* k1);
AVLTreeNode* insert(AVLTreeNode* &tree, T key);
AVLTreeNode* remove(AVLTreeNode* &tree, AVLTreeNode* z);
void destroy(AVLTreeNode* &tree);
void print(AVLTreeNode* tree, T key, int direction) const;
};
#endif // AVL_H
avl.cpp
#include "avl.h"
#include <iomanip>
//using namespace std;
template <class T>
AVLTree<T>::AVLTree():mRoot(nullptr) {}
template <class T>
AVLTree<T>::~AVLTree()
{
destroy(mRoot);
}
template <class T>
int AVLTree<T>::height(AVLTreeNode* tree) const
{
if (nullptr != tree)
{
return tree->height;
}
return 0;
}
template <class T>
int AVLTree<T>::height() const
{
return height(mRoot);
}
template <class T>
int AVLTree<T>::max(int a, int b) const
{
return a > b ? a : b;
}
template <class T>
void AVLTree<T>::preOrder(AVLTreeNode* tree) const
{
if (nullptr != tree)
{
cout << tree->key << " ";
preOrder(tree->left);
preOrder(tree->right);
}
}
template <class T>
void AVLTree<T>::preOrder() const
{
preOrder(mRoot);
}
template <class T>
void AVLTree<T>::inOrder(AVLTreeNode* tree) const
{
if (nullptr != tree)
{
inOrder(tree->left);
cout << tree->key << " ";
inOrder(tree->right);
}
}
template <class T>
void AVLTree<T>::inOrder() const
{
inOrder(mRoot);
}
template <class T>
void AVLTree<T>::postOrder(AVLTreeNode* tree) const
{
if (nullptr != tree)
{
postOrder(tree->left);
postOrder(tree->right);
cout << tree->key << " ";
}
}
template <class T>
void AVLTree<T>::postOrder() const
{
postOrder(mRoot);
}
//递归查找
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::search(AVLTreeNode* x, T key) const
{
if (x == nullptr || x->key == key)
{
return x;
}
if (key < x->key)
{
return search(x->left, key);
}
else
{
return search(x->right, key);
}
}
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::search(T key) const
{
return search(mRoot, key);
}
//非递归查找
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::iterativeSearch(AVLTreeNode* x, T key) const
{
while (nullptr != x && x->key != key)
{
if (key < x->key)
{
x = x->left;
}
else if (key > x->key)
{
x = x->right;
}
}
return x;
}
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::iterativeSearch(T key) const
{
return iterativeSearch(mRoot, key);
}
//查找最小结点
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::minimum(AVLTreeNode* tree) const
{
if (nullptr == tree)
{
return nullptr;
}
while (nullptr != tree->left)
{
tree = tree->left;
}
return tree;
}
template <class T>
T* AVLTree<T>::minimum() const
{
AVLTreeNode* p = minimum(mRoot);
if (nullptr != p)
{
return &p->key;
}
return nullptr;
}
//查找最大结点
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::maximum(AVLTreeNode* tree) const
{
if (nullptr == tree)
{
return nullptr;
}
while (nullptr != tree->right)
{
tree = tree->right;
}
return tree;
}
template <class T>
T* AVLTree<T>::maximum() const
{
AVLTreeNode* p = maximum(mRoot);
if (nullptr != p)
{
return &p->key;
}
return nullptr;
}
//左左对应,左单旋转;返回旋转后的根结点
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::leftLeftRotation(AVLTreeNode* k2)
{
AVLTreeNode* k1;
k1 = k2->left;
k2->left = k1->right;
k1->right = k2;
k2->height = max(height(k2->left), height(k2->right)) + 1; //height函数,防止对空结点取height成员
k1->height = max(height(k1->left), k2->height) + 1;
return k1;
}
//右右对应,右单旋转
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::rightRightRotation(AVLTreeNode* k1)
{
AVLTreeNode* k2;
k2 = k1->right;
k1->right = k2->left;
k2->left = k1;
k1->height = max(height(k1->left), height(k1->right)) + 1;
k2->height = max(height(k2->right), k1->height) + 1;
return k2;
}
//左右旋转
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::leftRightRotation(AVLTreeNode* k3)
{
k3->left = rightRightRotation(k3->left);
return leftLeftRotation(k3);
}
//右左旋转
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::rightLeftRotation(AVLTreeNode* k1)
{
k1->right = leftLeftRotation(k1->right);
return rightRightRotation(k1);
}
/*
* 将结点插入到AVL数中,并返回根节点
*
* 参数说明:
* tree AVL树的根结点
* key 插入结点的键值
* 返回值:
* 根节点
*/
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::insert(AVLTreeNode* &tree, T key)
{
if (nullptr == tree)
{
tree = new AVLTreeNode(key, nullptr, nullptr);
if (nullptr == tree)
{
cout << "Error: create avltree node failed!" << endl;
return nullptr;
}
}
else if (key < tree->key) //将key点插入到左子树的情况
{
tree->left = insert(tree->left, key);
// 插入节点后,若AVL树失去平衡,则进行相应的调节。
if (height(tree->left) - height(tree->right) == 2)
{
if (key < tree->left->key)
{
tree = leftLeftRotation(tree);
}
else if (key > tree->left->key)
{
tree = leftRightRotation(tree);
}
else
{
//不存在 key == tree->left->key 的情况,因为此时插入了点
}
}
}
else if (key > tree->key)
{
tree->right = insert(tree->right, key);
if (height(tree->right) - height(tree->left) == 2)
{
if (key < tree->right->key)
{
tree = rightLeftRotation(tree);
}
else if (key > tree->right->key)
{
tree = rightRightRotation(tree);
}
}
}
else if (key == tree->key)
{
cout << "error: cannot add node with same key" << endl;
}
tree->height = max(height(tree->left), height(tree->right)) + 1;
return tree;
}
template <class T>
void AVLTree<T>::insert(T key)
{
insert(mRoot, key);
}
/*
* 删除结点(z),返回根节点
*
* 参数说明:
* tree AVL树的根节点
* z 待删除的结点
* 返回值:
* 根节点
*/
template <class T>
typename AVLTree<T>::AVLTreeNode* AVLTree<T>::remove(AVLTreeNode* &tree, AVLTreeNode* z)
{
if (nullptr == tree || nullptr == z)
{
return nullptr;
}
if (z->key < tree->key) // 待删除结点在tree的左子树中
{
tree->left = remove(tree->left, z);
// 删除后,若AVL树失去平衡,则进行相应调整
if (height(tree->right) - height(tree->left) == 2)
{
AVLTreeNode* r = tree->right;
if (height(r->left) > height(r->right))
{
tree = rightLeftRotation(tree);
}
else if (height(r->left) <= height(r->right))
{
tree = rightRightRotation(tree);
}
}
}
else if (z->key > tree->key) // 待删除结点在tree的右子树中
{
tree->right = remove(tree->right, z);
// 删除后,若AVL树失去平衡,则进行相应调整
if (height(tree->left) - height(tree->right) == 2)
{
AVLTreeNode* l = tree->left;
if (height(l->left) > height(l->right))
{
tree = leftLeftRotation(tree);
}
else if (height(l->left) <= height(l->right))
{
tree = leftRightRotation(tree);
}
}
}
else // z->key == tree->key
{
if (nullptr != tree->left && nullptr != tree->right) // 左右子树都不为空
{
if (height(tree->left) > height(tree->right))
{
// 找出左子树中最大的结点,并将值赋值给tree,再删除该最大结点
// 采用这种方式的好处是:删除"tree的左子树中最大节点"之后,AVL树仍然是平衡的。
AVLTreeNode* max = maximum(tree->left);
tree->key = max->key;
tree->left = remove(tree->left, max);
}
else if (height(tree->left) <= height(tree->right))
{
// 找出右子树中最小的结点,赋值给tree,再删除最小结点
// 采用这种方式的好处是:删除"tree的右子树中最小节点"之后,AVL树仍然是平衡的。
AVLTreeNode* min = minimum(tree->right);
tree->key = min->key;
tree->right = remove(tree->right, min);
}
}
else
{
// 全空,或只有一个子结点且是叶子结点
AVLTreeNode* tmp = tree;
tree = tree->left != nullptr ? tree->left : tree->right;
delete tmp;
}
}
// 需要对高度进行更新!!!
if (nullptr != tree)
{
tree->height = max(height(tree->left), height(tree->right)) + 1;
}
return tree;
}
template <class T>
void AVLTree<T>::remove(T key)
{
AVLTreeNode* z;
z = search(mRoot, key);
if (nullptr != z)
{
mRoot = remove(mRoot, z);
}
}
// 销毁AVL树
template <class T>
void AVLTree<T>::destroy(AVLTreeNode* &tree)
{
if (nullptr == tree)
{
return;
}
if (nullptr != tree->left)
{
destroy(tree->left);
}
if (nullptr != tree->right)
{
destroy(tree->right);
}
delete tree;
}
template <class T>
void AVLTree<T>::destroy()
{
destroy(mRoot);
}
/*
* 打印”二叉查找树“
*
* key --结点键值
* direction -- 0, 该结点是根结点
* -1, 该结点是它父结点的左孩子
* 1, 该结点是它父节点的右孩子
*/
template <class T>
void AVLTree<T>::print(AVLTreeNode *tree, T key, int direction) const
{
if (nullptr != tree)
{
if (direction == 0)
{
cout << setw(2) << tree->key << " is root" << endl;
}
else
{
cout << setw(2) << tree->key << "is " << setw(2) << key << "'s " << setw(12)
<< (direction == 1 ? "right child" : "left child") << endl;
}
print(tree->left, tree->key, -1);
print(tree->right, tree->key, 1);
}
}
template <class T>
void AVLTree<T>::print() const
{
if (nullptr != mRoot)
{
print(mRoot, mRoot->key, 0);
}
}
main.cpp
//#include <iostream>
#include "avl.cpp"
//using namespace std;
static int arr[] = {3,2,1,4,5,6,7,16,15,14,13,12,11,10,8,9};
#define TBL_SIZE(a) ( (sizeof(a)) / (sizeof(a[0])) )
int main()
{
int i, ilen;
AVLTree<int>* tree = new AVLTree<int>();
cout << "== add element: ";
ilen = TBL_SIZE(arr);
for (i = 0; i < ilen; ++i)
{
cout << arr[i] << " ";
tree->insert(arr[i]);
}
cout << "\n== preOrder: ";
tree->preOrder();
cout << "\n== inOrder: ";
tree->inOrder();
cout << "\n== postOrder: ";
tree->postOrder();
cout << endl;
cout << "== height: " << tree->height() << endl;
cout << "== minimum: " << *tree->minimum() << endl;
cout << "== maximum: " << *tree->maximum() << endl;
cout << "== print: " << endl;
tree->print();
cout << "\n== remove: 1, 3, 5";
tree->remove(1); // 1
tree->remove(3); // 3
tree->remove(5); // 5
// tree->remove(2); // 1
// tree->remove(4); // 3
// tree->remove(6); // 5
cout << "\n== height: " << tree->height();
cout << "\n== inOrder: ";
tree->inOrder();
cout << "\n== print: " << endl;
tree->print();
// destroy tree
tree->destroy();
return 0;
}
运行结果如下: