From ac35c51cee84eff42b2c96ea116e59a24df75dcd Mon Sep 17 00:00:00 2001 From: SF-Zhou Date: Fri, 27 Dec 2024 11:43:40 +0800 Subject: [PATCH] optimize interface --- Cargo.toml | 2 +- README.md | 4 +- src/lockmap.rs | 198 +++++++++++----------------------------------- src/shards_map.rs | 151 ++++++++++++++++++++++++----------- 4 files changed, 153 insertions(+), 202 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c4760e2..2f027f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lockmap" -version = "0.1.2" +version = "0.1.3" edition = "2021" authors = ["SF-Zhou "] diff --git a/README.md b/README.md index 539ac9f..35f5369 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ let map = LockMap::::new(); map.set_by_ref("key", "value".into()); // Get a value -assert_eq!(map.get_by_ref("key"), Some("value".into())); +assert_eq!(map.get("key"), Some("value".into())); // Use entry API for exclusive access { @@ -28,5 +28,5 @@ assert_eq!(map.get_by_ref("key"), Some("value".into())); } // Remove a value -map.remove_by_ref("key"); +map.remove("key"); ``` diff --git a/src/lockmap.rs b/src/lockmap.rs index 08247f9..0f597ba 100644 --- a/src/lockmap.rs +++ b/src/lockmap.rs @@ -1,4 +1,4 @@ -use crate::{Action, ShardsMap, WaiterPtr}; +use crate::{ShardsMap, SimpleAction, UpdateAction, WaiterPtr}; use std::borrow::Borrow; use std::collections::LinkedList; use std::hash::Hash; @@ -120,7 +120,7 @@ impl LockMap { waiter.store(1, Ordering::Release); } state.queue.push_back(WaiterPtr::new(&waiter)); - (Action::Keep, state.value.as_mut() as *mut _) + (UpdateAction::Keep, state.value.as_mut() as *mut _) } None => { let mut state = State::default(); @@ -128,7 +128,7 @@ impl LockMap { state.queue.push_back(WaiterPtr::new(&waiter)); let ptr = state.value.as_mut() as *mut _; waiter.store(1, Ordering::Release); - (Action::Update(state), ptr) + (UpdateAction::Update(state), ptr) } }); @@ -155,9 +155,9 @@ impl LockMap { /// { /// let mut entry = map.entry_by_ref("key"); /// entry.value.replace(42); - /// // let _ = map.get_by_ref("key"); // DEADLOCK! + /// // let _ = map.get("key"); // DEADLOCK! /// // map.set_by_ref("key", 21); // DEADLOCK! - /// // map.remove_by_ref("key"); // DEADLOCK! + /// // map.remove("key"); // DEADLOCK! /// // let mut entry2 = map.entry_by_ref("key"); // DEADLOCK! /// } /// ``` @@ -174,7 +174,7 @@ impl LockMap { waiter.store(1, Ordering::Release); } state.queue.push_back(WaiterPtr::new(&waiter)); - (Action::Keep, state.value.as_mut() as *mut _) + (UpdateAction::Keep, state.value.as_mut() as *mut _) } None => { let mut state = State::default(); @@ -182,7 +182,7 @@ impl LockMap { state.queue.push_back(WaiterPtr::new(&waiter)); let ptr = state.value.as_mut() as *mut _; waiter.store(1, Ordering::Release); - (Action::Update(state), ptr) + (UpdateAction::Update(state), ptr) } }); @@ -226,12 +226,12 @@ impl LockMap { if state.queue.is_empty() { // no need to wait. state.value.replace(value); - (Action::Keep, None) + (UpdateAction::Keep, None) } else { // need to wait. state.queue.push_back(WaiterPtr::new(&waiter)); ptr = state.value.as_mut() as *mut _; - (Action::Keep, Some(value)) + (UpdateAction::Keep, Some(value)) } } None => { @@ -240,7 +240,7 @@ impl LockMap { value: Box::new(Some(value)), queue: Default::default(), }; - (Action::Update(state), None) + (UpdateAction::Update(state), None) } }); @@ -251,7 +251,7 @@ impl LockMap { WaiterPtr::wait(&waiter); *Self::value_ptr_to_ref(ptr) = value; - self.unlock(key); + self.unlock(&key); } /// Sets a value in the map. @@ -289,12 +289,12 @@ impl LockMap { if state.queue.is_empty() { // no need to wait. state.value.replace(value); - (Action::Keep, None) + (UpdateAction::Keep, None) } else { // need to wait. state.queue.push_back(WaiterPtr::new(&waiter)); ptr = state.value.as_mut() as *mut _; - (Action::Keep, Some(value)) + (UpdateAction::Keep, Some(value)) } } None => { @@ -303,7 +303,7 @@ impl LockMap { value: Box::new(Some(value)), queue: Default::default(), }; - (Action::Update(state), None) + (UpdateAction::Update(state), None) } }); @@ -314,51 +314,6 @@ impl LockMap { WaiterPtr::wait(&waiter); *Self::value_ptr_to_ref(ptr) = value; - self.unlock_by_ref(key); - } - - /// Removes a key from the map. - /// - /// If other threads are currently accessing the key, this will wait - /// until exclusive access is available before removing. - /// - /// # Arguments - /// * `key` - The key to remove - /// - /// **Locking behaviour:** Deadlock if called when holding the same entry. - /// - /// # Examples - /// ``` - /// use lockmap::LockMap; - /// - /// let map = LockMap::::new(); - /// map.set("key".to_string(), 42); - /// map.remove("key".to_string()); - /// assert_eq!(map.get("key".to_string()), None); - /// ``` - pub fn remove(&self, key: K) { - let waiter = AtomicU32::new(0); - let ptr = self.map.update(key.clone(), |v| match v { - Some(state) => { - if state.queue.is_empty() { - // no need to wait. - (Action::Remove, std::ptr::null_mut()) - } else { - // need to wait. - state.queue.push_back(WaiterPtr::new(&waiter)); - (Action::Keep, state.value.as_mut() as *mut Option) - } - } - None => (Action::Keep, std::ptr::null_mut()), // no need to wait. - }); - - if ptr.is_null() { - return; - } - - WaiterPtr::wait(&waiter); - - Self::value_ptr_to_ref(ptr).take(); self.unlock(key); } @@ -378,27 +333,27 @@ impl LockMap { /// /// let map = LockMap::::new(); /// map.set_by_ref("key", 42); - /// map.remove_by_ref("key"); - /// assert_eq!(map.get_by_ref("key"), None); + /// map.remove("key"); + /// assert_eq!(map.get("key"), None); /// ``` - pub fn remove_by_ref(&self, key: &Q) + pub fn remove(&self, key: &Q) where - K: Borrow + for<'c> From<&'c Q>, + K: Borrow, Q: Eq + Hash + ?Sized, { let waiter = AtomicU32::new(0); - let ptr = self.map.update_by_ref(key, |v| match v { + let ptr = self.map.simple_update(key, |v| match v { Some(state) => { if state.queue.is_empty() { // no need to wait. - (Action::Remove, std::ptr::null_mut()) + (SimpleAction::Remove, std::ptr::null_mut()) } else { // need to wait. state.queue.push_back(WaiterPtr::new(&waiter)); - (Action::Keep, state.value.as_mut() as *mut Option) + (SimpleAction::Keep, state.value.as_mut() as *mut Option) } } - None => (Action::Keep, std::ptr::null_mut()), // no need to wait. + None => (SimpleAction::Keep, std::ptr::null_mut()), // no need to wait. }); if ptr.is_null() { @@ -408,36 +363,29 @@ impl LockMap { WaiterPtr::wait(&waiter); Self::value_ptr_to_ref(ptr).take(); - self.unlock_by_ref(key); - } - - fn unlock(&self, key: K) { - self.map.update(key, |value| match value { - Some(state) => (Self::wake_up_next_one(state), ()), - None => panic!("impossible: unlock a non-existent key!"), - }); + self.unlock(key); } - fn unlock_by_ref(&self, key: &Q) + fn unlock(&self, key: &Q) where - K: Borrow + for<'c> From<&'c Q>, + K: Borrow, Q: Eq + Hash + ?Sized, { - self.map.update_by_ref(key, |value| match value { + self.map.simple_update(key, |value| match value { Some(state) => (Self::wake_up_next_one(state), ()), None => panic!("impossible: unlock a non-existent key!"), }); } - fn wake_up_next_one(state: &mut State) -> Action> { + fn wake_up_next_one(state: &mut State) -> SimpleAction { state.queue.pop_front(); match state.queue.front() { Some(waiter) => { waiter.wake_up(); - Action::Keep + SimpleAction::Keep } - None if state.value.is_none() => Action::Remove, - None => Action::Keep, + None if state.value.is_none() => SimpleAction::Remove, + None => SimpleAction::Keep, } } @@ -447,58 +395,6 @@ impl LockMap { } impl LockMap { - /// Gets the value associated with the given key. - /// - /// If other threads are currently accessing the key, this will wait - /// until exclusive access is available before returning. - /// - /// # Arguments - /// * `key` - The key to look up - /// - /// # Returns - /// * `Some(V)` if the key exists - /// * `None` if the key doesn't exist - /// - /// **Locking behaviour:** Deadlock if called when holding the same entry. - /// - /// # Examples - /// ``` - /// use lockmap::LockMap; - /// - /// let map = LockMap::::new(); - /// map.set("key".to_string(), 42); - /// assert_eq!(map.get("key".to_string()), Some(42)); - /// assert_eq!(map.get("missing".to_string()), None); - /// ``` - pub fn get(&self, key: K) -> Option { - let waiter = AtomicU32::new(0); - let mut ptr: *mut Option = std::ptr::null_mut(); - let value = self.map.update(key.clone(), |value| match value { - Some(state) => { - if state.queue.is_empty() { - // no need to wait. - (Action::Keep, state.value.as_mut().clone()) - } else { - // need to wait. - state.queue.push_back(WaiterPtr::new(&waiter)); - ptr = state.value.as_mut() as *mut _; - (Action::Keep, None) - } - } - None => (Action::Keep, None), - }); - - if ptr.is_null() { - return value; - } - - WaiterPtr::wait(&waiter); - - let value = Self::value_ptr_to_ref(ptr).clone(); - self.unlock(key); - value - } - /// Gets the value associated with the given key. /// /// If other threads are currently accessing the key, this will wait @@ -519,29 +415,29 @@ impl LockMap { /// /// let map = LockMap::::new(); /// map.set_by_ref("key", 42); - /// assert_eq!(map.get_by_ref("key"), Some(42)); - /// assert_eq!(map.get_by_ref("missing"), None); + /// assert_eq!(map.get("key"), Some(42)); + /// assert_eq!(map.get("missing"), None); /// ``` - pub fn get_by_ref(&self, key: &Q) -> Option + pub fn get(&self, key: &Q) -> Option where - K: Borrow + for<'c> From<&'c Q>, + K: Borrow, Q: Eq + Hash + ?Sized, { let waiter = AtomicU32::new(0); let mut ptr: *mut Option = std::ptr::null_mut(); - let value = self.map.update_by_ref(key, |value| match value { + let value = self.map.simple_update(key, |value| match value { Some(state) => { if state.queue.is_empty() { // no need to wait. - (Action::Keep, state.value.as_mut().clone()) + (SimpleAction::Keep, state.value.as_mut().clone()) } else { // need to wait. state.queue.push_back(WaiterPtr::new(&waiter)); ptr = state.value.as_mut() as *mut _; - (Action::Keep, None) + (SimpleAction::Keep, None) } } - None => (Action::Keep, None), + None => (SimpleAction::Keep, None), }); if ptr.is_null() { @@ -551,7 +447,7 @@ impl LockMap { WaiterPtr::wait(&waiter); let value = Self::value_ptr_to_ref(ptr).clone(); - self.unlock_by_ref(key); + self.unlock(key); value } } @@ -588,7 +484,7 @@ pub struct Entry<'a, K: Eq + Hash + Clone, V> { impl Drop for Entry<'_, K, V> { fn drop(&mut self) { - self.map.unlock(self.key.clone()); + self.map.unlock(&self.key); } } @@ -620,7 +516,7 @@ impl Drop for Entry<'_, K, V> { /// ``` pub struct EntryByRef<'a, 'b, K: Eq + Hash + Clone, Q, V> where - K: Borrow + for<'c> From<&'c Q>, + K: Borrow, Q: Eq + Hash + ?Sized, { map: &'a LockMap, @@ -630,11 +526,11 @@ where impl Drop for EntryByRef<'_, '_, K, Q, V> where - K: Borrow + for<'c> From<&'c Q>, + K: Borrow, Q: Eq + Hash + ?Sized, { fn drop(&mut self) { - self.map.unlock_by_ref(self.key); + self.map.unlock(self.key); } } @@ -783,7 +679,7 @@ mod tests { let key = rand::random::() % 32; let value = rand::random::() % 32; if value < 16 { - lock_map.remove(key); + lock_map.remove(&key); } else { lock_map.set(key, value); } @@ -796,7 +692,7 @@ mod tests { std::thread::spawn(move || { for _ in 0..N { let key = rand::random::() % 32; - let value = lock_map.get(key); + let value = lock_map.get(&key); if let Some(v) = value { assert!(v >= 16) } @@ -839,7 +735,7 @@ mod tests { let key = (rand::random::() % 32).to_string(); let value = rand::random::() % 32; if value < 16 { - lock_map.remove_by_ref(&key); + lock_map.remove(&key); } else { lock_map.set_by_ref(&key, value); } @@ -852,7 +748,7 @@ mod tests { std::thread::spawn(move || { for _ in 0..N { let key = (rand::random::() % 32).to_string(); - let value = lock_map.get_by_ref(&key); + let value = lock_map.get(&key); if let Some(v) = value { assert!(v >= 16) } diff --git a/src/shards_map.rs b/src/shards_map.rs index 2dce873..85dca48 100644 --- a/src/shards_map.rs +++ b/src/shards_map.rs @@ -4,11 +4,17 @@ use std::hash::{Hash, Hasher}; use std::sync::Mutex; /// Represents the action to be taken on a value in the `ShardMap`. -pub enum Action { +pub enum SimpleAction { /// Keep the current value unchanged. Keep, /// Remove the value from the map. Remove, +} + +/// Represents the action to be taken on a value in the `ShardMap`. +pub enum UpdateAction { + /// Keep the current value unchanged. + Keep, /// Update the value with the provided new value. Update(V), } @@ -41,6 +47,22 @@ where } } + pub fn simple_update(&self, key: &Q, func: F) -> R + where + K: Borrow, + Q: Eq + Hash + ?Sized, + F: FnOnce(Option<&mut V>) -> (SimpleAction, R), + { + let mut map = self.map.lock().unwrap(); + let value = map.get_mut(key); + let has_value = value.is_some(); + let (action, ret) = func(value); + if has_value && matches!(action, SimpleAction::Remove) { + let _ = map.remove_entry(key); + } + ret + } + /// Updates the value associated with the given key using the provided function. /// /// # Arguments @@ -53,19 +75,16 @@ where /// The result returned by the provided function. pub fn update(&self, key: K, func: F) -> R where - F: FnOnce(Option<&mut V>) -> (Action, R), + F: FnOnce(Option<&mut V>) -> (UpdateAction, R), { let mut map = self.map.lock().unwrap(); match map.get_mut(&key) { Some(value) => { let (action, ret) = func(Some(value)); match action { - Action::Keep => {} - Action::Remove => { - map.remove(&key); - } - Action::Update(value) => { - map.insert(key, value); + UpdateAction::Keep => {} + UpdateAction::Update(v) => { + *value = v; } } ret @@ -73,9 +92,8 @@ where None => { let (action, ret) = func(None); match action { - Action::Keep => {} - Action::Remove => {} - Action::Update(value) => { + UpdateAction::Keep => {} + UpdateAction::Update(value) => { map.insert(key, value); } } @@ -98,19 +116,16 @@ where where K: Borrow + for<'c> From<&'c Q>, Q: Eq + Hash + ?Sized, - F: FnOnce(Option<&mut V>) -> (Action, R), + F: FnOnce(Option<&mut V>) -> (UpdateAction, R), { let mut map = self.map.lock().unwrap(); match map.get_mut(key) { Some(value) => { let (action, ret) = func(Some(value)); match action { - Action::Keep => {} - Action::Remove => { - map.remove(key); - } - Action::Update(value) => { - map.insert(key.into(), value); + UpdateAction::Keep => {} + UpdateAction::Update(v) => { + *value = v; } } ret @@ -118,9 +133,8 @@ where None => { let (action, ret) = func(None); match action { - Action::Keep => {} - Action::Remove => {} - Action::Update(value) => { + UpdateAction::Keep => {} + UpdateAction::Update(value) => { map.insert(key.into(), value); } } @@ -159,6 +173,25 @@ where } } + /// Updates the value associated with the given key using the provided function. + /// + /// # Arguments + /// + /// * `key` - The key to update. + /// * `func` - A function that takes an `Option<&mut V>` and returns a tuple containing the action to take and the result. + /// + /// # Returns + /// + /// The result returned by the provided function. + pub fn simple_update(&self, key: &Q, func: F) -> R + where + K: Borrow, + Q: Eq + Hash + ?Sized, + F: FnOnce(Option<&mut V>) -> (SimpleAction, R), + { + self.shard(key).simple_update(key, func) + } + /// Updates the value associated with the given key using the provided function. /// /// # Arguments @@ -171,12 +204,9 @@ where /// The result returned by the provided function. pub fn update(&self, key: K, func: F) -> R where - F: FnOnce(Option<&mut V>) -> (Action, R), + F: FnOnce(Option<&mut V>) -> (UpdateAction, R), { - let mut s = DefaultHasher::new(); - key.hash(&mut s); - let idx = s.finish() as usize % self.shards.len(); - self.shards[idx].update(key, func) + self.shard(&key).update(key, func) } /// Updates the value associated with the given key using the provided function. @@ -193,12 +223,21 @@ where where K: Borrow + for<'c> From<&'c Q>, Q: Eq + Hash + ?Sized, - F: FnOnce(Option<&mut V>) -> (Action, R), + F: FnOnce(Option<&mut V>) -> (UpdateAction, R), + { + self.shard(key).update_by_ref(key, func) + } + + #[inline(always)] + fn shard(&self, key: &Q) -> &ShardMap + where + K: Borrow, + Q: Eq + Hash + ?Sized, { let mut s = DefaultHasher::new(); key.hash(&mut s); let idx = s.finish() as usize % self.shards.len(); - self.shards[idx].update_by_ref(key, func) + &self.shards[idx] } } @@ -216,31 +255,31 @@ mod tests { let shards_map = ShardsMap::::with_capacity_and_shard_amount(256, 16); shards_map.update(1, |v| { assert_eq!(v, None); - (Action::Update(1), ()) + (UpdateAction::Update(1), ()) }); shards_map.update(2, |v| { assert_eq!(v, None); - (Action::Keep, ()) + (UpdateAction::Keep, ()) }); - shards_map.update(3, |v| { + shards_map.simple_update(&3, |v| { assert_eq!(v, None); - (Action::Remove, ()) + (SimpleAction::Remove, ()) }); shards_map.update(1, |v| { assert_eq!(v.cloned(), Some(1)); - (Action::Update(2), ()) + (UpdateAction::Update(2), ()) }); shards_map.update(1, |v| { assert_eq!(v.cloned(), Some(2)); - (Action::Keep, ()) + (UpdateAction::Keep, ()) }); - shards_map.update(1, |v| { + shards_map.simple_update(&1, |v| { assert_eq!(v.cloned(), Some(2)); - (Action::Remove, ()) + (SimpleAction::Remove, ()) }); - shards_map.update(1, |v| { + shards_map.simple_update(&1, |v| { assert_eq!(v, None); - (Action::Remove, ()) + (SimpleAction::Remove, ()) }); } @@ -249,27 +288,43 @@ mod tests { let shards_map = ShardsMap::::with_capacity_and_shard_amount(256, 16); shards_map.update_by_ref("hello", |v| { assert_eq!(v, None); - (Action::Update("world".to_string()), ()) + (UpdateAction::Update("world".to_string()), ()) }); shards_map.update_by_ref("hello", |v| { assert_eq!(v.unwrap(), "world"); - (Action::Update("lockmap".to_string()), ()) + (UpdateAction::Update("lockmap".to_string()), ()) + }); + shards_map.simple_update("hello", |v| { + assert_eq!(v, Some(&mut "lockmap".to_string())); + (SimpleAction::Remove, ()) + }); + shards_map.simple_update("hello", |v| { + assert_eq!(v, None); + (SimpleAction::Remove, ()) }); shards_map.update_by_ref("hello", |v| { + assert_eq!(v, None); + (UpdateAction::Update("lockmap".to_string()), ()) + }); + shards_map.simple_update("hello", |v| { assert_eq!(v.unwrap(), "lockmap"); - (Action::Remove, ()) + (SimpleAction::Remove, ()) }); - shards_map.update_by_ref("hello", |v| { + shards_map.simple_update("hello", |v| { assert_eq!(v, None); - (Action::Remove, ()) + (SimpleAction::Remove, ()) }); shards_map.update_by_ref("hello", |v| { assert_eq!(v, None); - (Action::Keep, ()) + (UpdateAction::Keep, ()) }); shards_map.update_by_ref(&"hello".to_owned(), |v| { assert_eq!(v, None); - (Action::Keep, ()) + (UpdateAction::Keep, ()) + }); + shards_map.simple_update("hello", |v| { + assert_eq!(v, None); + (SimpleAction::Keep, ()) }); } @@ -282,7 +337,7 @@ mod tests { const N: usize = 1 << 12; const M: usize = 8; - lock_map.update(1, |_| (Action::Update(0), ())); + lock_map.update(1, |_| (UpdateAction::Update(0), ())); let threads = (0..M) .map(|_| { @@ -296,7 +351,7 @@ mod tests { *v.unwrap() += 1; let now = current.fetch_sub(1, Ordering::AcqRel); assert_eq!(now, 1); - (Action::Keep, ()) + (UpdateAction::Keep, ()) }); } }) @@ -305,7 +360,7 @@ mod tests { threads.into_iter().for_each(|t| t.join().unwrap()); assert_eq!( - lock_map.update(1, |v| (Action::Update(0), *v.unwrap())), + lock_map.update(1, |v| (UpdateAction::Update(0), *v.unwrap())), N * M ); }