Skip to content

Instantly share code, notes, and snippets.

@Element118
Last active July 12, 2025 17:52
Show Gist options
  • Select an option

  • Save Element118/470ca60423de5a0bbf4dc8770b9a8b9f to your computer and use it in GitHub Desktop.

Select an option

Save Element118/470ca60423de5a0bbf4dc8770b9a8b9f to your computer and use it in GitHub Desktop.
Immutable Persistent AVL Tree
use std::rc::Rc;
use std::fmt;
use std::cmp::max;
#[derive(Debug, Clone)]
pub struct AvlNode<T> {
balance: i8,
data: T,
left: Option<Rc<AvlNode<T>>>,
right: Option<Rc<AvlNode<T>>>,
}
pub struct AvlNodeWithRight<'a, T> {
node: &'a AvlNode<T>,
right: &'a AvlNode<T>,
}
pub struct AvlNodeWithLeft<'a, T> {
node: &'a AvlNode<T>,
left: &'a AvlNode<T>,
}
#[derive(Clone)]
pub struct AvlResult<N> {
pub delta: i8,
pub node: N,
}
pub type AvlNodeResult<T> = AvlResult<Rc<AvlNode<T>>>;
pub type AvlOptionalNodeResult<T> = AvlResult<Option<Rc<AvlNode<T>>>>;
impl<N: Clone> AvlResult<N> {
pub fn add_delta(&self, delta: i8) -> AvlResult<N> { AvlResult { delta: self.delta + delta, node: self.node.clone() } }
pub fn optional_box(&self) -> AvlResult<Option<N>> {
AvlResult { delta: self.delta, node: Some(self.node.clone()) }
}
}
impl<T> AvlNode<T> {
pub fn new(data: T) -> Self {
AvlNode {
balance: 0,
data,
left: None,
right: None,
}
}
pub fn view_with_left(&self) -> Option<AvlNodeWithLeft<T>> {
self.left.as_ref().map(|r| AvlNodeWithLeft { node: self, left: r })
}
pub fn view_with_right(&self) -> Option<AvlNodeWithRight<T>> {
self.right.as_ref().map(|r| AvlNodeWithRight { node: self, right: r })
}
}
impl<T: Clone> AvlNodeWithRight<'_, T> {
pub fn rotate_left(&self) -> AvlNodeResult<T> {
// relative heights
let rl_height = 0;
let rr_height = self.right.balance;
let l_height = max(rr_height, rl_height) + 1 - self.node.balance;
let old_height = max(l_height, max(rr_height, rl_height)+1);
let new_height = max(rr_height, max(rl_height, l_height)+1);
AvlNodeResult { delta: new_height - old_height, node: Rc::new(AvlNode {
balance: rr_height - (max(rl_height, l_height) + 1),
data: self.right.data.clone(),
left: Some(Rc::new(AvlNode {
balance: rl_height - l_height,
data: self.node.data.clone(),
left: self.node.left.clone(),
right: self.right.left.clone(),
})),
right: self.right.right.clone(),
}) }
}
}
impl<T: Clone> AvlNodeWithLeft<'_, T> {
pub fn rotate_right(&self) -> AvlNodeResult<T> {
// relative heights
let ll_height = 0;
let lr_height = self.left.balance;
let r_height = max(lr_height, ll_height) + 1 + self.node.balance;
let old_height = max(r_height, max(lr_height, ll_height)+1);
let new_height = max(ll_height, max(r_height, lr_height)+1);
AvlNodeResult { delta: new_height - old_height, node: Rc::new(AvlNode {
balance: (max(r_height, lr_height) + 1) - ll_height,
data: self.left.data.clone(),
left: self.left.left.clone(),
right: Some(Rc::new(AvlNode {
balance: r_height - lr_height,
data: self.node.data.clone(),
left: self.left.right.clone(),
right: self.node.right.clone(),
})),
}) }
}
}
pub fn balanced_avl_node_from<T>(
delta: i8, data: T,
left: Option<Rc<AvlNode<T>>>,
right: Option<Rc<AvlNode<T>>>,
balance: i8) -> AvlNodeResult<T>
where T: Clone {
(AvlNodeResult {
delta,
node: Rc::new(AvlNode {
data,
left,
right,
balance,
})
}).balance()
}
pub fn avl_insert<T>(node: &Option<Rc<AvlNode<T>>>, value: T) -> AvlNodeResult<T>
where T: Ord + Clone, {
match node {
None => AvlNodeResult { delta: 1, node: Rc::new(AvlNode::new(value)) }, // new leaf increases height
Some(node) => {
if value < node.data {
let AvlNodeResult { delta, node: new_left } = avl_insert(&node.left, value);
// relative heights
let old_l_height = 0;
let l_height = delta;
let r_height = node.balance;
balanced_avl_node_from(
max(r_height, l_height) - max(r_height, old_l_height),
node.data.clone(),
Some(new_left),
node.right.clone(),
r_height - l_height // left heavier
)
} else if value > node.data {
let AvlNodeResult { delta, node: new_right } = avl_insert(&node.right, value);
// relative heights
let l_height = 0;
let old_r_height = node.balance;
let r_height = old_r_height + delta;
balanced_avl_node_from(
max(r_height, l_height) - max(old_r_height, l_height),
node.data.clone(),
node.left.clone(),
Some(new_right),
r_height - l_height // right heavier
)
} else {
AvlNodeResult { delta: 0, node: node.clone() } // value already exists; no change
}
},
}
}
pub fn avl_delete<T>(node: &Option<Rc<AvlNode<T>>>, value: &T) -> AvlOptionalNodeResult<T>
where T: Ord + Clone, {
match node {
None => AvlOptionalNodeResult { delta: 0, node: None },
Some(node) => {
if *value < node.data {
let AvlOptionalNodeResult { delta, node: new_left } = avl_delete(&node.left, value);
// relative heights
let old_l_height = 0;
let l_height = delta;
let r_height = node.balance;
balanced_avl_node_from(
max(r_height, l_height) - max(r_height, old_l_height),
node.data.clone(),
new_left,
node.right.clone(),
r_height - l_height
).optional_box()
} else if *value > node.data {
let AvlOptionalNodeResult { delta, node: new_right } = avl_delete(&node.right, value);
// relative heights
let l_height = 0;
let old_r_height = node.balance;
let r_height = old_r_height + delta;
balanced_avl_node_from(
max(r_height, l_height) - max(old_r_height, l_height),
node.data.clone(),
node.left.clone(),
new_right,
r_height - l_height
).optional_box()
} else {
avl_delete_root(node).0
}
},
}
}
pub fn avl_delete_root<T>(node: &Rc<AvlNode<T>>) -> (AvlOptionalNodeResult<T>, T)
where T: Ord + Clone, {
(match (&node.left, &node.right) {
(None, None) => AvlResult { delta: -1, node: None },
(None, Some(right)) => AvlOptionalNodeResult { delta: -1, node: Some(right.clone()) },
(Some(left), None) => AvlOptionalNodeResult { delta: -1, node: Some(left.clone()) },
(Some(left), Some(right)) => {
let (subtree, value) = avl_delete_min(right);
// relative heights
let l_height = 0;
let old_r_height = node.balance;
let r_height = old_r_height + subtree.delta;
balanced_avl_node_from(
max(r_height, l_height) - max(old_r_height, l_height),
value,
Some(left.clone()),
subtree.node,
r_height - l_height
).optional_box()
},
}, node.data.clone())
}
pub fn avl_delete_min<T>(node: &Rc<AvlNode<T>>) -> (AvlOptionalNodeResult<T>, T)
where T: Ord + Clone, {
match &node.left {
None => avl_delete_root(node),
Some(left) => {
let (subtree, value) = avl_delete_min(left);
// relative heights
let old_l_height = 0;
let r_height = node.balance;
let l_height = old_l_height + subtree.delta;
(balanced_avl_node_from(
max(r_height, l_height) - max(r_height, old_l_height),
node.data.clone(),
subtree.node.clone(),
node.right.clone(),
r_height - l_height
).optional_box(), value)
}
}
}
pub fn new_balance_delta(balance: i8, left_delta: i8, right_delta: i8) -> (i8, i8) {
let l_height = 0;
let r_height = balance;
let new_l_height = l_height + left_delta;
let new_r_height = r_height + right_delta;
let old_height = max(r_height, l_height);
let new_height = max(new_r_height, new_l_height);
(balance + right_delta - left_delta, new_height - old_height)
}
impl<T: Clone> AvlNodeResult<T> {
pub fn balance(&self) -> AvlNodeResult<T> {
match (self.node.balance, self.node.left.clone(), self.node.right.clone()) {
(2, _, Some(right)) => {
if right.balance >= 0 { // Right-right case
self.node.view_with_right().unwrap().rotate_left().add_delta(self.delta)
} else { // Right-left case
let intermediate = right.view_with_left().unwrap().rotate_right();
let (balance, delta) = new_balance_delta(self.node.balance, 0, intermediate.delta);
(AvlNode {
balance,
data: self.node.data.clone(),
left: self.node.left.clone(),
right: Some(intermediate.node),
}).view_with_right().unwrap().rotate_left().add_delta(delta + self.delta)
}
},
(-2, Some(left), _) => {
if left.balance <= 0 { // Left-left case
self.node.view_with_left().unwrap().rotate_right().add_delta(self.delta)
} else { // Left-right case
let intermediate = left.view_with_right().unwrap().rotate_left();
let (balance, delta) = new_balance_delta(self.node.balance, intermediate.delta, 0);
(AvlNode {
balance,
data: self.node.data.clone(),
left: Some(intermediate.node),
right: self.node.right.clone(),
}).view_with_left().unwrap().rotate_right().add_delta(delta + self.delta)
}
},
(balance, _, _) => {
assert!((-1..=1).contains(&balance));
self.clone()
},
}
}
}
pub fn avl_find<T>(node: &Option<Rc<AvlNode<T>>>, value: &T) -> Option<Rc<AvlNode<T>>>
where T: Ord, {
use std::cmp::Ordering;
let Some(node) = node else { return None };
match (*value).cmp(&node.data) {
Ordering::Less => avl_find(&node.left, value),
Ordering::Greater => avl_find(&node.right, value),
Ordering::Equal => Some(node.clone()),
}
}
impl<T: fmt::Display> fmt::Display for AvlNode<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "({}, {}, ", self.data, self.balance)?;
match &self.left {
Some(left) => left.fmt(f)?, // recursively calls fmt
None => write!(f, "-")?,
}
write!(f, ", ")?; // comma between left and right
match &self.right {
Some(right) => right.fmt(f)?, // recursively calls fmt
None => write!(f, "-")?,
}
write!(f, ")")
}
}
pub struct AvlTree<T> {
root: Option<Rc<AvlNode<T>>>
}
impl<T: Ord + Clone> AvlTree<T> {
pub fn new() -> Self {
Self { root: None }
}
pub fn insert(&mut self, value: T) {
self.root = Some(avl_insert(&self.root, value).node);
}
pub fn delete(&mut self, value: &T) {
self.root = avl_delete(&self.root, value).node;
}
pub fn find(&self, value: &T) -> Option<Rc<AvlNode<T>>> {
avl_find(&self.root, value)
}
}
impl<T: fmt::Display> fmt::Display for AvlTree<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.root {
Some(node) => node.fmt(f), // recursively calls fmt
None => write!(f, "(Empty AvlTree)")
}
}
}
pub fn main() {
println!("Hello, avl!");
let x = Rc::new(AvlNode::new(1));
let y = Rc::new(AvlNode {
balance: 0,
data: 2,
left: Some(x.clone()),
right: Some(x.clone()),
});
println!("x: {}", x);
println!("y: {}", y);
let z = y.view_with_right().unwrap().rotate_left();
println!("z: {}", z.node);
let w = y.view_with_left().unwrap().rotate_right();
println!("w: {}", w.node);
}
#[cfg(test)]
mod tests {
use super::*;
use std::fmt::Write;
fn check_avl_height_balance<T>(node: &Option<Rc<AvlNode<T>>>) -> (i8, bool) {
match node {
None => (-1, true),
Some(node) => {
let (l_height, left) = check_avl_height_balance(&node.left);
let (r_height, right) = check_avl_height_balance(&node.right);
(max(l_height, r_height)+1, left && right && (-1 <= node.balance && node.balance <= 1) && node.balance == r_height - l_height)
}
}
}
#[test]
fn test_basic_insertion() {
let mut tree: AvlTree<i32> = AvlTree::new();
let values = vec![10, 5, 15, 3, 7, 13, 17];
for v in values {
tree.insert(v);
}
let expected = "(10, 0, (5, 0, (3, 0, -, -), (7, 0, -, -)), (15, 0, (13, 0, -, -), (17, 0, -, -)))";
assert_eq!(format!("{}", tree.root.clone().unwrap()), expected);
let (height, balance) = check_avl_height_balance(&tree.root);
assert_eq!(height, 2);
assert!(balance);
}
#[test]
fn test_insertion_two() {
let mut tree: AvlTree<i32> = AvlTree::new();
let values = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
for v in values {
tree.insert(v);
}
let expected = "(4, 1, (2, 0, (1, 0, -, -), (3, 0, -, -)), (8, 0, (6, 0, (5, 0, -, -), (7, 0, -, -)), (9, 1, -, (10, 0, -, -))))";
assert_eq!(format!("{}", tree), expected);
let (height, balance) = check_avl_height_balance(&tree.root);
assert_eq!(height, 3);
assert!(balance);
}
#[test]
fn test_large_performance() {
use std::rc::Rc;
use std::time::Instant;
use std::hint::black_box;
use rand::{SeedableRng, seq::SliceRandom};
use rand_chacha::ChaCha8Rng;
let mut tree: AvlTree<u32> = AvlTree::new();
{
let mut insert_values: Vec<u32> = (1..=100_000).collect();
// Fixed seed for reproducibility
let mut rng = ChaCha8Rng::seed_from_u64(42);
insert_values.shuffle(&mut rng);
let start = Instant::now();
for v in insert_values {
tree.insert(v);
}
let duration = start.elapsed();
println!("Inserting 100,000 elements took {:?}", duration);
}
assert!(tree.root.is_some());
let (height, balance) = check_avl_height_balance(&tree.root);
assert_eq!(height, 19);
assert!(balance);
{
let mut find_values: Vec<u32> = (1..=100_000).collect();
// Fixed seed for reproducibility
let mut rng = ChaCha8Rng::seed_from_u64(1234);
find_values.shuffle(&mut rng);
let start = Instant::now();
for v in find_values {
assert!(tree.find(&v).is_some());
}
let duration = start.elapsed();
println!("Finding 100,000 elements took {:?}", duration);
}
{
let mut delete_values: Vec<u32> = (1..=100_000).collect();
// Fixed seed for reproducibility
let mut rng = ChaCha8Rng::seed_from_u64(1337);
delete_values.shuffle(&mut rng);
let start = Instant::now();
for v in delete_values {
tree.delete(&v);
}
let duration = start.elapsed();
println!("Deleting 100,000 elements took {:?}", duration);
}
assert!(tree.root.is_none());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment