Skip to content

Instantly share code, notes, and snippets.

@robert-king
Created September 6, 2024 21:43
Show Gist options
  • Save robert-king/9f0a6f86681d6e9658285ccb7fa0bd05 to your computer and use it in GitHub Desktop.
Save robert-king/9f0a6f86681d6e9658285ccb7fa0bd05 to your computer and use it in GitHub Desktop.
proptest linkedlist
#![allow(dead_code)]
/*
we use hashmaps for the linkedlist to sidestep the borrow checker.
we use proptest strategy to generate series of operations across the linkedlist and compare to a VecDeque for correctness.
We can add subtle bugs and proptest manages to catch them
and shrink down to a minimal reproduction which we can then copy and paste to create a new test case.
PROPTEST_CASES=1000000 PROPTEST_MAX_SHRINK_ITERS=100000000 cargo test --release
cargo test copy_test_case
*/
use std::{collections::HashMap, fmt::Debug, hash::Hash};
trait LinkedListItem: Default + Eq + Hash + Copy + Debug + Clone {}
impl<T: Default + Eq + Hash + Copy + Debug + Clone> LinkedListItem for T {}
type Idx = u8;
#[derive(Default, Debug)]
struct LinkedList<T: LinkedListItem> {
head: Option<Idx>,
tail: Option<Idx>,
next: HashMap<Idx, Idx>,
prev: HashMap<Idx, Idx>,
val: HashMap<Idx, T>,
idx: HashMap<T, Idx>, // todo: MultiMap
increment: Idx,
}
impl<T: LinkedListItem> LinkedList<T> {
fn new() -> Self {
LinkedList::default()
}
fn pop_front(&mut self) -> Option<T> {
if let Some(head) = self.head {
if Some(head) == self.tail {
self.head = None;
self.tail = None;
self.next.clear();
self.prev.clear();
self.idx.clear();
let tmp = self.val.remove(&head).unwrap();
self.idx.remove(&tmp);
return Some(tmp);
}
self.head = self.next.remove(&head);
self.prev.remove(&self.head.unwrap());
let tmp = self.val.remove(&head).unwrap();
self.idx.remove(&tmp);
return Some(tmp);
}
None
}
fn pop_back(&mut self) -> Option<T> {
if let Some(tail) = self.tail {
if Some(tail) == self.head {
return self.pop_front();
}
self.tail = self.prev.remove(&tail);
self.next.remove(&self.tail.unwrap());
let tmp = self.val.remove(&tail).unwrap();
self.idx.remove(&tmp);
return Some(tmp);
}
None
}
fn push_back(&mut self, x: T) {
assert!(!self.idx.contains_key(&x));
let i = self.increment;
self.increment += 1;
self.val.insert(i, x);
self.idx.insert(x, i);
let Some(tail) = self.tail else {
self.head = Some(i);
self.tail = Some(i);
return;
};
self.tail = Some(i);
self.next.insert(tail, i);
self.prev.insert(i, tail);
}
fn remove(&mut self, x: T) -> bool {
let Some(&idx) = self.idx.get(&x) else {
return false;
};
let Some(&prev) = self.prev.get(&idx) else {
self.pop_front();
return true;
};
let Some(&nxt) = self.next.get(&idx) else {
self.pop_back();
return true;
};
self.next.insert(prev, nxt);
self.prev.insert(nxt, prev);
self.idx.remove(&x);
self.val.remove(&idx);
true
}
fn iter(&self) -> LinkedListIter<'_, T> {
LinkedListIter {
ll: self,
idx: self.head,
}
}
}
struct LinkedListIter<'a, T: LinkedListItem> {
ll: &'a LinkedList<T>,
idx: Option<Idx>,
}
impl<'a, T: LinkedListItem> Iterator for LinkedListIter<'a, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
let Some(idx) = self.idx else {
return None;
};
let item = self.ll.val.get(&idx).unwrap();
self.idx = self.ll.next.get(&idx).cloned();
Some(*item)
}
}
fn main() {}
#[test]
fn it_works() {
let mut ll = LinkedList::new();
ll.push_back(1);
ll.push_back(2);
assert_eq!(ll.pop_front(), Some(1));
assert_eq!(ll.pop_front(), Some(2));
assert_eq!(ll.pop_front(), None);
ll.push_back(1);
ll.push_back(2);
assert_eq!(ll.pop_front(), Some(1));
assert_eq!(ll.pop_front(), Some(2));
assert_eq!(ll.pop_front(), None);
ll.push_back(1);
ll.push_back(2);
assert_eq!(ll.pop_back(), Some(2));
assert_eq!(ll.pop_back(), Some(1));
assert_eq!(ll.pop_back(), None);
ll.push_back(2);
ll.push_back(1);
assert_eq!(ll.pop_back(), Some(1));
assert_eq!(ll.pop_back(), Some(2));
assert_eq!(ll.pop_back(), None);
assert_eq!(ll.remove(1), false);
ll.push_back(1);
assert_eq!(ll.remove(1), true);
assert_eq!(ll.remove(1), false);
ll.push_back(1);
ll.push_back(2);
ll.push_back(3);
assert_eq!(ll.remove(2), true);
assert_eq!(ll.remove(1), true);
assert_eq!(ll.pop_front(), Some(3));
assert_eq!(ll.pop_front(), None);
assert_eq!(ll.remove(3), false);
}
#[derive(Clone, Debug)]
enum Operation<T: LinkedListItem> {
PopFront,
PopBack,
PushBack(T),
Remove(T),
}
use proptest::prelude::*;
fn get_operation() -> impl Strategy<Value = Operation<i32>> {
prop_oneof![
Just(Operation::PopFront),
Just(Operation::PopBack),
(0..10).prop_map(Operation::PushBack),
(0..10).prop_map(Operation::Remove),
]
}
fn get_operations() -> impl Strategy<Value = Vec<Operation<i32>>> {
proptest::collection::vec(get_operation(), 1..10)
}
proptest! {
// #[ignore]
#[test]
fn linked_list_and_vec_deque_behave_the_same(
operations in get_operations()
) {
let mut s = std::collections::HashSet::new();
let mut ll = LinkedList::new();
let mut v = std::collections::VecDeque::new();
for op in operations {
match op {
Operation::PopFront => {
ll.pop_front();
v.pop_front();
},
Operation::PopBack => {
ll.pop_back();
v.pop_back();
},
Operation::PushBack(x) => {
if s.insert(x) {
ll.remove(x);
ll.push_back(x);
v.push_back(x);
}
}
Operation::Remove(x) => {
s.remove(&x);
ll.remove(x);
if let Some(i) = v.iter().position(|z| *z==x) {
v.remove(i);
}
}
}
println!("{op:?}!");
println!("{ll:?}!");
let ll_vec: Vec<i32> = ll.iter().collect();
println!("{ll_vec:?} <-- ll");
println!("{v:?} <-- v");
for (a, b) in ll.iter().zip(v.iter()) {
assert_eq!(a, *b);
}
}
}
}
#[test]
fn copy_test_case() {
use Operation::*;
let ops = [PushBack(0), PushBack(2), Remove(2), PushBack(2)];
let mut s = std::collections::HashSet::new();
let mut ll = LinkedList::new();
let mut v = std::collections::VecDeque::new();
for op in ops {
match op {
Operation::PopFront => {
ll.pop_front();
v.pop_front();
}
Operation::PopBack => {
ll.pop_back();
v.pop_back();
}
Operation::PushBack(x) => {
if s.insert(x) {
ll.remove(x);
ll.push_back(x);
v.push_back(x);
}
}
Operation::Remove(x) => {
s.remove(&x);
ll.remove(x);
if let Some(i) = v.iter().position(|z| *z == x) {
v.remove(i);
}
}
}
println!("{op:?}!");
println!("{ll:?}!");
let ll_vec: Vec<i32> = ll.iter().collect();
println!("{ll_vec:?} <-- ll");
println!("{v:?} <-- v");
for (a, b) in ll.iter().zip(v.iter()) {
assert_eq!(a, *b);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment