Created
March 31, 2025 18:40
-
-
Save pskrgag/8622977b6485900ff30ea27fc16e89c6 to your computer and use it in GitHub Desktop.
msc lock
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::cell::UnsafeCell; | |
use std::ops::{Deref, DerefMut}; | |
use std::ptr; | |
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering}; | |
struct Node { | |
next: AtomicPtr<Node>, | |
locked: AtomicBool, | |
} | |
impl Default for Node { | |
fn default() -> Self { | |
Self { | |
next: ptr::null_mut::<Node>().into(), | |
locked: false.into(), | |
} | |
} | |
} | |
impl Node { | |
pub fn new() -> *mut Self { | |
Box::leak(Box::new(Self::default())) | |
} | |
} | |
pub struct LockGuard<'lock, T> { | |
lock: &'lock MscLock<T>, | |
node: *mut Node, | |
} | |
pub struct MscLock<T> { | |
val: UnsafeCell<T>, | |
tail: AtomicPtr<Node>, | |
} | |
impl<T> MscLock<T> { | |
pub fn new(val: T) -> Self { | |
Self { | |
val: UnsafeCell::new(val), | |
tail: ptr::null_mut::<Node>().into(), | |
} | |
} | |
pub fn lock(&self) -> LockGuard<T> { | |
let node = Node::new(); | |
let prev_tail = self.tail.swap(node, Ordering::AcqRel); | |
// Enqueue new node. If there is no other nodes, lock is held | |
if prev_tail.is_null() { | |
return LockGuard { lock: self, node }; | |
} | |
unsafe { | |
// [`prev_tail`] points to the old tail. Link it to newly allocated node | |
(*prev_tail).next.store(node, Ordering::Release); | |
// Wait until node becomes the owner | |
while !(*node).locked.load(Ordering::Acquire) {} | |
} | |
LockGuard { lock: self, node } | |
} | |
pub fn into_inner(self) -> T { | |
self.val.into_inner() | |
} | |
// Should be called only from LockGuard::drop() | |
fn unlock(&self, node_ptr: *mut Node) { | |
let node = unsafe { &(*node_ptr) }; | |
let mut next = node.next.load(Ordering::Acquire); | |
if next.is_null() { | |
// There are two cases: either there is no contention, or the other thread | |
// just called `swap` on a `tail` pointer. | |
// If `tail` is null, we are the only owners of the lock | |
if self | |
.tail | |
.compare_exchange( | |
node_ptr, | |
ptr::null_mut(), | |
Ordering::Release, | |
Ordering::Relaxed, | |
) | |
.is_ok() | |
{ | |
// SAFETY: No other thread can access this node | |
unsafe { drop(Box::from_raw(node_ptr)) }; | |
return; | |
} | |
// Wait until other thread update our next field | |
while { | |
next = node.next.load(Ordering::Acquire); | |
next.is_null() | |
} {} | |
} | |
// Release the lock | |
unsafe { | |
(*next).locked.store(true, Ordering::Release); | |
drop(Box::from_raw(node_ptr)); | |
}; | |
} | |
} | |
unsafe impl<T> Sync for MscLock<T> {} | |
unsafe impl<T> Send for MscLock<T> {} | |
impl<T> Drop for LockGuard<'_, T> { | |
fn drop(&mut self) { | |
self.lock.unlock(self.node); | |
} | |
} | |
impl<T> Deref for LockGuard<'_, T> { | |
type Target = T; | |
fn deref(&self) -> &Self::Target { | |
// SAFETY: The current thread owns the lock | |
unsafe { &(*self.lock.val.get()) } | |
} | |
} | |
impl<T> DerefMut for LockGuard<'_, T> { | |
fn deref_mut(&mut self) -> &mut Self::Target { | |
// SAFETY: The current thread owns the lock | |
unsafe { &mut (*self.lock.val.get()) } | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use std::sync::Arc; | |
use std::thread::spawn; | |
#[test] | |
fn test_st() { | |
let lock = MscLock::new(1); | |
*lock.lock() += 1; | |
*lock.lock() += 1; | |
*lock.lock() += 1; | |
assert_eq!(lock.into_inner(), 4); | |
} | |
#[test] | |
fn test_compile_error() { | |
let lock = MscLock::new(1); | |
let mut guard = lock.lock(); | |
let val = guard.deref_mut(); | |
*val += 10; | |
drop(guard); | |
// Should give a compiler error | |
// *val += 10; | |
} | |
#[test] | |
fn test_more_threads() { | |
#[cfg(not(miri))] | |
static ITERS: usize = 100_000; | |
#[cfg(miri)] | |
static ITERS: usize = 100; | |
#[cfg(not(miri))] | |
static NUM_THREADS: usize = 10; | |
#[cfg(miri)] | |
static NUM_THREADS: usize = 2; | |
let lock = Arc::new(MscLock::new(0)); | |
let threads = (0..NUM_THREADS) | |
.map(|_| { | |
let lock = lock.clone(); | |
spawn(move || { | |
for _ in 0..ITERS { | |
*lock.lock() += 1; | |
} | |
}) | |
}) | |
.collect::<Vec<_>>(); | |
for i in threads { | |
i.join().unwrap(); | |
} | |
assert_eq!(*lock.lock(), ITERS * NUM_THREADS); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment