diff --git a/Cargo.toml b/Cargo.toml index a3d7d53..8f68ccf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lockmap" -version = "0.1.4" +version = "0.1.5" edition = "2021" authors = ["SF-Zhou "] diff --git a/README.md b/README.md index e0132e4..94ef1c1 100644 --- a/README.md +++ b/README.md @@ -16,15 +16,15 @@ use lockmap::LockMap; let map = LockMap::::new(); // Set a value -map.set_by_ref("key", "value".into()); +map.insert_by_ref("key", "value".into()); // Get a value assert_eq!(map.get("key"), Some("value".into())); // Use entry API for exclusive access { - let entry = map.entry_by_ref("key"); - *entry.value = Some("new value".into()); + let mut entry = map.entry_by_ref("key"); + *entry.get_mut() = Some("new value".into()); } // Remove a value diff --git a/src/futex.rs b/src/futex.rs new file mode 100644 index 0000000..25df6f5 --- /dev/null +++ b/src/futex.rs @@ -0,0 +1,135 @@ +// Modified from https://github.com/rust-lang/rust/blob/master/library/std/src/sys/sync/mutex/futex.rs +use std::sync::atomic::{ + AtomicU32, + Ordering::{Acquire, Relaxed, Release}, +}; + +pub struct Mutex { + futex: AtomicU32, +} + +const UNLOCKED: u32 = 0; +const LOCKED: u32 = 1; // locked, no other threads waiting +const CONTENDED: u32 = 2; // locked, and other threads waiting (contended) + +impl Mutex { + #[inline] + pub const fn new() -> Self { + Self { + futex: AtomicU32::new(UNLOCKED), + } + } + + #[inline] + pub fn try_lock(&self) -> bool { + self.futex + .compare_exchange(UNLOCKED, LOCKED, Acquire, Relaxed) + .is_ok() + } + + #[inline] + pub fn lock(&self) { + if !self.try_lock() { + self.lock_contended(); + } + } + + #[cold] + fn lock_contended(&self) { + // Spin first to speed things up if the lock is released quickly. + let mut state = self.spin(); + + // If it's unlocked now, attempt to take the lock + // without marking it as contended. + if state == UNLOCKED { + match self + .futex + .compare_exchange(UNLOCKED, LOCKED, Acquire, Relaxed) + { + Ok(_) => return, // Locked! + Err(s) => state = s, + } + } + + loop { + // Put the lock in contended state. + // We avoid an unnecessary write if it as already set to CONTENDED, + // to be friendlier for the caches. + if state != CONTENDED && self.futex.swap(CONTENDED, Acquire) == UNLOCKED { + // We changed it from UNLOCKED to CONTENDED, so we just successfully locked it. + return; + } + + // Wait for the futex to change state, assuming it is still CONTENDED. + atomic_wait::wait(&self.futex, CONTENDED); + + // Spin again after waking up. + state = self.spin(); + } + } + + fn spin(&self) -> u32 { + let mut spin = 100; + loop { + // We only use `load` (and not `swap` or `compare_exchange`) + // while spinning, to be easier on the caches. + let state = self.futex.load(Relaxed); + + // We stop spinning when the mutex is UNLOCKED, + // but also when it's CONTENDED. + if state != LOCKED || spin == 0 { + return state; + } + + std::hint::spin_loop(); + spin -= 1; + } + } + + #[inline] + pub fn unlock(&self) { + if self.futex.swap(UNLOCKED, Release) == CONTENDED { + // We only wake up one thread. When that thread locks the mutex, it + // will mark the mutex as CONTENDED (see lock_contended above), + // which makes sure that any other waiting threads will also be + // woken up eventually. + self.wake(); + } + } + + #[cold] + fn wake(&self) { + atomic_wait::wake_one(&self.futex); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_futex() { + let lock = Arc::new(Mutex::new()); + let current = Arc::new(AtomicU32::new(0)); + const N: usize = 8; + const M: usize = 1 << 20; + + let mut tasks = vec![]; + for _ in 0..N { + let lock = lock.clone(); + let current = current.clone(); + tasks.push(std::thread::spawn(move || { + for _ in 0..M { + lock.lock(); + assert_eq!(current.fetch_add(1, Acquire), 0); + current.fetch_sub(1, Acquire); + lock.unlock(); + } + })); + } + for task in tasks { + task.join().unwrap(); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 14be8bd..fff7f94 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,24 +17,24 @@ //! let map = LockMap::::new(); //! //! // Basic operations -//! map.set("key1".into(), 42); +//! map.insert("key1".into(), 42); //! assert_eq!(map.get("key1"), Some(42)); //! //! // Entry API for exclusive access //! { -//! let entry = map.entry("key2".into()); -//! entry.value.replace(123); +//! let mut entry = map.entry("key2".into()); +//! entry.get_mut().replace(123); //! } //! //! // Remove a value //! assert_eq!(map.remove("key1"), Some(42)); //! assert_eq!(map.get("key1"), None); //! ``` +mod futex; #[doc = include_str!("../README.md")] mod lockmap; mod shards_map; -mod waiter; +use futex::*; pub use lockmap::*; use shards_map::*; -use waiter::*; diff --git a/src/lockmap.rs b/src/lockmap.rs index 14b1c02..dbf0c4d 100644 --- a/src/lockmap.rs +++ b/src/lockmap.rs @@ -1,8 +1,6 @@ -use crate::{ShardsMap, SimpleAction, UpdateAction, WaiterPtr}; +use crate::{Mutex, ShardsMap, SimpleAction, UpdateAction}; use std::borrow::Borrow; -use std::collections::LinkedList; use std::hash::Hash; -use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::OnceLock; /// Internal state for a key-value pair in the `LockMap`. @@ -10,24 +8,14 @@ use std::sync::OnceLock; /// This type manages both the stored value and the queue of waiting threads /// for per-key synchronization. struct State { - /// The stored value, wrapped in a Box to ensure stable memory location - value: Box>, - /// Queue of threads waiting for access to this key - queue: LinkedList, -} - -impl Default for State { - fn default() -> Self { - Self { - value: Default::default(), - queue: Default::default(), - } - } + refcnt: u32, + mutex: Mutex, + value: Option, } /// A thread-safe hashmap that supports locking entries at the key level. pub struct LockMap { - map: ShardsMap>, + map: ShardsMap>>, } impl Default for LockMap { @@ -90,7 +78,7 @@ impl LockMap { /// Gets exclusive access to an entry in the map. /// - /// The returned `Entry` provides exclusive access to the key and its associated value + /// The returned `EntryByVal` provides exclusive access to the key and its associated value /// until it is dropped. /// /// **Locking behaviour:** Deadlock if called when holding the same entry. @@ -101,49 +89,40 @@ impl LockMap { /// let map = LockMap::::new(); /// { /// let mut entry = map.entry("key".to_string()); - /// entry.value.replace(42); + /// entry.get_mut().replace(42); /// // let _ = map.get("key".to_string()); // DEADLOCK! - /// // map.set("key".to_string(), 21); // DEADLOCK! + /// // map.insert("key".to_string(), 21); // DEADLOCK! /// // map.remove("key".to_string()); // DEADLOCK! /// // let mut entry2 = map.entry("key".to_string()); // DEADLOCK! /// } /// ``` - pub fn entry(&self, key: K) -> Entry<'_, K, V> + pub fn entry(&self, key: K) -> EntryByVal<'_, K, V> where K: Clone, { - let waiter = AtomicU32::new(0); - let ptr = self.map.update(key.clone(), |value| match value { + let ptr: *mut State = self.map.update(key.clone(), |s| match s { Some(state) => { - if state.queue.is_empty() { - // no need to wait. - waiter.store(1, Ordering::Release); - } - state.queue.push_back(WaiterPtr::new(&waiter)); - (UpdateAction::Keep, state.value.as_mut() as *mut _) + state.refcnt += 1; + let ptr = state.as_mut() as _; + (UpdateAction::Keep, ptr) } None => { - let mut state = State::default(); - // no need to wait. - state.queue.push_back(WaiterPtr::new(&waiter)); - let ptr = state.value.as_mut() as *mut _; - waiter.store(1, Ordering::Release); - (UpdateAction::Update(state), ptr) + let mut state: Box<_> = Box::new(State { + refcnt: 1, + mutex: Mutex::new(), + value: None, + }); + let ptr = state.as_mut() as _; + (UpdateAction::Replace(state), ptr) } }); - WaiterPtr::wait(&waiter); - - Entry { - map: self, - key, - value: Self::value_ptr_to_ref(ptr), - } + self.guard_by_val(ptr, key.clone()) } /// Gets exclusive access to an entry in the map. /// - /// The returned `Entry` provides exclusive access to the key and its associated value + /// The returned `EntryByVal` provides exclusive access to the key and its associated value /// until it is dropped. /// /// **Locking behaviour:** Deadlock if called when holding the same entry. @@ -154,9 +133,9 @@ impl LockMap { /// let map = LockMap::::new(); /// { /// let mut entry = map.entry_by_ref("key"); - /// entry.value.replace(42); + /// entry.get_mut().replace(42); /// // let _ = map.get("key"); // DEADLOCK! - /// // map.set_by_ref("key", 21); // DEADLOCK! + /// // map.insert_by_ref("key", 21); // DEADLOCK! /// // map.remove("key"); // DEADLOCK! /// // let mut entry2 = map.entry_by_ref("key"); // DEADLOCK! /// } @@ -166,33 +145,24 @@ impl LockMap { K: Borrow + for<'c> From<&'c Q>, Q: Eq + Hash + ?Sized, { - let waiter = AtomicU32::new(0); - let ptr = self.map.update_by_ref(key, |value| match value { + let ptr: *mut State = self.map.update_by_ref(key, |s| match s { Some(state) => { - if state.queue.is_empty() { - // no need to wait. - waiter.store(1, Ordering::Release); - } - state.queue.push_back(WaiterPtr::new(&waiter)); - (UpdateAction::Keep, state.value.as_mut() as *mut _) + state.refcnt += 1; + let ptr = state.as_mut() as _; + (UpdateAction::Keep, ptr) } None => { - let mut state = State::default(); - // no need to wait. - state.queue.push_back(WaiterPtr::new(&waiter)); - let ptr = state.value.as_mut() as *mut _; - waiter.store(1, Ordering::Release); - (UpdateAction::Update(state), ptr) + let mut state: Box<_> = Box::new(State { + refcnt: 1, + mutex: Mutex::new(), + value: None, + }); + let ptr = state.as_mut() as _; + (UpdateAction::Replace(state), ptr) } }); - WaiterPtr::wait(&waiter); - - EntryByRef { - map: self, - key, - value: Self::value_ptr_to_ref(ptr), - } + self.guard_by_ref(ptr, key) } /// Gets the value associated with the given key. @@ -214,7 +184,7 @@ impl LockMap { /// use lockmap::LockMap; /// /// let map = LockMap::::new(); - /// map.set_by_ref("key", 42); + /// map.insert_by_ref("key", 42); /// assert_eq!(map.get("key"), Some(42)); /// assert_eq!(map.get("missing"), None); /// ``` @@ -224,17 +194,15 @@ impl LockMap { V: Clone, Q: Eq + Hash + ?Sized, { - let waiter = AtomicU32::new(0); - let mut ptr: *mut Option = std::ptr::null_mut(); - let value = self.map.simple_update(key, |value| match value { + let mut ptr: *mut State = std::ptr::null_mut(); + let value = self.map.simple_update(key, |s| match s { Some(state) => { - if state.queue.is_empty() { - // no need to wait. - (SimpleAction::Keep, state.value.as_mut().clone()) + if state.refcnt == 0 { + let value = state.value.clone(); + (SimpleAction::Keep, value) } else { - // need to wait. - state.queue.push_back(WaiterPtr::new(&waiter)); - ptr = state.value.as_mut() as *mut _; + state.refcnt += 1; + ptr = state.as_mut(); (SimpleAction::Keep, None) } } @@ -245,11 +213,7 @@ impl LockMap { return value; } - WaiterPtr::wait(&waiter); - - let value = Self::value_ptr_to_ref(ptr).clone(); - self.unlock(key); - value + self.guard_by_ref(ptr, key).state.value.clone() } /// Sets a value in the map. @@ -270,48 +234,42 @@ impl LockMap { /// let map = LockMap::::new(); /// /// // Set a value - /// map.set("key".to_string(), 42); + /// assert_eq!(map.insert("key".to_string(), 42), None); /// /// // Update existing value - /// map.set("key".to_string(), 123); + /// assert_eq!(map.insert("key".to_string(), 123), Some(42)); /// ``` - pub fn set(&self, key: K, value: V) + pub fn insert(&self, key: K, value: V) -> Option where K: Clone, { - let waiter = AtomicU32::new(0); - let mut ptr: *mut Option = std::ptr::null_mut(); - let value = self.map.update(key.clone(), |v| match v { + let (ptr, value) = self.map.update(key.clone(), move |s| match s { Some(state) => { - if state.queue.is_empty() { - // no need to wait. - state.value.replace(value); - (UpdateAction::Keep, None) + if state.refcnt == 0 { + let value = state.value.replace(value); + (UpdateAction::Keep, (std::ptr::null_mut(), value)) } else { - // need to wait. - state.queue.push_back(WaiterPtr::new(&waiter)); - ptr = state.value.as_mut() as *mut _; - (UpdateAction::Keep, Some(value)) + state.refcnt += 1; + let ptr: *mut State = state.as_mut(); + (UpdateAction::Keep, (ptr, Some(value))) } } None => { - // no need to wait. - let state = State { - value: Box::new(Some(value)), - queue: Default::default(), - }; - (UpdateAction::Update(state), None) + let state: Box<_> = Box::new(State { + refcnt: 0, + mutex: Mutex::new(), + value: Some(value), + }); + (UpdateAction::Replace(state), (std::ptr::null_mut(), None)) } }); if ptr.is_null() { - return; + return value; } - WaiterPtr::wait(&waiter); - - *Self::value_ptr_to_ref(ptr) = value; - self.unlock(&key); + let mut entry = self.guard_by_val(ptr, key.clone()); + std::mem::replace(entry.get_mut(), value) } /// Sets a value in the map. @@ -332,49 +290,43 @@ impl LockMap { /// let map = LockMap::::new(); /// /// // Set a value - /// map.set_by_ref("key", 42); + /// map.insert_by_ref("key", 42); /// /// // Update existing value - /// map.set_by_ref("key", 123); + /// map.insert_by_ref("key", 123); /// ``` - pub fn set_by_ref(&self, key: &Q, value: V) + pub fn insert_by_ref(&self, key: &Q, value: V) -> Option where K: Borrow + for<'c> From<&'c Q>, Q: Eq + Hash + ?Sized, { - let waiter = AtomicU32::new(0); - let mut ptr: *mut Option = std::ptr::null_mut(); - let value = self.map.update_by_ref(key, |v| match v { + let (ptr, value) = self.map.update_by_ref(key, move |s| match s { Some(state) => { - if state.queue.is_empty() { - // no need to wait. - state.value.replace(value); - (UpdateAction::Keep, None) + if state.refcnt == 0 { + let value = state.value.replace(value); + (UpdateAction::Keep, (std::ptr::null_mut(), value)) } else { - // need to wait. - state.queue.push_back(WaiterPtr::new(&waiter)); - ptr = state.value.as_mut() as *mut _; - (UpdateAction::Keep, Some(value)) + state.refcnt += 1; + let ptr: *mut State = state.as_mut(); + (UpdateAction::Keep, (ptr, Some(value))) } } None => { - // no need to wait. - let state = State { - value: Box::new(Some(value)), - queue: Default::default(), - }; - (UpdateAction::Update(state), None) + let state: Box<_> = Box::new(State { + refcnt: 0, + mutex: Mutex::new(), + value: Some(value), + }); + (UpdateAction::Replace(state), (std::ptr::null_mut(), None)) } }); if ptr.is_null() { - return; + return value; } - WaiterPtr::wait(&waiter); - - *Self::value_ptr_to_ref(ptr) = value; - self.unlock(key); + let mut entry = self.guard_by_ref(ptr, key); + std::mem::replace(entry.get_mut(), value) } /// Removes a key from the map. @@ -396,7 +348,7 @@ impl LockMap { /// use lockmap::LockMap; /// /// let map = LockMap::::new(); - /// map.set_by_ref("key", 42); + /// map.insert_by_ref("key", 42); /// assert_eq!(map.remove("key"), Some(42)); /// assert_eq!(map.get("key"), None); /// ``` @@ -405,32 +357,26 @@ impl LockMap { K: Borrow, Q: Eq + Hash + ?Sized, { - let waiter = AtomicU32::new(0); - let mut ptr: *mut Option = std::ptr::null_mut(); - let value = self.map.simple_update(key, |v| match v { + let mut ptr: *mut State = std::ptr::null_mut(); + let value = self.map.simple_update(key, |s| match s { Some(state) => { - if state.queue.is_empty() { - // no need to wait. - (SimpleAction::Remove, state.value.take()) + if state.refcnt == 0 { + let value = state.value.take(); + (SimpleAction::Remove, value) } else { - // need to wait. - state.queue.push_back(WaiterPtr::new(&waiter)); - ptr = state.value.as_mut() as _; + state.refcnt += 1; + ptr = state.as_mut(); (SimpleAction::Keep, None) } } - None => (SimpleAction::Keep, None), // no need to wait. + None => (SimpleAction::Keep, None), }); if ptr.is_null() { return value; } - WaiterPtr::wait(&waiter); - - let value = Self::value_ptr_to_ref(ptr).take(); - self.unlock(key); - value + self.guard_by_ref(ptr, key).state.value.take() } fn unlock(&self, key: &Q) @@ -439,25 +385,50 @@ impl LockMap { Q: Eq + Hash + ?Sized, { self.map.simple_update(key, |value| match value { - Some(state) => (Self::wake_up_next_one(state), ()), + Some(state) => { + state.refcnt -= 1; + if state.value.is_none() && state.refcnt == 0 { + (SimpleAction::Remove, ()) + } else { + (SimpleAction::Keep, ()) + } + } None => panic!("impossible: unlock a non-existent key!"), }); } - fn wake_up_next_one(state: &mut State) -> SimpleAction { - state.queue.pop_front(); - match state.queue.front() { - Some(waiter) => { - waiter.wake_up(); - SimpleAction::Keep - } - None if state.value.is_none() => SimpleAction::Remove, - None => SimpleAction::Keep, + fn guard_by_val(&self, ptr: *mut State, key: K) -> EntryByVal { + let state = unsafe { &mut *ptr }; + state.mutex.lock(); + EntryByVal { + map: self, + key, + state, } } - fn value_ptr_to_ref<'env>(ptr: *mut Option) -> &'env mut Option { - unsafe { &mut *ptr } + fn guard_by_ref<'a, 'b, Q>( + &'a self, + ptr: *mut State, + key: &'b Q, + ) -> EntryByRef<'a, 'b, K, Q, V> + where + K: Borrow, + Q: Eq + Hash + ?Sized, + { + let state = unsafe { &mut *ptr }; + state.mutex.lock(); + EntryByRef { + map: self, + key, + state, + } + } +} + +impl std::fmt::Debug for LockMap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LockMap").finish() } } @@ -480,19 +451,47 @@ impl LockMap { /// let mut entry = map.entry("key"); /// /// // Modify the value -/// entry.value.replace(42); +/// entry.get_mut().replace(42); /// -/// // Entry is automatically unlocked when dropped +/// // EntryByVal is automatically unlocked when dropped /// } /// ``` -pub struct Entry<'a, K: Eq + Hash, V> { +pub struct EntryByVal<'a, K: Eq + Hash, V> { map: &'a LockMap, - pub key: K, - pub value: &'a mut Option, + key: K, + state: &'a mut State, } -impl Drop for Entry<'_, K, V> { +impl EntryByVal<'_, K, V> { + pub fn key(&self) -> &K { + &self.key + } + + pub fn get(&mut self) -> &Option { + &self.state.value + } + + pub fn get_mut(&mut self) -> &mut Option { + &mut self.state.value + } + + pub fn insert(&mut self, value: V) -> Option { + self.state.value.replace(value) + } +} + +impl std::fmt::Debug for EntryByVal<'_, K, V> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EntryByVal") + .field("key", &self.key) + .field("value", &self.state.value) + .finish() + } +} + +impl Drop for EntryByVal<'_, K, V> { fn drop(&mut self) { + self.state.mutex.unlock(); self.map.unlock(&self.key); } } @@ -512,33 +511,58 @@ impl Drop for Entry<'_, K, V> { /// ``` /// use lockmap::LockMap; /// -/// let map = LockMap::new(); +/// let map = LockMap::::new(); /// { /// // Get exclusive access to an entry -/// let mut entry = map.entry("key"); +/// let mut entry = map.entry_by_ref("key"); /// /// // Modify the value -/// entry.value.replace(42); +/// entry.get_mut().replace(42); /// -/// // Entry is automatically unlocked when dropped +/// // EntryByRef is automatically unlocked when dropped /// } /// ``` -pub struct EntryByRef<'a, 'b, K: Eq + Hash, Q, V> -where - K: Borrow, - Q: Eq + Hash + ?Sized, -{ +pub struct EntryByRef<'a, 'b, K: Eq + Hash + Borrow, Q: Eq + Hash + ?Sized, V> { map: &'a LockMap, - pub key: &'b Q, - pub value: &'a mut Option, + key: &'b Q, + state: &'a mut State, } -impl Drop for EntryByRef<'_, '_, K, Q, V> +impl, Q: Eq + Hash + ?Sized, V> EntryByRef<'_, '_, K, Q, V> { + pub fn key(&self) -> &Q { + self.key + } + + pub fn get(&mut self) -> &Option { + &self.state.value + } + + pub fn get_mut(&mut self) -> &mut Option { + &mut self.state.value + } + + pub fn insert(&mut self, value: V) -> Option { + self.state.value.replace(value) + } +} + +impl std::fmt::Debug for EntryByRef<'_, '_, K, Q, V> where - K: Borrow, - Q: Eq + Hash + ?Sized, + K: Eq + Hash + Borrow + std::fmt::Debug, + Q: Eq + Hash + ?Sized + std::fmt::Debug, + V: std::fmt::Debug, { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EntryByRef") + .field("key", &self.key) + .field("value", &self.state.value) + .finish() + } +} + +impl, Q: Eq + Hash + ?Sized, V> Drop for EntryByRef<'_, '_, K, Q, V> { fn drop(&mut self) { + self.state.mutex.unlock(); self.map.unlock(self.key); } } @@ -546,28 +570,34 @@ where #[cfg(test)] mod tests { use super::*; - use std::sync::{atomic::AtomicUsize, Arc}; + use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }; #[test] fn test_lockmap_lock() { let map = LockMap::::new(); + println!("{:?}", map); { - let entry = map.entry(1); - entry.value.replace(2); + let mut entry = map.entry(1); + assert_eq!(*entry.key(), 1); + assert_eq!(entry.insert(2), None); + println!("{:?}", entry); } { - let entry = map.entry(1); - assert_eq!(entry.value.unwrap(), 2); - entry.value.take(); + let mut entry = map.entry(1); + assert_eq!(entry.get_mut().unwrap(), 2); + entry.get_mut().take(); } { - let entry = map.entry(1); - assert!(entry.value.is_none()); + let mut entry = map.entry(1); + assert!(entry.get_mut().is_none()); } let map = LockMap::::default(); { - let entry = map.entry(1); - entry.value.replace(2); + let mut entry = map.entry(1); + entry.get_mut().replace(2); } assert_eq!(map.remove(&1), Some(2)); assert_eq!(map.remove(&1), None); @@ -577,37 +607,61 @@ mod tests { #[should_panic(expected = "impossible: unlock a non-existent key!")] fn test_lockmap_invalid_unlock() { let map = LockMap::::new(); - let key = 0xB1; - let mut dummy = Some(7268); - let _ = Entry { + let mut state = State { + refcnt: 1, + mutex: Mutex::new(), + value: None, + }; + let _ = EntryByVal { map: &map, - key, - value: &mut dummy, + key: 7268, + state: &mut state, }; } #[test] - #[should_panic(expected = "impossible: unlock a non-existent key!")] - fn test_lockmap_invalid_unlock_by_ref() { - let map = LockMap::::new(); - let key = "hello"; - let mut dummy = Some(7268); - let _ = EntryByRef { - map: &map, - key, - value: &mut dummy, - }; + fn test_lockmap_same_key_by_value() { + let lock_map = Arc::new(LockMap::::with_capacity(256)); + let current = Arc::new(AtomicU32::default()); + const N: usize = 1 << 20; + const M: usize = 4; + + const S: usize = 0; + lock_map.insert(S, 0); + + let threads = (0..M) + .map(|_| { + let lock_map = lock_map.clone(); + let current = current.clone(); + std::thread::spawn(move || { + for _ in 0..N { + let mut entry = lock_map.entry(S); + let now = current.fetch_add(1, Ordering::AcqRel); + assert_eq!(now, 0); + let v = entry.get_mut().as_mut().unwrap(); + *v += 1; + let now = current.fetch_sub(1, Ordering::AcqRel); + assert_eq!(now, 1); + } + }) + }) + .collect::>(); + threads.into_iter().for_each(|t| t.join().unwrap()); + + let mut entry = lock_map.entry(S); + assert_eq!(*entry.get(), Some(N * M)); + assert_eq!(entry.get_mut().unwrap(), N * M); } #[test] - fn test_lockmap_same_key() { + fn test_lockmap_same_key_by_ref() { let lock_map = Arc::new(LockMap::::with_capacity(256)); let current = Arc::new(AtomicU32::default()); - const N: usize = 1 << 12; - const M: usize = 16; + const N: usize = 1 << 20; + const M: usize = 4; const S: &str = "hello"; - lock_map.set_by_ref(S, 0); + lock_map.insert_by_ref(S, 0); let threads = (0..M) .map(|_| { @@ -615,10 +669,10 @@ mod tests { let current = current.clone(); std::thread::spawn(move || { for _ in 0..N { - let entry = lock_map.entry_by_ref(S); + let mut entry = lock_map.entry_by_ref(S); let now = current.fetch_add(1, Ordering::AcqRel); assert_eq!(now, 0); - let v = entry.value.as_mut().unwrap(); + let v = entry.get_mut().as_mut().unwrap(); *v += 1; let now = current.fetch_sub(1, Ordering::AcqRel); assert_eq!(now, 1); @@ -628,14 +682,17 @@ mod tests { .collect::>(); threads.into_iter().for_each(|t| t.join().unwrap()); - let entry = lock_map.entry_by_ref(S); - assert_eq!(entry.value.unwrap(), N * M); + let mut entry = lock_map.entry_by_ref(S); + println!("{:?}", entry); + assert_eq!(entry.key(), S); + assert_eq!(*entry.get(), Some(N * M)); + assert_eq!(entry.insert(0).unwrap(), N * M); } #[test] fn test_lockmap_random_key() { let lock_map = Arc::new(LockMap::::with_capacity_and_shard_amount(256, 16)); - let total = Arc::new(AtomicUsize::default()); + let total = Arc::new(AtomicU32::default()); const N: usize = 1 << 12; const M: usize = 8; @@ -646,18 +703,18 @@ mod tests { std::thread::spawn(move || { for _ in 0..N { let key = rand::random::() % 32; - let entry = lock_map.entry(key); - assert!(entry.value.is_none()); - entry.value.replace(1); + let mut entry = lock_map.entry(key); + assert!(entry.get_mut().is_none()); + entry.get_mut().replace(1); total.fetch_add(1, Ordering::AcqRel); - entry.value.take(); + entry.get_mut().take(); } }) }) .collect::>(); threads.into_iter().for_each(|t| t.join().unwrap()); - assert_eq!(total.load(Ordering::Acquire), N * M); + assert_eq!(total.load(Ordering::Acquire) as usize, N * M); } #[test] @@ -671,11 +728,11 @@ mod tests { for _ in 0..N { let key = rand::random::() % 32; let value = rand::random::() % 32; - let entry = lock_map.entry(key); + let mut entry = lock_map.entry(key); if value < 16 { - entry.value.take(); + entry.get_mut().take(); } else { - entry.value.replace(value); + entry.get_mut().replace(value); } } }) @@ -690,7 +747,7 @@ mod tests { if value < 16 { lock_map.remove(&key); } else { - lock_map.set(key, value); + lock_map.insert(key, value); } } }) @@ -727,11 +784,11 @@ mod tests { for _ in 0..N { let key = (rand::random::() % 32).to_string(); let value = rand::random::() % 32; - let entry = lock_map.entry_by_ref(&key); + let mut entry = lock_map.entry_by_ref(&key); if value < 16 { - entry.value.take(); + entry.get_mut().take(); } else { - entry.value.replace(value); + entry.get_mut().replace(value); } } }) @@ -746,7 +803,7 @@ mod tests { if value < 16 { lock_map.remove(&key); } else { - lock_map.set_by_ref(&key, value); + lock_map.insert_by_ref(&key, value); } } }) diff --git a/src/shards_map.rs b/src/shards_map.rs index 63e6823..f86dbdf 100644 --- a/src/shards_map.rs +++ b/src/shards_map.rs @@ -16,7 +16,7 @@ pub enum UpdateAction { /// Keep the current value unchanged. Keep, /// Update the value with the provided new value. - Update(V), + Replace(V), } /// A thread-safe hashmap shard. @@ -83,7 +83,7 @@ where let (action, ret) = func(Some(value)); match action { UpdateAction::Keep => {} - UpdateAction::Update(v) => { + UpdateAction::Replace(v) => { *value = v; } } @@ -93,7 +93,7 @@ where let (action, ret) = func(None); match action { UpdateAction::Keep => {} - UpdateAction::Update(value) => { + UpdateAction::Replace(value) => { map.insert(key, value); } } @@ -124,7 +124,7 @@ where let (action, ret) = func(Some(value)); match action { UpdateAction::Keep => {} - UpdateAction::Update(v) => { + UpdateAction::Replace(v) => { *value = v; } } @@ -134,7 +134,7 @@ where let (action, ret) = func(None); match action { UpdateAction::Keep => {} - UpdateAction::Update(value) => { + UpdateAction::Replace(value) => { map.insert(key.into(), value); } } @@ -255,7 +255,7 @@ mod tests { let shards_map = ShardsMap::::with_capacity_and_shard_amount(256, 16); shards_map.update(1, |v| { assert_eq!(v, None); - (UpdateAction::Update(1), ()) + (UpdateAction::Replace(1), ()) }); shards_map.update(2, |v| { assert_eq!(v, None); @@ -267,7 +267,7 @@ mod tests { }); shards_map.update(1, |v| { assert_eq!(v.cloned(), Some(1)); - (UpdateAction::Update(2), ()) + (UpdateAction::Replace(2), ()) }); shards_map.update(1, |v| { assert_eq!(v.cloned(), Some(2)); @@ -288,11 +288,11 @@ mod tests { let shards_map = ShardsMap::::with_capacity_and_shard_amount(256, 16); shards_map.update_by_ref("hello", |v| { assert_eq!(v, None); - (UpdateAction::Update("world".to_string()), ()) + (UpdateAction::Replace("world".to_string()), ()) }); shards_map.update_by_ref("hello", |v| { assert_eq!(v.unwrap(), "world"); - (UpdateAction::Update("lockmap".to_string()), ()) + (UpdateAction::Replace("lockmap".to_string()), ()) }); shards_map.simple_update("hello", |v| { assert_eq!(v, Some(&mut "lockmap".to_string())); @@ -304,7 +304,7 @@ mod tests { }); shards_map.update_by_ref("hello", |v| { assert_eq!(v, None); - (UpdateAction::Update("lockmap".to_string()), ()) + (UpdateAction::Replace("lockmap".to_string()), ()) }); shards_map.simple_update("hello", |v| { assert_eq!(v.unwrap(), "lockmap"); @@ -337,7 +337,7 @@ mod tests { const N: usize = 1 << 12; const M: usize = 8; - lock_map.update(1, |_| (UpdateAction::Update(0), ())); + lock_map.update(1, |_| (UpdateAction::Replace(0), ())); let threads = (0..M) .map(|_| { @@ -360,7 +360,7 @@ mod tests { threads.into_iter().for_each(|t| t.join().unwrap()); assert_eq!( - lock_map.update(1, |v| (UpdateAction::Update(0), *v.unwrap())), + lock_map.update(1, |v| (UpdateAction::Replace(0), *v.unwrap())), N * M ); } diff --git a/src/waiter.rs b/src/waiter.rs deleted file mode 100644 index babbdbc..0000000 --- a/src/waiter.rs +++ /dev/null @@ -1,48 +0,0 @@ -use std::sync::atomic::{AtomicU32, Ordering}; - -/// A pointer type used for thread synchronization that provides waiting and waking capabilities. -/// -/// This type is used internally by `LockMap` to implement per-key locking behavior. It wraps an -/// `AtomicU32` and provides safe access to waiting and waking operations. -/// -/// # Safety -/// This type is both `Send` and `Sync` safe as it only provides controlled access to an atomic value. -pub struct WaiterPtr(*const AtomicU32); - -impl WaiterPtr { - /// Creates a new `WaiterPtr` from a reference to an `AtomicU32`. - /// - /// # Arguments - /// * `w` - Reference to the `AtomicU32` to wrap - /// - /// # Safety - /// The wrapped `AtomicU32` must outlive the `WaiterPtr`. - pub fn new(w: &AtomicU32) -> Self { - Self(w as *const _) - } - - /// Wakes up a single thread waiting on this value. - /// - /// Sets the atomic value to 1 and wakes one waiting thread. - pub fn wake_up(&self) { - let waiter = unsafe { &*self.0 }; - waiter.store(1, Ordering::Release); - atomic_wait::wake_one(self.0); - } - - /// Waits until the atomic value becomes non-zero. - /// - /// This will block the current thread until another thread calls `wake_up()`. - /// - /// # Arguments - /// * `w` - The `AtomicU32` to wait on - pub fn wait(w: &AtomicU32) { - while w.load(Ordering::Acquire) == 0 { - atomic_wait::wait(w, 0); - } - } -} - -// Safety: WaiterPtr can be safely shared between threads -unsafe impl Sync for WaiterPtr {} -unsafe impl Send for WaiterPtr {}