Last active
July 12, 2025 17:52
-
-
Save Element118/470ca60423de5a0bbf4dc8770b9a8b9f to your computer and use it in GitHub Desktop.
Immutable Persistent AVL Tree
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use std::rc::Rc; | |
| use std::fmt; | |
| use std::cmp::max; | |
| #[derive(Debug, Clone)] | |
| pub struct AvlNode<T> { | |
| balance: i8, | |
| data: T, | |
| left: Option<Rc<AvlNode<T>>>, | |
| right: Option<Rc<AvlNode<T>>>, | |
| } | |
| pub struct AvlNodeWithRight<'a, T> { | |
| node: &'a AvlNode<T>, | |
| right: &'a AvlNode<T>, | |
| } | |
| pub struct AvlNodeWithLeft<'a, T> { | |
| node: &'a AvlNode<T>, | |
| left: &'a AvlNode<T>, | |
| } | |
| #[derive(Clone)] | |
| pub struct AvlResult<N> { | |
| pub delta: i8, | |
| pub node: N, | |
| } | |
| pub type AvlNodeResult<T> = AvlResult<Rc<AvlNode<T>>>; | |
| pub type AvlOptionalNodeResult<T> = AvlResult<Option<Rc<AvlNode<T>>>>; | |
| impl<N: Clone> AvlResult<N> { | |
| pub fn add_delta(&self, delta: i8) -> AvlResult<N> { AvlResult { delta: self.delta + delta, node: self.node.clone() } } | |
| pub fn optional_box(&self) -> AvlResult<Option<N>> { | |
| AvlResult { delta: self.delta, node: Some(self.node.clone()) } | |
| } | |
| } | |
| impl<T> AvlNode<T> { | |
| pub fn new(data: T) -> Self { | |
| AvlNode { | |
| balance: 0, | |
| data, | |
| left: None, | |
| right: None, | |
| } | |
| } | |
| pub fn view_with_left(&self) -> Option<AvlNodeWithLeft<T>> { | |
| self.left.as_ref().map(|r| AvlNodeWithLeft { node: self, left: r }) | |
| } | |
| pub fn view_with_right(&self) -> Option<AvlNodeWithRight<T>> { | |
| self.right.as_ref().map(|r| AvlNodeWithRight { node: self, right: r }) | |
| } | |
| } | |
| impl<T: Clone> AvlNodeWithRight<'_, T> { | |
| pub fn rotate_left(&self) -> AvlNodeResult<T> { | |
| // relative heights | |
| let rl_height = 0; | |
| let rr_height = self.right.balance; | |
| let l_height = max(rr_height, rl_height) + 1 - self.node.balance; | |
| let old_height = max(l_height, max(rr_height, rl_height)+1); | |
| let new_height = max(rr_height, max(rl_height, l_height)+1); | |
| AvlNodeResult { delta: new_height - old_height, node: Rc::new(AvlNode { | |
| balance: rr_height - (max(rl_height, l_height) + 1), | |
| data: self.right.data.clone(), | |
| left: Some(Rc::new(AvlNode { | |
| balance: rl_height - l_height, | |
| data: self.node.data.clone(), | |
| left: self.node.left.clone(), | |
| right: self.right.left.clone(), | |
| })), | |
| right: self.right.right.clone(), | |
| }) } | |
| } | |
| } | |
| impl<T: Clone> AvlNodeWithLeft<'_, T> { | |
| pub fn rotate_right(&self) -> AvlNodeResult<T> { | |
| // relative heights | |
| let ll_height = 0; | |
| let lr_height = self.left.balance; | |
| let r_height = max(lr_height, ll_height) + 1 + self.node.balance; | |
| let old_height = max(r_height, max(lr_height, ll_height)+1); | |
| let new_height = max(ll_height, max(r_height, lr_height)+1); | |
| AvlNodeResult { delta: new_height - old_height, node: Rc::new(AvlNode { | |
| balance: (max(r_height, lr_height) + 1) - ll_height, | |
| data: self.left.data.clone(), | |
| left: self.left.left.clone(), | |
| right: Some(Rc::new(AvlNode { | |
| balance: r_height - lr_height, | |
| data: self.node.data.clone(), | |
| left: self.left.right.clone(), | |
| right: self.node.right.clone(), | |
| })), | |
| }) } | |
| } | |
| } | |
| pub fn balanced_avl_node_from<T>( | |
| delta: i8, data: T, | |
| left: Option<Rc<AvlNode<T>>>, | |
| right: Option<Rc<AvlNode<T>>>, | |
| balance: i8) -> AvlNodeResult<T> | |
| where T: Clone { | |
| (AvlNodeResult { | |
| delta, | |
| node: Rc::new(AvlNode { | |
| data, | |
| left, | |
| right, | |
| balance, | |
| }) | |
| }).balance() | |
| } | |
| pub fn avl_insert<T>(node: &Option<Rc<AvlNode<T>>>, value: T) -> AvlNodeResult<T> | |
| where T: Ord + Clone, { | |
| match node { | |
| None => AvlNodeResult { delta: 1, node: Rc::new(AvlNode::new(value)) }, // new leaf increases height | |
| Some(node) => { | |
| if value < node.data { | |
| let AvlNodeResult { delta, node: new_left } = avl_insert(&node.left, value); | |
| // relative heights | |
| let old_l_height = 0; | |
| let l_height = delta; | |
| let r_height = node.balance; | |
| balanced_avl_node_from( | |
| max(r_height, l_height) - max(r_height, old_l_height), | |
| node.data.clone(), | |
| Some(new_left), | |
| node.right.clone(), | |
| r_height - l_height // left heavier | |
| ) | |
| } else if value > node.data { | |
| let AvlNodeResult { delta, node: new_right } = avl_insert(&node.right, value); | |
| // relative heights | |
| let l_height = 0; | |
| let old_r_height = node.balance; | |
| let r_height = old_r_height + delta; | |
| balanced_avl_node_from( | |
| max(r_height, l_height) - max(old_r_height, l_height), | |
| node.data.clone(), | |
| node.left.clone(), | |
| Some(new_right), | |
| r_height - l_height // right heavier | |
| ) | |
| } else { | |
| AvlNodeResult { delta: 0, node: node.clone() } // value already exists; no change | |
| } | |
| }, | |
| } | |
| } | |
| pub fn avl_delete<T>(node: &Option<Rc<AvlNode<T>>>, value: &T) -> AvlOptionalNodeResult<T> | |
| where T: Ord + Clone, { | |
| match node { | |
| None => AvlOptionalNodeResult { delta: 0, node: None }, | |
| Some(node) => { | |
| if *value < node.data { | |
| let AvlOptionalNodeResult { delta, node: new_left } = avl_delete(&node.left, value); | |
| // relative heights | |
| let old_l_height = 0; | |
| let l_height = delta; | |
| let r_height = node.balance; | |
| balanced_avl_node_from( | |
| max(r_height, l_height) - max(r_height, old_l_height), | |
| node.data.clone(), | |
| new_left, | |
| node.right.clone(), | |
| r_height - l_height | |
| ).optional_box() | |
| } else if *value > node.data { | |
| let AvlOptionalNodeResult { delta, node: new_right } = avl_delete(&node.right, value); | |
| // relative heights | |
| let l_height = 0; | |
| let old_r_height = node.balance; | |
| let r_height = old_r_height + delta; | |
| balanced_avl_node_from( | |
| max(r_height, l_height) - max(old_r_height, l_height), | |
| node.data.clone(), | |
| node.left.clone(), | |
| new_right, | |
| r_height - l_height | |
| ).optional_box() | |
| } else { | |
| avl_delete_root(node).0 | |
| } | |
| }, | |
| } | |
| } | |
| pub fn avl_delete_root<T>(node: &Rc<AvlNode<T>>) -> (AvlOptionalNodeResult<T>, T) | |
| where T: Ord + Clone, { | |
| (match (&node.left, &node.right) { | |
| (None, None) => AvlResult { delta: -1, node: None }, | |
| (None, Some(right)) => AvlOptionalNodeResult { delta: -1, node: Some(right.clone()) }, | |
| (Some(left), None) => AvlOptionalNodeResult { delta: -1, node: Some(left.clone()) }, | |
| (Some(left), Some(right)) => { | |
| let (subtree, value) = avl_delete_min(right); | |
| // relative heights | |
| let l_height = 0; | |
| let old_r_height = node.balance; | |
| let r_height = old_r_height + subtree.delta; | |
| balanced_avl_node_from( | |
| max(r_height, l_height) - max(old_r_height, l_height), | |
| value, | |
| Some(left.clone()), | |
| subtree.node, | |
| r_height - l_height | |
| ).optional_box() | |
| }, | |
| }, node.data.clone()) | |
| } | |
| pub fn avl_delete_min<T>(node: &Rc<AvlNode<T>>) -> (AvlOptionalNodeResult<T>, T) | |
| where T: Ord + Clone, { | |
| match &node.left { | |
| None => avl_delete_root(node), | |
| Some(left) => { | |
| let (subtree, value) = avl_delete_min(left); | |
| // relative heights | |
| let old_l_height = 0; | |
| let r_height = node.balance; | |
| let l_height = old_l_height + subtree.delta; | |
| (balanced_avl_node_from( | |
| max(r_height, l_height) - max(r_height, old_l_height), | |
| node.data.clone(), | |
| subtree.node.clone(), | |
| node.right.clone(), | |
| r_height - l_height | |
| ).optional_box(), value) | |
| } | |
| } | |
| } | |
| pub fn new_balance_delta(balance: i8, left_delta: i8, right_delta: i8) -> (i8, i8) { | |
| let l_height = 0; | |
| let r_height = balance; | |
| let new_l_height = l_height + left_delta; | |
| let new_r_height = r_height + right_delta; | |
| let old_height = max(r_height, l_height); | |
| let new_height = max(new_r_height, new_l_height); | |
| (balance + right_delta - left_delta, new_height - old_height) | |
| } | |
| impl<T: Clone> AvlNodeResult<T> { | |
| pub fn balance(&self) -> AvlNodeResult<T> { | |
| match (self.node.balance, self.node.left.clone(), self.node.right.clone()) { | |
| (2, _, Some(right)) => { | |
| if right.balance >= 0 { // Right-right case | |
| self.node.view_with_right().unwrap().rotate_left().add_delta(self.delta) | |
| } else { // Right-left case | |
| let intermediate = right.view_with_left().unwrap().rotate_right(); | |
| let (balance, delta) = new_balance_delta(self.node.balance, 0, intermediate.delta); | |
| (AvlNode { | |
| balance, | |
| data: self.node.data.clone(), | |
| left: self.node.left.clone(), | |
| right: Some(intermediate.node), | |
| }).view_with_right().unwrap().rotate_left().add_delta(delta + self.delta) | |
| } | |
| }, | |
| (-2, Some(left), _) => { | |
| if left.balance <= 0 { // Left-left case | |
| self.node.view_with_left().unwrap().rotate_right().add_delta(self.delta) | |
| } else { // Left-right case | |
| let intermediate = left.view_with_right().unwrap().rotate_left(); | |
| let (balance, delta) = new_balance_delta(self.node.balance, intermediate.delta, 0); | |
| (AvlNode { | |
| balance, | |
| data: self.node.data.clone(), | |
| left: Some(intermediate.node), | |
| right: self.node.right.clone(), | |
| }).view_with_left().unwrap().rotate_right().add_delta(delta + self.delta) | |
| } | |
| }, | |
| (balance, _, _) => { | |
| assert!((-1..=1).contains(&balance)); | |
| self.clone() | |
| }, | |
| } | |
| } | |
| } | |
| pub fn avl_find<T>(node: &Option<Rc<AvlNode<T>>>, value: &T) -> Option<Rc<AvlNode<T>>> | |
| where T: Ord, { | |
| use std::cmp::Ordering; | |
| let Some(node) = node else { return None }; | |
| match (*value).cmp(&node.data) { | |
| Ordering::Less => avl_find(&node.left, value), | |
| Ordering::Greater => avl_find(&node.right, value), | |
| Ordering::Equal => Some(node.clone()), | |
| } | |
| } | |
| impl<T: fmt::Display> fmt::Display for AvlNode<T> { | |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
| write!(f, "({}, {}, ", self.data, self.balance)?; | |
| match &self.left { | |
| Some(left) => left.fmt(f)?, // recursively calls fmt | |
| None => write!(f, "-")?, | |
| } | |
| write!(f, ", ")?; // comma between left and right | |
| match &self.right { | |
| Some(right) => right.fmt(f)?, // recursively calls fmt | |
| None => write!(f, "-")?, | |
| } | |
| write!(f, ")") | |
| } | |
| } | |
| pub struct AvlTree<T> { | |
| root: Option<Rc<AvlNode<T>>> | |
| } | |
| impl<T: Ord + Clone> AvlTree<T> { | |
| pub fn new() -> Self { | |
| Self { root: None } | |
| } | |
| pub fn insert(&mut self, value: T) { | |
| self.root = Some(avl_insert(&self.root, value).node); | |
| } | |
| pub fn delete(&mut self, value: &T) { | |
| self.root = avl_delete(&self.root, value).node; | |
| } | |
| pub fn find(&self, value: &T) -> Option<Rc<AvlNode<T>>> { | |
| avl_find(&self.root, value) | |
| } | |
| } | |
| impl<T: fmt::Display> fmt::Display for AvlTree<T> { | |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
| match &self.root { | |
| Some(node) => node.fmt(f), // recursively calls fmt | |
| None => write!(f, "(Empty AvlTree)") | |
| } | |
| } | |
| } | |
| pub fn main() { | |
| println!("Hello, avl!"); | |
| let x = Rc::new(AvlNode::new(1)); | |
| let y = Rc::new(AvlNode { | |
| balance: 0, | |
| data: 2, | |
| left: Some(x.clone()), | |
| right: Some(x.clone()), | |
| }); | |
| println!("x: {}", x); | |
| println!("y: {}", y); | |
| let z = y.view_with_right().unwrap().rotate_left(); | |
| println!("z: {}", z.node); | |
| let w = y.view_with_left().unwrap().rotate_right(); | |
| println!("w: {}", w.node); | |
| } | |
| #[cfg(test)] | |
| mod tests { | |
| use super::*; | |
| use std::fmt::Write; | |
| fn check_avl_height_balance<T>(node: &Option<Rc<AvlNode<T>>>) -> (i8, bool) { | |
| match node { | |
| None => (-1, true), | |
| Some(node) => { | |
| let (l_height, left) = check_avl_height_balance(&node.left); | |
| let (r_height, right) = check_avl_height_balance(&node.right); | |
| (max(l_height, r_height)+1, left && right && (-1 <= node.balance && node.balance <= 1) && node.balance == r_height - l_height) | |
| } | |
| } | |
| } | |
| #[test] | |
| fn test_basic_insertion() { | |
| let mut tree: AvlTree<i32> = AvlTree::new(); | |
| let values = vec![10, 5, 15, 3, 7, 13, 17]; | |
| for v in values { | |
| tree.insert(v); | |
| } | |
| let expected = "(10, 0, (5, 0, (3, 0, -, -), (7, 0, -, -)), (15, 0, (13, 0, -, -), (17, 0, -, -)))"; | |
| assert_eq!(format!("{}", tree.root.clone().unwrap()), expected); | |
| let (height, balance) = check_avl_height_balance(&tree.root); | |
| assert_eq!(height, 2); | |
| assert!(balance); | |
| } | |
| #[test] | |
| fn test_insertion_two() { | |
| let mut tree: AvlTree<i32> = AvlTree::new(); | |
| let values = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; | |
| for v in values { | |
| tree.insert(v); | |
| } | |
| let expected = "(4, 1, (2, 0, (1, 0, -, -), (3, 0, -, -)), (8, 0, (6, 0, (5, 0, -, -), (7, 0, -, -)), (9, 1, -, (10, 0, -, -))))"; | |
| assert_eq!(format!("{}", tree), expected); | |
| let (height, balance) = check_avl_height_balance(&tree.root); | |
| assert_eq!(height, 3); | |
| assert!(balance); | |
| } | |
| #[test] | |
| fn test_large_performance() { | |
| use std::rc::Rc; | |
| use std::time::Instant; | |
| use std::hint::black_box; | |
| use rand::{SeedableRng, seq::SliceRandom}; | |
| use rand_chacha::ChaCha8Rng; | |
| let mut tree: AvlTree<u32> = AvlTree::new(); | |
| { | |
| let mut insert_values: Vec<u32> = (1..=100_000).collect(); | |
| // Fixed seed for reproducibility | |
| let mut rng = ChaCha8Rng::seed_from_u64(42); | |
| insert_values.shuffle(&mut rng); | |
| let start = Instant::now(); | |
| for v in insert_values { | |
| tree.insert(v); | |
| } | |
| let duration = start.elapsed(); | |
| println!("Inserting 100,000 elements took {:?}", duration); | |
| } | |
| assert!(tree.root.is_some()); | |
| let (height, balance) = check_avl_height_balance(&tree.root); | |
| assert_eq!(height, 19); | |
| assert!(balance); | |
| { | |
| let mut find_values: Vec<u32> = (1..=100_000).collect(); | |
| // Fixed seed for reproducibility | |
| let mut rng = ChaCha8Rng::seed_from_u64(1234); | |
| find_values.shuffle(&mut rng); | |
| let start = Instant::now(); | |
| for v in find_values { | |
| assert!(tree.find(&v).is_some()); | |
| } | |
| let duration = start.elapsed(); | |
| println!("Finding 100,000 elements took {:?}", duration); | |
| } | |
| { | |
| let mut delete_values: Vec<u32> = (1..=100_000).collect(); | |
| // Fixed seed for reproducibility | |
| let mut rng = ChaCha8Rng::seed_from_u64(1337); | |
| delete_values.shuffle(&mut rng); | |
| let start = Instant::now(); | |
| for v in delete_values { | |
| tree.delete(&v); | |
| } | |
| let duration = start.elapsed(); | |
| println!("Deleting 100,000 elements took {:?}", duration); | |
| } | |
| assert!(tree.root.is_none()); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment