二叉搜索树作用、原理和实现(C和Python)

二叉搜索树(Binary Search Tree)是干什么用的?

我知道的主要作用是搜索和动态排序,二叉树进行插入/查询/删除的时间复杂度为O(log(n))。但是实际使用的时候通常不会有这么快,因为你插入顺序所用的middle通常不是那么准,尤其是在插入数据的顺序是有序或者基本有序的时候,这颗二叉树会严重的不平衡,最糟糕的情况下会下降到和链表一样。而解决这个问题的办法是平衡二叉树,但是今天这篇文章不会讨论它,只会关注于基本的二叉搜索树。

二叉搜索树的基本原理

如果你不想看下面这些话,只想一句话明白搜索二叉树怎么实现:(key比自身小放左面,key比自身大放右边),如果你看完这句话就明白了,那恭喜你,因为这个事情本来就是这么简单,我之所以写这篇文章其实就是想节省一些不喜欢看长篇大论的人的宝贵时间。。。

  1. 简单来说搜索二叉树的思想和二分插入排序的思想是一致的,就是用一个数做middle,然后将数据分成两份,一份是比自己大的,一份是比自己小的。并且在插入和搜索时遵循相同的规则,就可以达到一个快速查找和排序的目的。
  2. 典型的二叉树的数据结构有3个部分:(查找键、左子节点指针、右子节点指针)。不过使用的时候通常会再加上:(数据指针、父节点指针)。
  3. 数据指针:通常在有key-value-store的需求时使用。(在我的示例中简化了,数据简单的用int来表示)
  4. 父节点指针:这不是必须的,但是加上这个指针会让你在删除节点等操作时更方便,空间换时间。
  5. 通常搜索二叉树遵循如下原则:
  6. 左子节点小于自身。 右子节点大于自身。
  7. 所有子节点均遵从第一条规则。

如何实现一个二叉搜索树

插入:

  1. 插入操作是比较简单的,就跟上面说的一样,比自己小就插到左边,比自己大就插入到右边,相等就更新自身。
  2. 我是使用递归实现的,因为我觉得这比较美观。但这不是必须,而且从效率上来讲可能不用递归会好一些,因为递归会带来很高的函数调用开销以及栈的疯狂提高。(但是在现代的编译器中这两个问题可能不是问题,因为这些是可以被尾递归优化的)
  3. 具体实现看nAdd函数。(这是不是一句废话?如果是请告诉我,我会尝试尽量不说废话)

查找:

  1. 查找操作和插入操作是几乎一样。
  2. 实现请看nSearch.

删除

  1. 删除操作就稍微麻烦一些,而且我这段代码写的很丑,但是我懒得在写了。(如果传进去的是Node**类型,那么删除时就会直观一些,因为那样就可以直接修改主函数栈上的根节点指针)
  2. 删除节点时分三种情况:
  3. 没有子节点:这个时候,直接将父节点的指针置空就可以了。并删除自身。
  4. 有一个子节点,这个时候,需要将父节点和子节点互相指向。并删除自身。
  5. 有两个子节点时,需要挑选一个继承者(后继节点),然后将继承者复制到要删除的几点。然后删除继承者原来在的节点。
  6. 前驱节点和后继节点。
  7. 前驱节点:比当前节点小的节点中最大的节点。
  8. 后继节点:比当前节点大的节点中最小的节点。
  9. 至于怎么找到他们你可能已经想到了。
    1. 根据BST的特性,最大的节点肯定在“树的最右边”,最小的节点肯定在“树的最左边”。
    2. 查找前驱节点:就在当前节点的左子树中查找最大的。
    3. 查找后继节点:就在当前节点的右子树中查找最小的。(实现请看nSuccessor)
  10. 实现请看nRemove

遍历

  1. 前面又说BST可以用来动态排序,一般的排序方式都是一次性的排序,对于一直有动态插入的数据就不是一个类型的问题了。而BST用来解决这个问题是很好的,因为你只需要中序遍历就可以得到一个有序的数据。
  2. 实现请看nWalk。

C实现(修改后的版本)

#include "StdAfx.h"

struct Node {
    int key;
    int value;
    // Node* parent;
    Node* left;
    Node* right;
};

Node* NewNode(int key, int value) {
    Node* p = (Node*)malloc(sizeof(Node));   // 删除节点时要注意放掉内存
    p->key = key;
    p->value = value;
    // p->parent = parent;
    p->left = NULL;
    p->right = NULL;
    return p;
}

Node** nSearch(Node** n, int key) {
    if (*n == NULL) {
        return n;
    }
    else if (key == (*n)->key) {
        return n;
    }
    else if (key < (*n)->key) {
        return nSearch(&(*n)->left, key);
    }
    else {
        return nSearch(&(*n)->right, key);
    }
}

void nAdd(Node** r, int key, int value) {
    Node** n = nSearch(r, key);
    if (*n == NULL) {
        *n = NewNode(key, value);
    }
    else {
        (*n)->value = value;
    }
}

int nGet(Node** n, int key) {
    Node** result = nSearch(n, key);
    if ((*result) == NULL) {
        return NULL;
    }
    else {
        return (*result)->value;
    }
}

void nWalk(Node* n) {
    if (n == NULL) {
        return;
    }
    if (n->left != NULL) {
        nWalk(n->left);
    }
    printf("%d\n", n->key);
    if (n->right != NULL) {
        nWalk(n->right);
    }
}

Node** nMin(Node** n) {
    if ((*n)->left == NULL) {
        return n;
    }
    else {
        return nMin(&(*n)->left);
    }
}

Node** nSuccessor(Node** n) {
    if ((*n)->right == NULL) {
        return NULL;
    }
    else {
        return nMin(&(*n)->right);
    }
}

int nRemove(Node** n, int key) {
    Node** d = nSearch(n, key);
    if ((*d) == NULL) {
        return -1; // 要删除的几点不存在
    }


    if ((*d)->left != NULL && (*d)->right != NULL) {
        Node** suc = nSuccessor(d);
        (*d)->key = (*suc)->key;
        (*d)->value = (*suc)->value;
        return nRemove(suc, (*suc)->key);
    }
    else {
        // 因为使用二级指针,不在需要判断要删除的节点本身是处在左边还是右边,因为指针中已经指向了原始我们要修改的位置。
        Node* d_child;
        if ((*d)->left != NULL) {
            d_child = (*d)->left;
            // d_child->parent = (*d)->parent;
        }
        else if ((*d)->right != NULL) {
            d_child = (*d)->right;
            // d_child->parent = (*d)->parent;
        }
        else {
            d_child = NULL;
        }

        free(*d); 
        // 同样的,这里也不再需要判断是不是根节点
        *d = d_child;
        return 0;
    }
}

int main() {
    int testData[20][2] = {
        {61, 6161},
        {30, 3030},
        {98, 9898},
        {3, 33},
        {36, 3636},
        {30, 3030},
        {6, 66},
        {54, 5454},
        {63, 6363},
        {93, 9393},
        {93, 9393},
        {76, 7676},
        {84, 8484},
        {16, 1616},
        {13, 1313},
        {76, 7676},
        {78, 7878},
        {29, 2929},
        {9, 99},
        {76, 7676}
    };

    Node* root = NULL;
    Node** rootP = &root;

    int i;
    for (i = 0; i < 20; i++) {
        nAdd(rootP, testData[i][0], testData[i][1]);
    }
    nWalk(root);

    for (i = 0; i < 20; i++) {
        nRemove(rootP, testData[i][0]);
    }
    printf("\n\n");
    nWalk(root);

    return 0;
}

C实现(一开始的版本,留在这里为了让大家对比看看修改后的)

#include "StdAfx.h"


struct Node {
    int key;
    int value;
    Node* parent;
    Node* left;
    Node* right;
};

