#ifndef __BINARY_SEARCH_H__ #define __BINARY_SEARCH_H__ #include <assert.h> #include <iostream> template <typename Key, typename Value> class BinarySearchTree; template <typename Key, typename Value> std::ostream& operator<<(std::ostream &out, BinarySearchTree<Key, Value>&); template <typename Key, typename Value> class BinarySearchTree { friend std::ostream& operator<< <Key, Value>(std::ostream &out, BinarySearchTree<Key, Value> &tree); struct Node { Key key; Value val; Node *left; Node *right; Node(Key pkey, Value pval):key(pkey), val(pval), left(nullptr), right(nullptr) {} }; Node *root; void traverse(Node *node, std::ostream &out) { if (node == nullptr) return; traverse(node->left, out); out << "(" << node->key << ", " << node->val << ") "; traverse(node->right, out); } public: BinarySearchTree(Key pkey, Value pval) :root(new Node(pkey, pval)) {} void put(Key key, Value val) { Node **node = &root; while (*node != nullptr) { if (key < (*node)->key) node = &((*node)->left); else if (key > (*node)->key) node = &((*node)->right); else return; } if(*node == nullptr) *node = new Node(key, val); } Value get(Key key) { Node *node = root; while (node != nullptr) { if (key < node->key) node = node->left; else if (key > node->key) node = node->right; else return node->val; } assert(false); } Node* deleteMin(Node* head) { if (head == nullptr) return nullptr; Node *node = head; Node *lastNode = nullptr; while (node->left != nullptr) { lastNode = node; node = node->left; } lastNode->left = node->right; //delete node; return head; } Node* min(Node *head) { if (head == nullptr) return nullptr; Node *node = head; while (node->left != nullptr) node = node->left; return node; } void deleteNode(Key key) { Node *lastNode = nullptr; Node *node = root; Node *newNode = nullptr; while (node != nullptr) { if (key < node->key) { lastNode = node; node = node->left; } else if (key > node->key) { lastNode = node; node = node->right; } else { Node **plastNode = nullptr; // 注意树根 if (lastNode == nullptr) plastNode = &root; else plastNode = &lastNode; // 无节点的情况 if (node->left == nullptr && node->right == nullptr) { if ((*plastNode)->left == node) { (*plastNode)->left = nullptr; delete node; return; } else if ((*plastNode)->right == node) { (*plastNode)->right = nullptr; delete node; return; } } // 只有一个节点的情况 if (node->left == nullptr) { if ((*plastNode)->left == node) { (*plastNode)->left = node->right; delete node; return; } else if ((*plastNode)->right == node) { (*plastNode)->right = node->right; delete node; return; } } if (node->right == nullptr) { if ((*plastNode)->right == node) { (*plastNode)->right = node->left; delete node; return; } else if ((*plastNode)->left == node) { (*plastNode)->left = node->left; delete node; return; } } // 两个节点的情况 Node *star = min(node->right); star->right = deleteMin(node->right); star->left = node->left; } } } }; template <typename Key, typename Value> std::ostream& operator<<(std::ostream &out, BinarySearchTree<Key, Value> &tree) { tree.traverse(tree.root, out); return out; } #endif