Last active
October 3, 2021 21:28
-
-
Save upsuper/6332576 to your computer and use it in GitHub Desktop.
A Red-Black Tree implemented in C++ (need C++11)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef RBTREE_RBTREE_H_ | |
#define RBTREE_RBTREE_H_ | |
#include <cstddef> | |
#include <cassert> | |
#include <utility> | |
namespace upsuper { | |
namespace learning { | |
// A macro to disallow the copy constructor and operator= functions | |
// This should be used in the private: declarations for a class | |
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ | |
TypeName(const TypeName&); \ | |
void operator=(const TypeName&) | |
template <class Key> | |
class RBTree { | |
public: | |
typedef std::size_t size_type; | |
inline RBTree() : nil_(new Node), count_(0) { | |
root_ = nil_; | |
nil_->parent = nil_; | |
nil_->left = nil_; | |
nil_->right = nil_; | |
nil_->color = kBlack; | |
} | |
inline ~RBTree() { | |
FreeSubtree(root_); | |
delete nil_; | |
} | |
bool Put(const Key& key); | |
bool Remove(const Key& key); | |
inline bool Contains(const Key& key) const { | |
Node *node = FindNodeOrParent(key); | |
return !IsNil(node) && node->key == key; | |
} | |
inline size_type Count() const { | |
return count_; | |
} | |
inline bool Empty() const { | |
return count_ == 0; | |
} | |
protected: | |
enum Color { kRed, kBlack }; | |
struct Node { | |
Node *parent; | |
Node *left; | |
Node *right; | |
Color color; | |
Key key; | |
}; | |
inline const Node *GetRoot() const { | |
return root_; | |
} | |
inline bool IsNil(const Node *node) const { | |
return node == nil_; | |
} | |
inline bool IsRed(const Node *node) const { | |
return node->color == kRed; | |
} | |
inline bool IsBlack(const Node *node) const { | |
return node->color == kBlack; | |
} | |
private: | |
inline void SetRed(Node *node) { | |
assert(node != nil_); | |
node->color = kRed; | |
} | |
inline void SetBlack(Node *node) { | |
node->color = kBlack; | |
} | |
inline bool IsLeftChild(const Node *node) const { | |
return node->parent->left == node; | |
} | |
inline bool IsRightChild(const Node *node) const { | |
return node->parent->right == node; | |
} | |
inline void SetLeft(Node *node, Node *child) { | |
assert(!IsNil(node)); | |
node->left = child; | |
if (!IsNil(child)) | |
child->parent = node; | |
} | |
inline void SetRight(Node *node, Node *child) { | |
assert(!IsNil(node)); | |
node->right = child; | |
if (!IsNil(child)) | |
child->parent = node; | |
} | |
inline Node *GetSibling(const Node *node) const { | |
if (IsLeftChild(node)) | |
return node->parent->right; | |
else if (IsRightChild(node)) | |
return node->parent->left; | |
assert(false); | |
} | |
inline Node *ReplaceChild(Node *child, Node *new_child) { | |
if (IsNil(child->parent)) { | |
root_ = new_child; | |
new_child->parent = nil_; | |
} else if (IsLeftChild(child)) { | |
SetLeft(child->parent, new_child); | |
} else if (IsRightChild(child)) { | |
SetRight(child->parent, new_child); | |
} else { assert(false); } | |
return new_child; | |
} | |
inline Node *LeftRotate(Node *node) { | |
assert(node != nil_ && node->right != nil_); | |
Node *child = node->right; | |
ReplaceChild(node, child); | |
SetRight(node, child->left); | |
SetLeft(child, node); | |
std::swap(node->color, child->color); | |
return child; | |
} | |
inline Node *RightRotate(Node *node) { | |
assert(node != nil_ && node->left != nil_); | |
Node *child = node->left; | |
ReplaceChild(node, child); | |
SetLeft(node, child->right); | |
SetRight(child, node); | |
std::swap(node->color, child->color); | |
return child; | |
} | |
inline Node *ReverseRotate(Node *node) { | |
if (IsLeftChild(node)) | |
return RightRotate(node->parent); | |
else if (IsRightChild(node)) | |
return LeftRotate(node->parent); | |
assert(false); | |
} | |
inline Node *FindNodeOrParent(const Key& key) const { | |
Node *node = root_; | |
Node *parent = nil_; | |
while (!IsNil(node)) { | |
if (node->key == key) return node; | |
parent = node; | |
node = node->key > key ? node->left : node->right; | |
} | |
return parent; | |
} | |
void FixInsert(Node *node); | |
void FixRemove(Node *node); | |
void FreeSubtree(Node *root) { | |
if (root != nil_) { | |
FreeSubtree(root->left); | |
FreeSubtree(root->right); | |
delete root; | |
} | |
} | |
Node *root_; | |
Node *nil_; | |
size_type count_; | |
DISALLOW_COPY_AND_ASSIGN(RBTree<Key>); | |
}; | |
/* Public */ | |
template <class Key> | |
bool RBTree<Key>::Put(const Key& key) { | |
Node *parent = FindNodeOrParent(key); | |
if (!IsNil(parent) && parent->key == key) | |
return false; | |
Node *node = new Node{nil_, nil_, nil_, kRed, key}; | |
if (IsNil(parent)) { | |
root_ = node; | |
} else { // !IsNil(parent) | |
if (key < parent->key) | |
SetLeft(parent, node); | |
else | |
SetRight(parent, node); | |
} | |
FixInsert(node); | |
++count_; | |
return true; | |
} | |
template <class Key> | |
bool RBTree<Key>::Remove(const Key& key) { | |
Node *node = FindNodeOrParent(key); | |
Node *child; | |
if (IsNil(node) || node->key != key) | |
return false; | |
if (IsNil(node->right)) { | |
child = node->left; | |
} else if (IsNil(node->left)) { | |
child = node->right; | |
} else { | |
Node *sub = node->right; | |
while (!IsNil(sub->left)) | |
sub = sub->left; | |
node->key = std::move(sub->key); | |
node = sub; | |
child = sub->right; | |
} | |
child = IsNil(child) ? node : ReplaceChild(node, child); | |
if (IsBlack(node)) | |
FixRemove(child); | |
if (node == child) | |
ReplaceChild(node, nil_); | |
delete node; | |
--count_; | |
return true; | |
} | |
/* Private */ | |
template <class Key> | |
void RBTree<Key>::FixInsert(Node *node) { | |
while (!IsBlack(node) && !IsBlack(node->parent)) { | |
Node *parent = node->parent; | |
Node *uncle = GetSibling(parent); | |
if (IsRed(uncle)) { | |
SetBlack(uncle); | |
SetBlack(parent); | |
SetRed(parent->parent); | |
node = parent->parent; | |
} else { // IsBlack(uncle) | |
if (IsLeftChild(node) != IsLeftChild(parent)) | |
parent = ReverseRotate(node); | |
node = ReverseRotate(parent); | |
} | |
} | |
if (IsNil(node->parent)) | |
SetBlack(node); | |
} | |
template <class Key> | |
void RBTree<Key>::FixRemove(Node *node) { | |
while (!IsRed(node) && !IsNil(node->parent)) { | |
Node *sibling = GetSibling(node); | |
if (IsRed(sibling)) { | |
ReverseRotate(sibling); | |
sibling = GetSibling(node); | |
} | |
if (IsBlack(sibling->left) && IsBlack(sibling->right)) { | |
SetRed(sibling); | |
node = node->parent; | |
} else { | |
if (IsLeftChild(sibling) && !IsRed(sibling->left)) | |
sibling = LeftRotate(sibling); | |
else if (IsRightChild(sibling) && !IsRed(sibling->right)) | |
sibling = RightRotate(sibling); | |
ReverseRotate(sibling); | |
node = GetSibling(node->parent); | |
} | |
} | |
SetBlack(node); | |
} | |
} // namespace learning | |
} // namespace upsuper | |
#endif // RBTREE_RBTREE_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <set> | |
#include <random> | |
#include <iostream> | |
#include "rbtree_tester.hpp" | |
using std::cout; | |
using std::cerr; | |
using std::ends; | |
using std::endl; | |
using upsuper::learning::RBTreeTester; | |
template <class Key> | |
void __attribute__ ((noreturn)) PrintTreeAndExit( | |
const std::vector<std::string>& orig_tree, const RBTreeTester<Key>& tree) { | |
for (auto line : orig_tree) | |
cout << line << endl; | |
auto lines = tree.PrintTree(); | |
for (auto line : lines) | |
cout << line << endl; | |
exit(1); | |
} | |
void SequenceInsert(RBTreeTester<int>& tree, const int n) { | |
assert(tree.Empty()); | |
for (int i = 0; i < n; ++i) { | |
auto orig_tree = tree.PrintTree(); | |
tree.Put(i); | |
if (tree.Count() != i + 1 || !tree.Contains(i) || !tree.Verify()) { | |
cout << "SequenceInsert: " << i << endl; | |
PrintTreeAndExit(orig_tree, tree); | |
} | |
} | |
} | |
void SequenceRemove(RBTreeTester<int>& tree, const int n) { | |
assert(tree.Count() == n); | |
for (int i = 0; i < n; ++i) { | |
auto orig_tree = tree.PrintTree(); | |
tree.Remove(i); | |
if (tree.Count() != n - i - 1 || tree.Contains(i) || !tree.Verify()) { | |
cout << "SequenceInsert: " << i << endl; | |
PrintTreeAndExit(orig_tree, tree); | |
} | |
} | |
} | |
void ReverseInsert(RBTreeTester<int>& tree, const int n) { | |
assert(tree.Empty()); | |
for (int i = n - 1; i >= 0; --i) { | |
auto orig_tree = tree.PrintTree(); | |
tree.Put(i); | |
if (tree.Count() != n - i || !tree.Contains(i) || !tree.Verify()) { | |
cout << "ReverseInsert: " << i << endl; | |
PrintTreeAndExit(orig_tree, tree); | |
} | |
} | |
} | |
void ReverseRemove(RBTreeTester<int>& tree, const int n) { | |
assert(tree.Count() == n); | |
for (int i = n - 1; i >= 0; --i) { | |
auto orig_tree = tree.PrintTree(); | |
tree.Remove(i); | |
if (tree.Count() != i || tree.Contains(i) || !tree.Verify()) { | |
cout << "ReverseRemove: " << i << endl; | |
PrintTreeAndExit(orig_tree, tree); | |
} | |
} | |
} | |
void RandomOperations(RBTreeTester<int>& tree, const int n) { | |
assert(tree.Empty()); | |
std::set<int> ref; | |
std::random_device rd; | |
auto seed = rd(); | |
std::mt19937 gen(seed); | |
std::bernoulli_distribution op_dist(0.8); | |
std::uniform_int_distribution<> val_dist(0, n - 1); | |
for (int i = 0; i < n * 5; ++i) { | |
auto orig_tree = tree.PrintTree(); | |
bool add_item = op_dist(gen); | |
int val = val_dist(gen); | |
if (add_item) { | |
ref.insert(val); | |
tree.Put(val); | |
} else { | |
ref.erase(val); | |
tree.Remove(val); | |
} | |
if (tree.Count() != ref.size() || !tree.Verify()) { | |
cout << "(Seed: " << seed << ") " << | |
(add_item ? "Add " : "Remove ") << val << endl; | |
PrintTreeAndExit(orig_tree, tree); | |
} | |
} | |
} | |
int main() { | |
RBTreeTester<int> tree; | |
const int n = 1000; | |
cout << "SequenceInsert & SequenceRemove" << endl; | |
SequenceInsert(tree, n); | |
SequenceRemove(tree, n); | |
cout << "SequenceInsert & ReverseRemove" << endl; | |
SequenceInsert(tree, n); | |
ReverseRemove(tree, n); | |
cout << "ReverseInsert & SequenceRemove" << endl; | |
ReverseInsert(tree, n); | |
SequenceRemove(tree, n); | |
cout << "ReverseInsert & ReverseRemove" << endl; | |
ReverseInsert(tree, n); | |
ReverseRemove(tree, n); | |
cout << "RandomOperations" << endl; | |
RandomOperations(tree, n); | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef RBTREE_RBTREE_TESTER_H_ | |
#define RBTREE_RBTREE_TESTER_H_ | |
#include <string> | |
#include <vector> | |
#include "rbtree.hpp" | |
namespace upsuper { | |
namespace learning { | |
template <class Key> | |
class RBTreeTester : public RBTree<Key> { | |
public: | |
using size_type = typename RBTree<Key>::size_type; | |
using RBTree<Key>::Count; | |
bool Verify() const; | |
inline std::vector<std::string> PrintTree() const { | |
std::vector<std::string> tree; | |
if (!IsNil(GetRoot())) { | |
tree.push_back(""); | |
BuildPrintTree(GetRoot(), tree); | |
} | |
return std::move(tree); | |
} | |
private: | |
using Node = typename RBTree<Key>::Node; | |
using RBTree<Key>::GetRoot; | |
using RBTree<Key>::IsNil; | |
using RBTree<Key>::IsRed; | |
using RBTree<Key>::IsBlack; | |
int Traverse(const Node *node, const Key *min, const Key *max, | |
size_type *count) const; | |
void BuildPrintTree(const Node *node, std::vector<std::string>& tree) const; | |
}; | |
/* Public */ | |
template <class Key> | |
bool RBTreeTester<Key>::Verify() const { | |
const Node *root = GetRoot(); | |
if (!IsBlack(root) || !IsNil(root->parent)) | |
return false; | |
const Node *nil = root->parent; | |
if (!IsBlack(nil) || !IsNil(nil->parent) || | |
!IsNil(nil->left) || !IsNil(nil->right)) | |
return false; | |
size_type count = 0; | |
size_type bh = Traverse(root, nullptr, nullptr, &count); | |
if (bh == -1) return false; | |
if (count != Count()) return false; | |
return true; | |
} | |
/* Private */ | |
template <class Key> | |
int RBTreeTester<Key>::Traverse(const Node *node, | |
const Key *min, const Key *max, | |
size_type *count) const { | |
if (IsNil(node)) return 0; | |
if (min != nullptr && node->key <= *min) return -1; | |
if (max != nullptr && node->key >= *max) return -1; | |
if (IsRed(node)) | |
if (!IsBlack(node->left) || !IsBlack(node->right)) | |
return -1; | |
int left_bh = Traverse(node->left, min, &node->key, count); | |
int right_bh = Traverse(node->right, &node->key, max, count); | |
if (left_bh == -1 || right_bh == -1 || left_bh != right_bh) | |
return -1; | |
++*count; | |
return IsBlack(node) ? left_bh + 1 : left_bh; | |
} | |
template<class Key> | |
void RBTreeTester<Key>::BuildPrintTree( | |
const Node *node, std::vector<std::string>& tree) const { | |
auto& line = tree.back(); | |
line.append(" ").append(std::to_string(node->key)) | |
.append(IsRed(node) ? "(R)" : "(B)").append(" "); | |
auto len = line.size(); | |
if (!IsNil(node->left)) { | |
line.append("-"); | |
BuildPrintTree(node->left, tree); | |
} | |
if (!IsNil(node->right)) { | |
for (auto iter = tree.rbegin(); (*iter)[len] == ' '; ++iter) | |
(*iter)[len] = '|'; | |
std::string line2(len, ' '); | |
line2.append("\\"); | |
tree.push_back(line2); | |
BuildPrintTree(node->right, tree); | |
} | |
} | |
} // namespace learning | |
} // namespace upsuper | |
#endif // RBTREE_RBTREE_TESTER_H_ |
why are the copy constructor and assignment operator disallowed in this code?
why are the copy constructor and assignment operator disallowed in this code?
I can't recall. It could be either that I was just too lazy to implement them in the proper way, or I thought they can be very misleading to use.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
rbtree.hpp:172:24: error: expected ';' at end of declaration
Node *node = new Node{nil_, nil_, nil_, kRed, key};
^
;
1 error generated.