Node* NewNode(int key, int value, Node* parent) {
    Node* p = (Node*)malloc(sizeof(Node));
    p->key = key;
    p->value = value;
    p->parent = parent;
    p->left = NULL;
    p->right = NULL;
    // 删除节点时要注意free掉内存
    return p;
}

void nAdd(Node* n, int key, int value) {
    if (key == n->key) {
        n->value = value;
    }
    else if (key < n->key) {
        if (n->left == NULL) {
            n->left = NewNode(key, value, n);
        }
        else {
            nAdd(n->left, key, value);
        }
    }
    else {
        if (n->right == NULL) {
            n->right = NewNode(key, value, n);
        }
        else {
            nAdd(n->right, key, value);
        }
    }

}

Node* nSearch(Node* n, int key) {
    if (key == n->key) {
        return n;
    }
    else if (key < n->key) {
        if (n->left == NULL) {
            return NULL;
        }
        else {
            return nSearch(n->left, key);
        }
    }
    else {
        if (n->right == NULL) {
            return NULL;
        }
        else {
            return nSearch(n->right, key);
        }
    }
}


int nGet(Node* n, int key) {
    Node* result = nSearch(n, key);
    if (result == NULL) {
        return NULL;
    }
    else {
        return result->value;
    }
}

void nWalk(Node* n) {
    if (n->left != NULL) {
        nWalk(n->left);
    }
    printf("%d\n", n->key);
    if (n->right != NULL) {
        nWalk(n->right);
    }
}

Node* nMin(Node* n) {
    if (n->left == NULL) {
        return n;
    }
    else {
        return nMin(n->left);
    }
}

Node* nSuccessor(Node* n) {
    if (n->right == NULL) {
        return NULL;
    }
    else {
        return nMin(n->right);
    }
}


int nRemove(Node* n, int key) {
    Node* d = nSearch(n, key);
    if (d == NULL) {
        return -1;
    }

    if (d->left != NULL && d->right != NULL) {
        Node* suc = nSuccessor(d);
        d->key = suc->key;
        d->value = suc->value;
        d = suc;
    }
    
    Node* d_child;
    if (d->left != NULL) {
        d_child = d->left;
        d_child->parent = d->parent;
    }
    else if (d->right != NULL) {
        d_child = d->right;
        d_child->parent = d->parent;
    }
    else {
        d_child = NULL;
    }

    if (d->parent != NULL) {
        if (d->parent->left == d) {
            d->parent->left = d_child;
        }
        else {
            d->parent->right = d_child;
        }

        free(d);
    }
    else {
        if (d_child == NULL) {
            // 这是最后一个节点,这个时候就不能free掉内存了,否则main中root将指向野地址
            n->key = 0;
            n->value = 0;
            return -2;
        }
        else {
            // 当被删除的是根节点的时候,要把继承者的数据复制到根节点上。
            // 因为在此函数内部无法改变main作用于中root的指向,如果删除掉根节点,mian中的root将指向野地址
            *n = *d_child;
            free(d_child);
            return 0;
        }
        
    }
    return 0;
}

int main() {
    int testData[20][2] = {
        {61, 6161},
        {30, 3030},
        {98, 9898},
        {3, 33},
        {36, 3636},
        {30, 3030},
        {6, 66},
        {54, 5454},
        {63, 6363},
        {93, 9393},
        {93, 9393},
        {76, 7676},
        {84, 8484},
        {16, 1616},
        {13, 1313},
        {76, 7676},
        {78, 7878},
        {29, 2929},
        {9, 99},
        {76, 7676}
    };
    Node* root = NewNode(50, 10000, NULL);
    int i;
    for (i = 0; i < 20; i++) {
        nAdd(root, testData[i][0], testData[i][1]);
    }

    nWalk(root);

    nRemove(root, 50);
    for (i = 0; i < 20; i++) {
        nRemove(root, testData[i][0]);
    }
    nWalk(root);
    return 0;
}

Python实现

