Created
May 26, 2020 05:36
-
-
Save FrancisMurillo/ae83ba089d4860b21ac02b82ef234e7e to your computer and use it in GitHub Desktop.
Peterson/Filter Lock
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#[macro_use] | |
extern crate log; | |
use std::{ | |
cell::{RefCell, UnsafeCell}, | |
sync::atomic::{AtomicU8, Ordering}, | |
}; | |
thread_local! { | |
static THREAD_ID: RefCell<u8> = RefCell::new(0); | |
} | |
const N: u8 = 32; | |
#[derive(Debug)] | |
struct Peterson { | |
level: UnsafeCell<Box<[u8; N as usize]>>, | |
victim: UnsafeCell<Box<[u8; N as usize]>>, | |
thread_id: AtomicU8, | |
} | |
unsafe impl Sync for Peterson {} | |
#[derive(Debug)] | |
struct PetersonLock<'a> { | |
peterson: &'a Peterson, | |
} | |
impl Peterson { | |
fn new() -> Self { | |
Self { | |
level: UnsafeCell::new(Box::new([0; N as usize])), | |
victim: UnsafeCell::new(Box::new([0; N as usize])), | |
thread_id: AtomicU8::new(0), | |
} | |
} | |
fn acquire<'a>(&'a self) -> Option<PetersonLock<'a>> { | |
THREAD_ID.with(|id| { | |
if *id.borrow() == 0 { | |
let old_value = self.thread_id.fetch_add(1, Ordering::SeqCst); | |
*id.borrow_mut() = old_value + 1; | |
} | |
}); | |
if self.thread_id.load(Ordering::Relaxed) == N { | |
return None; | |
} | |
THREAD_ID.with(|id| { | |
debug!("Acquring lock for {}", *id.borrow()); | |
let self_thread_id = *id.borrow(); | |
let level = unsafe { &mut *self.level.get() }; | |
let victim = unsafe { &mut *self.victim.get() }; | |
for current_level in 1..N { | |
level[self_thread_id as usize] = current_level; | |
victim[self_thread_id as usize] = self_thread_id; | |
while self.same_or_higher(self_thread_id, current_level) | |
&& victim[current_level as usize] == self_thread_id | |
{} | |
} | |
Some(PetersonLock { peterson: &self }) | |
}) | |
} | |
fn same_or_higher(&self, thread_id: u8, current_level: u8) -> bool { | |
let level = unsafe { &*self.level.get() }; | |
for k in 1..N { | |
if k != thread_id && level[k as usize] >= current_level { | |
return true; | |
} | |
} | |
false | |
} | |
fn release(&self) { | |
THREAD_ID.with(|id| { | |
debug!("Releasing lock for {}", *id.borrow()); | |
unsafe { | |
(*self.level.get())[*id.borrow() as usize] = 0; | |
} | |
}); | |
} | |
} | |
impl<'a> Drop for PetersonLock<'a> { | |
fn drop(&mut self) { | |
self.peterson.release(); | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use simplelog::{Config, LevelFilter, SimpleLogger}; | |
use std::{ | |
cell::UnsafeCell, | |
rc::Rc, | |
sync::Arc, | |
thread, | |
time::{Duration, SystemTime}, | |
}; | |
use crate::*; | |
#[test] | |
fn it_works() { | |
SimpleLogger::init(LevelFilter::Trace, Config::default()).unwrap_or(()); | |
let lock = Arc::new(Peterson::new()); | |
let first_lock = lock.clone(); | |
let second_lock = lock.clone(); | |
let third_lock = lock.clone(); | |
let first_handle = thread::spawn(move || { | |
info!("Acquiring T1 lock"); | |
let _guard = first_lock.acquire().expect("Could not acquire lock"); | |
info!("Using T1 lock for 3 seconds"); | |
thread::sleep(Duration::from_millis(3_000)); | |
info!("Done with T1 lock"); | |
}); | |
let second_handle = thread::spawn(move || { | |
info!("Acquiring T2 lock"); | |
let _guard = second_lock.acquire().expect("Could not acquire lock"); | |
info!("Using T2 lock for 3 seconds"); | |
thread::sleep(Duration::from_millis(3_000)); | |
info!("Done with T2 lock"); | |
}); | |
let third_handle = thread::spawn(move || { | |
{ | |
info!("Acquiring T3 lock"); | |
let _guard = third_lock.acquire().expect("Could not acquire lock"); | |
info!("Using T3 lock for 1 second"); | |
thread::sleep(Duration::from_millis(1_000)); | |
} | |
{ | |
info!("Acquiring T3 lock again"); | |
let other_guard = third_lock.acquire().expect("Could not acquire lock"); | |
info!("Using T3 lock for 1 second"); | |
thread::sleep(Duration::from_millis(1_000)); | |
drop(other_guard); | |
info!("Done with T3 lock"); | |
} | |
}); | |
first_handle.join().unwrap(); | |
second_handle.join().unwrap(); | |
third_handle.join().unwrap(); | |
} | |
#[test] | |
fn it_works_with_max_threads() { | |
SimpleLogger::init(LevelFilter::Trace, Config::default()).unwrap_or(()); | |
let lock = Arc::new(Peterson::new()); | |
let counter = Rc::new(UnsafeCell::new(0)); | |
let outer_ref = counter.clone(); | |
(1..N).into_iter().for_each(move |_| { | |
let lock_ref = lock.clone(); | |
let counter_ref = unsafe { &mut *outer_ref.get() }; | |
thread::spawn(move || { | |
let _guard = lock_ref.acquire().expect("Could not acquire lock"); | |
let old_value = *counter_ref; | |
thread::sleep(Duration::from_millis(10)); | |
*counter_ref += 1; | |
assert_eq!(old_value + 1, *counter_ref); | |
debug!("Updated counter {}", *counter_ref); | |
}); | |
}); | |
let start = SystemTime::now(); | |
while (unsafe { *counter.get() }) < (N - 1) | |
&& SystemTime::now().duration_since(start).unwrap().as_secs() < 5 | |
{} | |
let counter_value: &u8 = unsafe { &*counter.get() }; | |
assert_eq!(N - 1, *counter_value); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment