From 6466c29dec2b918afe4099dea835e853b0778f42 Mon Sep 17 00:00:00 2001 From: Wisha Wa Date: Mon, 13 Feb 2023 04:25:16 +0000 Subject: [PATCH] simplify creation of WakerArray and WakerVec and make waker_from_redirect_position safer --- Cargo.toml | 1 - src/utils/wakers/array/waker_array.rs | 46 ++++++++-------- src/utils/wakers/shared_arc.rs | 76 ++++++++++++++++----------- src/utils/wakers/vec/waker_vec.rs | 32 +++++------ 4 files changed, 78 insertions(+), 77 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7512913..25f411f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,6 @@ harness = false bitvec = { version = "1.0.1", default-features = false, features = ["alloc"] } futures-core = "0.3" pin-project = "1.0.8" -sptr = "0.3.2" [dev-dependencies] futures = "0.3.25" diff --git a/src/utils/wakers/array/waker_array.rs b/src/utils/wakers/array/waker_array.rs index 3ad3291..ac6def9 100644 --- a/src/utils/wakers/array/waker_array.rs +++ b/src/utils/wakers/array/waker_array.rs @@ -1,6 +1,6 @@ use core::array; use core::task::Waker; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, Weak}; use super::{ super::shared_arc::{waker_from_redirect_position, SharedArcContent}, @@ -22,30 +22,18 @@ struct WakerArrayInner { impl WakerArray { /// Create a new instance of `WakerArray`. pub(crate) fn new() -> Self { - let mut inner = Arc::new(WakerArrayInner { - readiness: Mutex::new(ReadinessArray::new()), - redirect: [std::ptr::null(); N], // We don't know the Arc's address yet so put null for now. - }); - let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address. - - // At this point the strong count is 2. Decrement it to 1. - // Each time we create/clone a Waker the count will be incremented by 1. - // So N Wakers -> count = N+1. - unsafe { Arc::decrement_strong_count(raw) } - - // Make redirect all point to the Arc itself. - Arc::get_mut(&mut inner).unwrap().redirect = [raw; N]; - - // Now the redirect array is complete. Time to create the actual Wakers. - let wakers = array::from_fn(|i| { - let data = inner.redirect.get(i).unwrap(); - unsafe { - waker_from_redirect_position::>( - data as *const *const WakerArrayInner, - ) + let inner = Arc::new_cyclic(|w| { + // `Weak::as_ptr` on a live Weak gives the same thing as `Arc::into_raw`. + let raw = Weak::as_ptr(w); + WakerArrayInner { + readiness: Mutex::new(ReadinessArray::new()), + redirect: [raw; N], } }); + let wakers = + array::from_fn(|i| unsafe { waker_from_redirect_position(Arc::clone(&inner), i) }); + Self { inner, wakers } } @@ -59,7 +47,8 @@ impl WakerArray { } } -impl SharedArcContent for WakerArrayInner { +#[deny(unsafe_op_in_unsafe_fn)] +unsafe impl SharedArcContent for WakerArrayInner { fn get_redirect_slice(&self) -> &[*const Self] { &self.redirect } @@ -84,7 +73,10 @@ mod tests { #[test] fn check_refcount() { let mut wa = WakerArray::<5>::new(); + + // Each waker holds 1 ref, and the combinator itself holds 1. assert_eq!(Arc::strong_count(&wa.inner), 6); + wa.wakers[4] = dummy_waker(); assert_eq!(Arc::strong_count(&wa.inner), 5); let cloned = wa.wakers[3].clone(); @@ -105,13 +97,17 @@ mod tests { let taken = std::mem::replace(&mut wa.wakers[2], dummy_waker()); assert_eq!(Arc::strong_count(&wa.inner), 4); taken.wake_by_ref(); - taken.wake_by_ref(); - taken.wake_by_ref(); + assert_eq!(Arc::strong_count(&wa.inner), 4); + taken.clone().wake(); assert_eq!(Arc::strong_count(&wa.inner), 4); taken.wake(); assert_eq!(Arc::strong_count(&wa.inner), 3); wa.wakers = array::from_fn(|_| dummy_waker()); assert_eq!(Arc::strong_count(&wa.inner), 1); + + let weak = Arc::downgrade(&wa.inner); + drop(wa); + assert_eq!(weak.strong_count(), 0); } } diff --git a/src/utils/wakers/shared_arc.rs b/src/utils/wakers/shared_arc.rs index 313fb41..1e9a555 100644 --- a/src/utils/wakers/shared_arc.rs +++ b/src/utils/wakers/shared_arc.rs @@ -43,43 +43,46 @@ use std::sync::Arc; /// A trait to be implemented on [super::WakerArray] and [super::WakerVec] for polymorphism. /// These are the type that goes in the Arc. They both contain the Readiness and the redirect array/vec. -pub(super) trait SharedArcContent { - /// Get the reference of the redirect slice. This is used to compute the index. +/// # Safety +/// The `get_redirect_slice` method MUST always return the same slice for the same self. +pub(super) unsafe trait SharedArcContent { + /// Get the reference of the redirect slice. fn get_redirect_slice(&self) -> &[*const Self]; + /// Called when the subfuture at the specified index should be polled. /// Should call `Readiness::set_ready`. fn wake_index(&self, index: usize); } /// Create one waker following the mechanism described in the [module][self] doc. -/// The following must be upheld for safety: -/// - `pointer` must points to a slot in the redirect array. -/// - that slot must contain a pointer obtained by `Arc::::into_raw`. -/// - the Arc must still be alive at the time this function is called. -/// The following should be upheld for correct behavior: -/// - calling `SharedArcContent::get_redirect_slice` on the content of the Arc should give the redirect array within which `pointer` points to. +/// For safety, the index MUST be within bounds of the slice returned by `A::get_redirect_slice()`. #[deny(unsafe_op_in_unsafe_fn)] pub(super) unsafe fn waker_from_redirect_position( - pointer: *const *const A, + arc: Arc, + index: usize, ) -> Waker { - /// Create a Waker from a type-erased pointer. - /// The pointer must satisfy the safety constraints listed in the wrapping function's documentation. - unsafe fn create_waker(pointer: *const ()) -> RawWaker { + // For `create_waker`, `wake_by_ref`, `wake`, and `drop_waker`, the following MUST be upheld for safety: + // - `pointer` must points to a slot in the redirect array. + // - that slot must contain a pointer of an Arc obtained from `Arc::::into_raw`. + // - that Arc must still be alive (strong count > 0) at the time the function is called. + + /// Clone a Waker from a type-erased pointer. + /// The pointer must satisfy the safety constraints listed in the code comments above. + unsafe fn clone_waker(pointer: *const ()) -> RawWaker { // Retype the type-erased pointer. let pointer = pointer as *const *const A; - // We're creating a new Waker, so we need to increment the count. - // SAFETY: The constraints listed for the wrapping function documentation means + // Increment the count so that the Arc won't die before this new Waker we're creating. + // SAFETY: The required constraints means // - `*pointer` is an `*const A` obtained from `Arc::::into_raw`. - // - the Arc is alive. - // So this operation is safe. + // - the Arc is alive right now. unsafe { Arc::increment_strong_count(*pointer) }; RawWaker::new(pointer as *const (), create_vtable::()) } /// Invoke `SharedArcContent::wake_index` with the index in the redirect slice where this pointer points to. - /// The pointer must satisfy the safety constraints listed in the wrapping function's documentation. + /// The pointer must satisfy the safety constraints listed in the code comments above. unsafe fn wake_by_ref(pointer: *const ()) { // Retype the type-erased pointer. let pointer = pointer as *const *const A; @@ -89,31 +92,28 @@ pub(super) unsafe fn waker_from_redirect_position( // SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`. let arc_content: &A = unsafe { &*raw }; - // Calculate the index. - // This is your familiar pointer math - // `item_address = array_address + (index * item_size)` - // rearranged to - // `index = (item_address - array_address) / item_size`. - let item_address = sptr::Strict::addr(pointer); - let redirect_slice_address = sptr::Strict::addr(arc_content.get_redirect_slice().as_ptr()); - let redirect_item_size = core::mem::size_of::<*const A>(); // the size of each item in the redirect slice - let index = (item_address - redirect_slice_address) / redirect_item_size; + let slice_start = arc_content.get_redirect_slice().as_ptr(); + + // We'll switch to [`sub_ptr`](https://github.com/rust-lang/rust/issues/95892) once that's stable. + let index = unsafe { pointer.offset_from(slice_start) } as usize; arc_content.wake_index(index); } - /// The pointer must satisfy the safety constraints listed in the wrapping function's documentation. + /// Drop the waker (and drop the shared Arc if other Wakers and the combinator have all been dropped). + /// The pointer must satisfy the safety constraints listed in the code comments above. unsafe fn drop_waker(pointer: *const ()) { // Retype the type-erased pointer. let pointer = pointer as *const *const A; // SAFETY: we are already requiring `pointer` to point to a slot in the redirect array. - let raw = unsafe { *pointer }; + let raw: *const A = unsafe { *pointer }; // SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`. unsafe { Arc::decrement_strong_count(raw) }; } - /// The pointer must satisfy the safety constraints listed in the wrapping function's documentation. + /// Like `wake_by_ref` but consumes the Waker. + /// The pointer must satisfy the safety constraints listed in the code comments above. unsafe fn wake(pointer: *const ()) { // SAFETY: we are already requiring the constraints of `wake_by_ref` and `drop_waker`. unsafe { @@ -124,13 +124,27 @@ pub(super) unsafe fn waker_from_redirect_position( fn create_vtable() -> &'static RawWakerVTable { &RawWakerVTable::new( - create_waker::, + clone_waker::, wake::, wake_by_ref::, drop_waker::, ) } + + let redirect_slice: &[*const A] = arc.get_redirect_slice(); + + debug_assert!(redirect_slice.len() > index); + + // SAFETY: we are already requiring that index be in bound of the slice. + let pointer: *const *const A = unsafe { redirect_slice.as_ptr().add(index) }; + // Type-erase the pointer because the Waker methods expect so. + let pointer = pointer as *const (); + + // We want to transfer management of the one strong count associated with `arc` to the Waker we're creating. + // That count should only be decremented when the Waker is dropped (by `drop_waker`). + core::mem::forget(arc); + // SAFETY: All our vtable functions adhere to the RawWakerVTable contract, // and we are already requiring that `pointer` is what our functions expect. - unsafe { Waker::from_raw(create_waker::(pointer as *const ())) } + unsafe { Waker::from_raw(RawWaker::new(pointer, create_vtable::())) } } diff --git a/src/utils/wakers/vec/waker_vec.rs b/src/utils/wakers/vec/waker_vec.rs index e4f993e..a84d9e7 100644 --- a/src/utils/wakers/vec/waker_vec.rs +++ b/src/utils/wakers/vec/waker_vec.rs @@ -1,5 +1,5 @@ use core::task::Waker; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, Weak}; use super::{ super::shared_arc::{waker_from_redirect_position, SharedArcContent}, @@ -21,27 +21,18 @@ struct WakerVecInner { impl WakerVec { /// Create a new instance of `WakerVec`. pub(crate) fn new(len: usize) -> Self { - let mut inner = Arc::new(WakerVecInner { - readiness: Mutex::new(ReadinessVec::new(len)), - redirect: Vec::new(), + let inner = Arc::new_cyclic(|w| { + // `Weak::as_ptr` on a live Weak gives the same thing as `Arc::into_raw`. + let raw = Weak::as_ptr(w); + WakerVecInner { + readiness: Mutex::new(ReadinessVec::new(len)), + redirect: vec![raw; len], + } }); - let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address. - - // At this point the strong count is 2. Decrement it to 1. - // Each time we create/clone a Waker the count will be incremented by 1. - // So N Wakers -> count = N+1. - unsafe { Arc::decrement_strong_count(raw) } - - // Make redirect all point to the Arc itself. - Arc::get_mut(&mut inner).unwrap().redirect = vec![raw; len]; // Now the redirect vec is complete. Time to create the actual Wakers. - let wakers = inner - .redirect - .iter() - .map(|data| unsafe { - waker_from_redirect_position::(data as *const *const WakerVecInner) - }) + let wakers = (0..len) + .map(|i| unsafe { waker_from_redirect_position(Arc::clone(&inner), i) }) .collect(); Self { inner, wakers } @@ -56,7 +47,8 @@ impl WakerVec { } } -impl SharedArcContent for WakerVecInner { +#[deny(unsafe_op_in_unsafe_fn)] +unsafe impl SharedArcContent for WakerVecInner { fn get_redirect_slice(&self) -> &[*const Self] { &self.redirect }