Skip to content

Instantly share code, notes, and snippets.

@pskrgag
Created March 31, 2025 18:40
Show Gist options
  • Save pskrgag/8622977b6485900ff30ea27fc16e89c6 to your computer and use it in GitHub Desktop.
Save pskrgag/8622977b6485900ff30ea27fc16e89c6 to your computer and use it in GitHub Desktop.
msc lock
use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::ptr;
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
struct Node {
next: AtomicPtr<Node>,
locked: AtomicBool,
}
impl Default for Node {
fn default() -> Self {
Self {
next: ptr::null_mut::<Node>().into(),
locked: false.into(),
}
}
}
impl Node {
pub fn new() -> *mut Self {
Box::leak(Box::new(Self::default()))
}
}
pub struct LockGuard<'lock, T> {
lock: &'lock MscLock<T>,
node: *mut Node,
}
pub struct MscLock<T> {
val: UnsafeCell<T>,
tail: AtomicPtr<Node>,
}
impl<T> MscLock<T> {
pub fn new(val: T) -> Self {
Self {
val: UnsafeCell::new(val),
tail: ptr::null_mut::<Node>().into(),
}
}
pub fn lock(&self) -> LockGuard<T> {
let node = Node::new();
let prev_tail = self.tail.swap(node, Ordering::AcqRel);
// Enqueue new node. If there is no other nodes, lock is held
if prev_tail.is_null() {
return LockGuard { lock: self, node };
}
unsafe {
// [`prev_tail`] points to the old tail. Link it to newly allocated node
(*prev_tail).next.store(node, Ordering::Release);
// Wait until node becomes the owner
while !(*node).locked.load(Ordering::Acquire) {}
}
LockGuard { lock: self, node }
}
pub fn into_inner(self) -> T {
self.val.into_inner()
}
// Should be called only from LockGuard::drop()
fn unlock(&self, node_ptr: *mut Node) {
let node = unsafe { &(*node_ptr) };
let mut next = node.next.load(Ordering::Acquire);
if next.is_null() {
// There are two cases: either there is no contention, or the other thread
// just called `swap` on a `tail` pointer.
// If `tail` is null, we are the only owners of the lock
if self
.tail
.compare_exchange(
node_ptr,
ptr::null_mut(),
Ordering::Release,
Ordering::Relaxed,
)
.is_ok()
{
// SAFETY: No other thread can access this node
unsafe { drop(Box::from_raw(node_ptr)) };
return;
}
// Wait until other thread update our next field
while {
next = node.next.load(Ordering::Acquire);
next.is_null()
} {}
}
// Release the lock
unsafe {
(*next).locked.store(true, Ordering::Release);
drop(Box::from_raw(node_ptr));
};
}
}
unsafe impl<T> Sync for MscLock<T> {}
unsafe impl<T> Send for MscLock<T> {}
impl<T> Drop for LockGuard<'_, T> {
fn drop(&mut self) {
self.lock.unlock(self.node);
}
}
impl<T> Deref for LockGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// SAFETY: The current thread owns the lock
unsafe { &(*self.lock.val.get()) }
}
}
impl<T> DerefMut for LockGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
// SAFETY: The current thread owns the lock
unsafe { &mut (*self.lock.val.get()) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread::spawn;
#[test]
fn test_st() {
let lock = MscLock::new(1);
*lock.lock() += 1;
*lock.lock() += 1;
*lock.lock() += 1;
assert_eq!(lock.into_inner(), 4);
}
#[test]
fn test_compile_error() {
let lock = MscLock::new(1);
let mut guard = lock.lock();
let val = guard.deref_mut();
*val += 10;
drop(guard);
// Should give a compiler error
// *val += 10;
}
#[test]
fn test_more_threads() {
#[cfg(not(miri))]
static ITERS: usize = 100_000;
#[cfg(miri)]
static ITERS: usize = 100;
#[cfg(not(miri))]
static NUM_THREADS: usize = 10;
#[cfg(miri)]
static NUM_THREADS: usize = 2;
let lock = Arc::new(MscLock::new(0));
let threads = (0..NUM_THREADS)
.map(|_| {
let lock = lock.clone();
spawn(move || {
for _ in 0..ITERS {
*lock.lock() += 1;
}
})
})
.collect::<Vec<_>>();
for i in threads {
i.join().unwrap();
}
assert_eq!(*lock.lock(), ITERS * NUM_THREADS);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment