diff --git a/src/futex.rs b/src/futex.rs index 25df6f5..f0a15f9 100644 --- a/src/futex.rs +++ b/src/futex.rs @@ -132,4 +132,49 @@ mod tests { task.join().unwrap(); } } + + #[test] + fn test_concurrent() { + let lock = Arc::new(Mutex::new()); + let counter = Arc::new(AtomicU32::new(0)); + const THREAD_COUNT: usize = 4; + const ITERATIONS: usize = 10000; + + let mut handles = vec![]; + + // Spawn multiple threads that increment and decrement a shared counter + for _ in 0..THREAD_COUNT { + let lock = Arc::clone(&lock); + let counter = Arc::clone(&counter); + + handles.push(std::thread::spawn(move || { + for _ in 0..ITERATIONS { + // Lock and modify shared state + lock.lock(); + let value = counter.load(Relaxed); + std::thread::yield_now(); // Force a context switch to increase contention + counter.store(value + 1, Relaxed); + lock.unlock(); + + // Do some work without the lock + std::thread::yield_now(); + + // Lock and modify shared state again + lock.lock(); + let value = counter.load(Relaxed); + std::thread::yield_now(); // Force a context switch to increase contention + counter.store(value - 1, Relaxed); + lock.unlock(); + } + })); + } + + // Wait for all threads to complete + for handle in handles { + handle.join().unwrap(); + } + + // Verify the final counter value is 0 + assert_eq!(counter.load(Relaxed), 0); + } } diff --git a/src/lockmap.rs b/src/lockmap.rs index c264bc6..cbf9098 100644 --- a/src/lockmap.rs +++ b/src/lockmap.rs @@ -886,4 +886,50 @@ mod tests { set_thread.join().unwrap(); get_thread.join().unwrap(); } + + #[test] + fn test_lockmap_heavy_contention() { + let lock_map = Arc::new(LockMap::::new()); + const THREADS: usize = 16; + const OPS_PER_THREAD: usize = 10000; + const HOT_KEYS: u32 = 5; + + let counter = Arc::new(AtomicU32::new(0)); + + let threads: Vec<_> = (0..THREADS) + .map(|_| { + let lock_map = lock_map.clone(); + let counter = counter.clone(); + std::thread::spawn(move || { + for _ in 0..OPS_PER_THREAD { + let key = rand::random::() % HOT_KEYS; + let mut entry = lock_map.entry(key); + + // Simulate some work + std::thread::sleep(std::time::Duration::from_nanos(10)); + + match entry.get_mut() { + Some(value) => { + *value = value.wrapping_add(1); + counter.fetch_add(1, Ordering::Relaxed); + } + None => { + entry.insert(1); + counter.fetch_add(1, Ordering::Relaxed); + } + } + } + }) + }) + .collect(); + + for thread in threads { + thread.join().unwrap(); + } + + assert_eq!( + counter.load(Ordering::Relaxed), + THREADS as u32 * OPS_PER_THREAD as u32 + ); + } }