diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 830d0cc..0deb71b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,10 +1,4 @@ -on: - pull_request: - branches: - - main - push: - branches: - - main +on: [push, pull_request] env: RUST_BACKTRACE: 1 diff --git a/Cargo.toml b/Cargo.toml index ce77074..f13633c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aarc" -version = "0.0.1" +version = "0.1.0" edition = "2021" description = "Atomically updatable variants of Arc and Weak for lock-free concurrency." homepage = "https://github.com/aarc-rs/aarc" @@ -11,4 +11,4 @@ categories = ["concurrency", "memory-management", "data-structures", "algorithms exclude = [".github/", ".gitignore"] [dev-dependencies] -rand = "0.8" \ No newline at end of file +rand = "0.8" diff --git a/src/atomics.rs b/src/atomics.rs index b4cca30..2948691 100644 --- a/src/atomics.rs +++ b/src/atomics.rs @@ -1,8 +1,10 @@ use crate::statics::{acquire, release, retire}; +use std::mem::ManuallyDrop; +use std::ops::Deref; use std::ptr::null; use std::sync::atomic::Ordering::SeqCst; use std::sync::atomic::{AtomicPtr, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; pub struct AtomicArc { ptr: AtomicPtr, @@ -16,7 +18,7 @@ impl AtomicArc { /// # Examples /// ``` /// use std::sync::atomic::Ordering::SeqCst; - /// use aarc::atomics::AtomicArc; + /// use aarc::AtomicArc; /// /// let atomic1: AtomicArc> = AtomicArc::new(None); /// assert_eq!(atomic1.load(SeqCst), None); @@ -32,17 +34,17 @@ impl AtomicArc { } } - /// Loads the [`Arc`] with its reference count incremented appropriately if not [`None`]. - /// - /// # Example + /// Loads the [`Arc`] and increments its strong count appropriately if not [`None`]. /// + /// # Examples /// ``` /// use std::sync::Arc; /// use std::sync::atomic::Ordering::SeqCst; - /// use aarc::atomics::AtomicArc; + /// use aarc::AtomicArc; /// - /// let atomic1 = AtomicArc::new(Some(42)); - /// let loaded = atomic1.load(SeqCst).unwrap(); + /// let atomic = AtomicArc::new(None); + /// atomic.store(Some(&Arc::new(42)), SeqCst); + /// let loaded = atomic.load(SeqCst).unwrap(); /// /// assert_eq!(Arc::strong_count(&loaded), 2); /// assert_eq!(*loaded, 42); @@ -56,6 +58,18 @@ impl AtomicArc { } } + /// Stores `val` into `self`. See `AtomicArc::load` for examples. + pub fn store(&self, val: Option<&Arc>, order: Ordering) { + let ptr: *const T = val.map_or(null(), Arc::as_ptr); + unsafe { + Arc::increment_strong_count(ptr); + } + let before = self.ptr.swap(ptr.cast_mut(), order); + if !before.is_null() { + retire::<_, true>(before); + } + } + /// Stores `new` into `self` if `self` is equal to `current`. /// /// The comparison is shallow and is determined by pointer equality; see [`Arc::ptr_eq`]. @@ -64,11 +78,11 @@ impl AtomicArc { /// (instead of a copy of `current`). This eliminates the overhead of providing the caller with /// a redundant [`Arc`]. /// - /// # Example: + /// # Examples: /// ``` /// use std::sync::Arc; /// use std::sync::atomic::Ordering::SeqCst; - /// use aarc::atomics::AtomicArc; + /// use aarc::AtomicArc; /// /// let atomic1: AtomicArc = AtomicArc::new(None); /// let arc1 = Arc::new(42); @@ -92,7 +106,7 @@ impl AtomicArc { { Ok(before) => unsafe { if !before.is_null() { - retire(before); + retire::<_, true>(before); } if !n.is_null() { Arc::increment_strong_count(n); @@ -134,9 +148,11 @@ impl Clone for AtomicArc { } impl Default for AtomicArc { - /// Equivalent to `AtomicArc::new(None)`. + /// Creates an empty `AtomicArc`. Equivalent to `AtomicArc::new(None)`. fn default() -> Self { - Self::new(None) + Self { + ptr: AtomicPtr::default(), + } } } @@ -144,7 +160,7 @@ impl Drop for AtomicArc { fn drop(&mut self) { let ptr = self.ptr.load(SeqCst); if !ptr.is_null() { - retire(ptr); + retire::<_, true>(ptr); } } } @@ -164,3 +180,158 @@ impl From>> for AtomicArc { } } } + +pub struct AtomicWeak { + ptr: AtomicPtr, +} + +impl AtomicWeak { + /// Loads the [`Weak`] and increments its weak count appropriately if not [`None`]. + /// + /// # Examples + /// ``` + /// use std::sync::{Arc, Weak}; + /// use std::sync::atomic::Ordering::SeqCst; + /// use aarc::{AtomicArc, AtomicWeak}; + /// + /// let arc = Arc::new(42); + /// let weak = Arc::downgrade(&arc); + /// + /// let atomic_weak = AtomicWeak::default(); + /// atomic_weak.store(Some(&weak), SeqCst); + /// let loaded = atomic_weak.load(SeqCst).unwrap(); + /// + /// assert_eq!(*loaded.upgrade().unwrap(), 42); + /// assert_eq!(Weak::weak_count(&loaded), 3); + /// ``` + pub fn load(&self, order: Ordering) -> Option> { + acquire(); + let ptr = self.ptr.load(order); + let result = if !ptr.is_null() { + unsafe { Some(clone_weak_from_raw(ptr)) } + } else { + None + }; + release(); + result + } + + /// Stores `val` into `self`. See `AtomicWeak::load` for examples. + pub fn store(&self, val: Option<&Weak>, order: Ordering) { + let ptr: *const T = val.map_or(null(), |w| { + increment_weak_count(w); + Weak::as_ptr(w) + }); + let before = self.ptr.swap(ptr.cast_mut(), order); + if !before.is_null() { + retire::<_, false>(before); + } + } + + /// Stores `new` into `self` if `self` is equal to `current`. + /// + /// The comparison is shallow and is determined by pointer equality; see [`Weak::ptr_eq`]. + /// + /// If the comparison succeeds, the return value will be an [`Ok`] containing the unit type + /// (instead of a copy of `current`). This eliminates the overhead of providing the caller with + /// a redundant [`Weak`]. + /// + /// # Examples: + /// ``` + /// use std::sync::Arc; + /// use std::sync::atomic::Ordering::SeqCst; + /// use aarc::AtomicWeak; + /// + /// let arc = Arc::new(42); + /// let weak = Arc::downgrade(&arc); + /// + /// let atomic_weak = AtomicWeak::default(); + /// assert!(atomic_weak.compare_exchange(None, Some(&weak), SeqCst, SeqCst).is_ok()); + /// let loaded = atomic_weak.load(SeqCst).unwrap(); + /// assert_eq!(*loaded.upgrade().unwrap(), 42); + /// ``` + pub fn compare_exchange( + &self, + current: Option<&Weak>, + new: Option<&Weak>, + success: Ordering, + failure: Ordering, + ) -> Result<(), Option>> { + let c: *const T = current.map_or(null(), Weak::as_ptr); + let n: *const T = new.map_or(null(), Weak::as_ptr); + acquire(); + let result = match self + .ptr + .compare_exchange(c.cast_mut(), n.cast_mut(), success, failure) + { + Ok(before) => { + if !before.is_null() { + retire::<_, false>(before); + } + if let Some(weak_new) = new { + increment_weak_count(weak_new); + } + Ok(()) + } + Err(before) => unsafe { + if !before.is_null() { + Err(Some(clone_weak_from_raw(before))) + } else { + Err(None) + } + }, + }; + release(); + result + } +} + +impl Clone for AtomicWeak { + fn clone(&self) -> Self { + self.load(SeqCst).map_or(Self::default(), |weak| Self { + ptr: AtomicPtr::new(Weak::into_raw(weak).cast_mut()), + }) + } +} + +impl Default for AtomicWeak { + /// Creates an empty `AtomicWeak`. + fn default() -> Self { + Self { + ptr: AtomicPtr::default(), + } + } +} + +impl Drop for AtomicWeak { + fn drop(&mut self) { + let ptr = self.ptr.load(SeqCst); + if !ptr.is_null() { + retire::<_, false>(ptr); + } + } +} + +impl From> for AtomicWeak { + fn from(value: Weak) -> Self { + Self { + ptr: AtomicPtr::new(Weak::into_raw(value).cast_mut()), + } + } +} + +impl From>> for AtomicWeak { + fn from(value: Option>) -> Self { + Self { + ptr: AtomicPtr::new(value.map_or(null(), Weak::into_raw).cast_mut()), + } + } +} + +fn increment_weak_count(w: &Weak) { + _ = ManuallyDrop::new(w.clone()); +} + +unsafe fn clone_weak_from_raw(ptr: *const T) -> Weak { + ManuallyDrop::new(Weak::from_raw(ptr)).deref().clone() +} diff --git a/src/lib.rs b/src/lib.rs index 339e6fa..ca6306f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,7 @@ -pub mod atomics; +pub use atomics::AtomicArc; +pub use atomics::AtomicWeak; + +pub(crate) mod atomics; pub(crate) mod hyaline; pub(crate) mod statics; diff --git a/src/statics.rs b/src/statics.rs index fb470ce..e8f1534 100644 --- a/src/statics.rs +++ b/src/statics.rs @@ -1,7 +1,7 @@ use crate::hyaline::{Context, DeferredFn, Slot}; use std::cell::RefCell; use std::ptr::null_mut; -use std::sync::{Arc, OnceLock}; +use std::sync::{Arc, OnceLock, Weak}; use std::thread::AccessError; const SLOTS_PER_NODE: usize = 64; @@ -38,21 +38,29 @@ pub(crate) fn release() { } } -pub(crate) fn retire(raw_arc_ptr: *const T) { +pub(crate) fn retire(ptr: *const T) { if let Ok(slot) = get_slot() { get_context().retire( slot, DeferredFn { - ptr: raw_arc_ptr as *mut u8, - f: Box::new(|ptr| unsafe { - Arc::decrement_strong_count(ptr as *const T); + ptr: ptr as *mut u8, + f: Box::new(|p| unsafe { + if IS_STRONG { + drop(Arc::from_raw(p as *const T)); + } else { + drop(Weak::from_raw(p as *const T)); + }; }), }, ); } else { // The TLS key is being destroyed. This path is only safe during cleanup. unsafe { - Arc::decrement_strong_count(raw_arc_ptr); + if IS_STRONG { + drop(Arc::from_raw(ptr)); + } else { + drop(Weak::from_raw(ptr)); + }; } } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 7165622..b5552e6 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -1,5 +1,6 @@ -use aarc::atomics::AtomicArc; +use aarc::{AtomicArc, AtomicWeak}; use rand::random; +use std::collections::VecDeque; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::SeqCst; use std::sync::Arc; @@ -15,6 +16,7 @@ fn test_stack() { next: Option>, } + #[derive(Default)] struct Stack { top: AtomicArc, } @@ -49,9 +51,7 @@ fn test_stack() { } } - let stack = Stack { - top: AtomicArc::default(), - }; + let stack = Stack::default(); thread::scope(|s| { for _ in 0..THREADS_COUNT { @@ -82,59 +82,81 @@ fn test_stack() { } #[test] -fn test_sorted_linked_list() { +fn test_sorted_doubly_linked_list() { const THREADS_COUNT: usize = 5; const ITERS_PER_THREAD: usize = 10; + #[derive(Default)] struct ListNode { val: usize, - next: AtomicArc, + prev: AtomicWeak, + next: AtomicArc, + } + + struct LinkedList { + head: Arc, + } + + impl LinkedList { + fn insert_sorted(&self, val: usize) { + let mut curr_arc = self.head.clone(); + let mut next = curr_arc.next.load(SeqCst); + loop { + if next.is_none() || val < next.as_ref().unwrap().val { + let new = Arc::new(ListNode { + val, + prev: AtomicWeak::from(Arc::downgrade(&curr_arc)), + next: AtomicArc::from(next.clone()), + }); + match curr_arc + .next + .compare_exchange(next.as_ref(), Some(&new), SeqCst, SeqCst) + { + Ok(_) => { + if let Some(next_node) = next { + next_node.prev.store(Some(&Arc::downgrade(&new)), SeqCst); + } + break; + } + Err(actual_next) => next = actual_next, + } + } else { + curr_arc = next.unwrap(); + next = curr_arc.next.load(SeqCst); + } + } + } } - let head = AtomicArc::new(Some(ListNode { - val: 0, - next: AtomicArc::new(None), - })); + let list = LinkedList { + head: Arc::new(ListNode::default()), + }; thread::scope(|s| { for _ in 0..THREADS_COUNT { s.spawn(|| { for _ in 0..ITERS_PER_THREAD { - let val = random::(); - let mut curr_arc = head.load(SeqCst).unwrap(); - let mut next = curr_arc.next.load(SeqCst); - 'inner: loop { - if next.is_none() || val < next.as_ref().unwrap().val { - let new = Arc::new(ListNode { - val, - next: AtomicArc::from(next.clone()), - }); - match curr_arc.next.compare_exchange( - next.as_ref(), - Some(&new), - SeqCst, - SeqCst, - ) { - Ok(_) => break 'inner, - Err(actual_next) => next = actual_next, - } - } else { - curr_arc = next.unwrap(); - next = curr_arc.next.load(SeqCst); - } - } + list.insert_sorted(random::()); } }); } }); // Verify that no nodes were lost and that the list is in sorted order. - let mut i = 0; - let mut curr_arc = head.load(SeqCst).unwrap(); - while let Some(next_node) = curr_arc.next.load(SeqCst) { - assert!(curr_arc.val <= next_node.val); - curr_arc = next_node; - i += 1; + let mut stack = VecDeque::new(); + let mut curr_node = list.head.clone(); + while let Some(next_node) = curr_node.next.load(SeqCst) { + assert!(curr_node.val <= next_node.val); + stack.push_back(next_node.val); + curr_node = next_node; + } + assert_eq!(THREADS_COUNT * ITERS_PER_THREAD, stack.len()); + + // Check the weak pointers by iterating in reverse order and popping from the stack. + while let Some(prev_weak) = curr_node.prev.load(SeqCst) { + let prev_node = prev_weak.upgrade().unwrap(); + assert_eq!(stack.pop_back().unwrap(), curr_node.val); + curr_node = prev_node; } - assert_eq!(THREADS_COUNT * ITERS_PER_THREAD, i); + assert_eq!(stack.len(), 0); }