写到这里其实有些困了,Python版本的删除我记得好像是有一些bug,因为我真的是懒得把一份不直觉的代码改对。而且本质上来讲这两个删除函数都是无可救药的crap code, 请原谅我将这些垃圾发布上来,改天我会尽量重写一份不辣眼睛的版本来赎罪,请原谅。 C语言重写后的版本我已经贴了上来,大家可以对比一下原来的版本,虽然是小修改,但是很满意很美观的改动。

有”tree”的版本:

# Python code:
class Node(object):
    def __init__(self, key, data, parent=None):
        self.key = key
        self.data = data
        self.parent = parent
        self.left = None
        self.right = None

    def add_child(self, key, data):
        if key == self.key:
            self.data = data
            return
        elif key < self.key:
            if self.left is None:
                self.left = Node(key, data, self)
            else:
                self.left.add_child(key, data)
        else:
            if self.right is None:
                self.right = Node(key, data, self)
            else:
                self.right.add_child(key, data)

    def get(self, key):
        if key == self.key:
            return self
        elif key < self.key:
            if self.left is None:
                return None
            else:
                return self.left.get(key)
        else:
            if self.right is None:
                return None
            else:
                return self.right.get(key)

    def get_data(self, key):
        r = self.get(key)
        if r is None:
            return None
        else:
            return r.data

    def in_order_walk(self):
        if self.left is not None:
            self.left.in_order_walk()
        print(self.key)
        if self.right is not None:
            self.right.in_order_walk()

    def get_max(self):
        if self.right is None:
            return self
        else:
            return self.right.get_max()

    def get_min(self):
        if self.left is None:
            return self
        else:
            return self.left.get_min()

    def predecessor(self):
        if self.left is not None:
            return self.left.get_max()
        else:
            return None

    def successor(self):
        if self.right is not None:
            return self.right.get_min()
        else:
            return None

    def remove(self, key):
        d = self.get(key)

        if not d.left and not d.right:
            if d.parent is None:
                # self = None
                return
            elif d.parent.left == d:
                d.parent.left = None
            else:
                d.parent.right = None
        elif bool(d.left) is not bool(d.right):
            if d.left:
                if d.parent.left == d:
                    d.parent.left = d.left
                    d.left.parent = d.parent
                else:
                    d.parent.right = d.left
                    d.left.parent = d.parent
            else:
                if d.parent.right == d:
                    d.parent.right = d.right
                    d.right.parent = d.parent
                else:
                    d.parent.left = d.right
                    d.right.parent = d.parent
        else:
            successor = d.successor()
            d.key = successor.key
            d.data = successor.data
            successor.remove(successor.key)


class Tree(object):
    def __init__(self, root):
        self.root = root

    def remove(self, key):
        d = self.root.get(key)
        if d.left and d.right:
            successor = d.successor()
            d.key = successor.key
            d.data = successor.data
            d = successor
        else:
            pass

        if d.left:
            d_child = d.left
            d_child.parent = d.parent
        elif d.right:
            d_child = d.right
            d_child.parent = d.parent
        else:
            d_child = None

        if d.parent:
            if d.parent.left == d:
                d.parent.left = d_child
            else:
                d.parent.right = d_child
        else:
            self.root = None

    def in_order_walk(self):
        if self.root:
            self.root.in_order_walk()
        else:
            print(None)

    def __getattr__(self, item):
        return getattr(self.root, item)


def main():
    data_list = [
        [20, 2222],
        [30, 3333],
        [40, 4444],
        [50, 5555],
        [60, 6666],
        [70, 7777],
        [80, 8888],
        [90, 9999],
    ]
    middle = data_list[5]
    tree = Tree(Node(middle[0], middle[1]))
    for i in data_list:
        tree.root.add_child(i[0], i[1])
    tree.add_child(1, 111)
    tree.remove(70)
    tree.in_order_walk()


if __name__ == '__main__':
    main()

    原文作者:我爱这世界
    原文地址: https://www.jianshu.com/p/6e05da24905b
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