diff --git a/src/future/join/array.rs b/src/future/join/array.rs index 9a6ec0c..cdedcc4 100644 --- a/src/future/join/array.rs +++ b/src/future/join/array.rs @@ -168,11 +168,10 @@ where #[cfg(test)] mod test { use super::*; - use crate::utils::DummyWaker; + use crate::utils::dummy_waker; use std::future; use std::future::Future; - use std::sync::Arc; use std::task::Context; #[test] @@ -189,7 +188,7 @@ mod test { assert_eq!(format!("{:?}", fut), "[Pending, Pending]"); let mut fut = Pin::new(&mut fut); - let waker = Arc::new(DummyWaker()).into(); + let waker = dummy_waker(); let mut cx = Context::from_waker(&waker); let _ = fut.as_mut().poll(&mut cx); assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]"); diff --git a/src/future/join/vec.rs b/src/future/join/vec.rs index 1817de6..31d643b 100644 --- a/src/future/join/vec.rs +++ b/src/future/join/vec.rs @@ -170,11 +170,10 @@ where #[cfg(test)] mod test { use super::*; - use crate::utils::DummyWaker; + use crate::utils::dummy_waker; use std::future; use std::future::Future; - use std::sync::Arc; use std::task::Context; #[test] @@ -191,7 +190,7 @@ mod test { assert_eq!(format!("{:?}", fut), "[Pending, Pending]"); let mut fut = Pin::new(&mut fut); - let waker = Arc::new(DummyWaker()).into(); + let waker = dummy_waker(); let mut cx = Context::from_waker(&waker); let _ = fut.as_mut().poll(&mut cx); assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]"); diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ac5d38e..ed9932a 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -16,7 +16,7 @@ pub(crate) use tuple::{gen_conditions, tuple_len}; pub(crate) use wakers::{WakerArray, WakerVec}; #[cfg(test)] -pub(crate) use wakers::DummyWaker; +pub(crate) use wakers::dummy_waker; #[cfg(test)] pub(crate) mod channel; diff --git a/src/utils/wakers/array/mod.rs b/src/utils/wakers/array/mod.rs index 7303d9c..20a9392 100644 --- a/src/utils/wakers/array/mod.rs +++ b/src/utils/wakers/array/mod.rs @@ -1,7 +1,5 @@ mod readiness; -mod waker; mod waker_array; pub(crate) use readiness::ReadinessArray; -pub(crate) use waker::InlineWakerArray; pub(crate) use waker_array::WakerArray; diff --git a/src/utils/wakers/array/waker.rs b/src/utils/wakers/array/waker.rs deleted file mode 100644 index 960dff0..0000000 --- a/src/utils/wakers/array/waker.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::sync::{Arc, Mutex}; -use std::task::Wake; - -use super::ReadinessArray; - -/// An efficient waker which delegates wake events. -#[derive(Debug, Clone)] -pub(crate) struct InlineWakerArray { - pub(crate) id: usize, - pub(crate) readiness: Arc>>, -} - -impl InlineWakerArray { - /// Create a new instance of `InlineWaker`. - pub(crate) fn new(id: usize, readiness: Arc>>) -> Self { - Self { id, readiness } - } -} - -impl Wake for InlineWakerArray { - fn wake(self: std::sync::Arc) { - let mut readiness = self.readiness.lock().unwrap(); - if !readiness.set_ready(self.id) { - readiness - .parent_waker() - .as_mut() - .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") - .wake_by_ref() - } - } -} diff --git a/src/utils/wakers/array/waker_array.rs b/src/utils/wakers/array/waker_array.rs index 293155b..ac6def9 100644 --- a/src/utils/wakers/array/waker_array.rs +++ b/src/utils/wakers/array/waker_array.rs @@ -1,26 +1,40 @@ use core::array; -use std::sync::Arc; -use std::sync::Mutex; -use std::task::Waker; +use core::task::Waker; +use std::sync::{Arc, Mutex, Weak}; -use super::{InlineWakerArray, ReadinessArray}; +use super::{ + super::shared_arc::{waker_from_redirect_position, SharedArcContent}, + ReadinessArray, +}; /// A collection of wakers which delegate to an in-line waker. pub(crate) struct WakerArray { wakers: [Waker; N], - readiness: Arc>>, + inner: Arc>, +} + +/// See [super::super::shared_arc] for how this works. +struct WakerArrayInner { + redirect: [*const Self; N], + readiness: Mutex>, } impl WakerArray { /// Create a new instance of `WakerArray`. pub(crate) fn new() -> Self { - let readiness = Arc::new(Mutex::new(ReadinessArray::new())); - Self { - wakers: array::from_fn(|i| { - Arc::new(InlineWakerArray::new(i, readiness.clone())).into() - }), - readiness, - } + let inner = Arc::new_cyclic(|w| { + // `Weak::as_ptr` on a live Weak gives the same thing as `Arc::into_raw`. + let raw = Weak::as_ptr(w); + WakerArrayInner { + readiness: Mutex::new(ReadinessArray::new()), + redirect: [raw; N], + } + }); + + let wakers = + array::from_fn(|i| unsafe { waker_from_redirect_position(Arc::clone(&inner), i) }); + + Self { inner, wakers } } pub(crate) fn get(&self, index: usize) -> Option<&Waker> { @@ -29,6 +43,71 @@ impl WakerArray { /// Access the `Readiness`. pub(crate) fn readiness(&self) -> &Mutex> { - self.readiness.as_ref() + &self.inner.readiness + } +} + +#[deny(unsafe_op_in_unsafe_fn)] +unsafe impl SharedArcContent for WakerArrayInner { + fn get_redirect_slice(&self) -> &[*const Self] { + &self.redirect + } + + fn wake_index(&self, index: usize) { + let mut readiness = self.readiness.lock().unwrap(); + if !readiness.set_ready(index) { + readiness + .parent_waker() + .as_ref() + .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") + .wake_by_ref(); + } + } +} + +#[cfg(test)] +mod tests { + use crate::utils::wakers::dummy_waker; + + use super::*; + #[test] + fn check_refcount() { + let mut wa = WakerArray::<5>::new(); + + // Each waker holds 1 ref, and the combinator itself holds 1. + assert_eq!(Arc::strong_count(&wa.inner), 6); + + wa.wakers[4] = dummy_waker(); + assert_eq!(Arc::strong_count(&wa.inner), 5); + let cloned = wa.wakers[3].clone(); + assert_eq!(Arc::strong_count(&wa.inner), 6); + wa.wakers[3] = wa.wakers[4].clone(); + assert_eq!(Arc::strong_count(&wa.inner), 5); + drop(cloned); + assert_eq!(Arc::strong_count(&wa.inner), 4); + + wa.wakers[0].wake_by_ref(); + wa.wakers[0].wake_by_ref(); + wa.wakers[0].wake_by_ref(); + assert_eq!(Arc::strong_count(&wa.inner), 4); + + wa.wakers[0] = wa.wakers[1].clone(); + assert_eq!(Arc::strong_count(&wa.inner), 4); + + let taken = std::mem::replace(&mut wa.wakers[2], dummy_waker()); + assert_eq!(Arc::strong_count(&wa.inner), 4); + taken.wake_by_ref(); + assert_eq!(Arc::strong_count(&wa.inner), 4); + taken.clone().wake(); + assert_eq!(Arc::strong_count(&wa.inner), 4); + taken.wake(); + assert_eq!(Arc::strong_count(&wa.inner), 3); + + wa.wakers = array::from_fn(|_| dummy_waker()); + assert_eq!(Arc::strong_count(&wa.inner), 1); + + let weak = Arc::downgrade(&wa.inner); + drop(wa); + assert_eq!(weak.strong_count(), 0); } } diff --git a/src/utils/wakers/dummy.rs b/src/utils/wakers/dummy.rs index 0f454b0..b60c996 100644 --- a/src/utils/wakers/dummy.rs +++ b/src/utils/wakers/dummy.rs @@ -1,6 +1,16 @@ -use std::{sync::Arc, task::Wake}; +use core::task::{RawWaker, RawWakerVTable, Waker}; -pub(crate) struct DummyWaker(); -impl Wake for DummyWaker { - fn wake(self: Arc) {} +/// A Waker that doesn't do anything. +pub(crate) fn dummy_waker() -> Waker { + fn new_raw_waker() -> RawWaker { + unsafe fn no_op(_data: *const ()) {} + unsafe fn clone(_data: *const ()) -> RawWaker { + new_raw_waker() + } + RawWaker::new( + core::ptr::null() as *const usize as *const (), + &RawWakerVTable::new(clone, no_op, no_op, no_op), + ) + } + unsafe { Waker::from_raw(new_raw_waker()) } } diff --git a/src/utils/wakers/mod.rs b/src/utils/wakers/mod.rs index d5c7f1d..72e0815 100644 --- a/src/utils/wakers/mod.rs +++ b/src/utils/wakers/mod.rs @@ -1,10 +1,20 @@ +//! Wakers that track when they are woken. +//! +//! By tracking which subfutures have woken, we can avoid having to re-poll N subfutures every time. +//! This tracking is done by a [ReadinessArray]/[ReadinessVec]. These store the indexes of the subfutures that have woken. +//! Each subfuture are given a Waker when polled. +//! This waker must know the index of its corresponding subfuture so that it can update Readiness correctly. +//! + mod array; +mod shared_arc; +mod vec; + #[cfg(test)] mod dummy; -mod vec; #[cfg(test)] -pub(crate) use dummy::DummyWaker; +pub(crate) use dummy::dummy_waker; pub(crate) use array::*; pub(crate) use vec::*; diff --git a/src/utils/wakers/shared_arc.rs b/src/utils/wakers/shared_arc.rs new file mode 100644 index 0000000..1e9a555 --- /dev/null +++ b/src/utils/wakers/shared_arc.rs @@ -0,0 +1,150 @@ +//! To save on allocations, we avoid making a separate Arc Waker for every subfuture. +//! Rather, we have all N Wakers share a single Arc, and use a "redirect" mechanism to allow different wakers to be distinguished. +//! The mechanism works as follows. +//! The Arc contains 2 things: +//! - the Readiness structure ([ReadinessArray][super::array::ReadinessArray] / [ReadinessVec][super::vec::ReadinessVec]) +//! - the redirect array. +//! The redirect array contains N repeated copies of the pointer to the Arc itself (obtained by `Arc::into_raw`). +//! The Waker for the `i`th subfuture points to the `i`th item in the redirect array. +//! (i.e. the Waker pointer is `*const *const A` where `A` is the type of the item in the Arc) +//! When the Waker is woken, we deref it twice (giving reference to the content of the Arc), +//! and compare it to the address of the redirect slice. +//! The difference tells us the index of the waker. We can then record this woken index in the Readiness. +//! +//! ```text +//! ┌───────────────────────────┬──────────────┬──────────────┐ +//! │ │ │ │ +//! │ / ┌─────────────┬──────┼───────┬──────┼───────┬──────┼───────┬─────┐ \ +//! ▼ / │ │ │ │ │ │ │ │ │ \ +//! Arc < │ Readiness │ redirect[0] │ redirect[1] │ redirect[2] │ ... │ > +//! ▲ \ │ │ │ │ │ │ / +//! │ \ └─────────────┴──────▲───────┴──────▲───────┴──────▲───────┴─────┘ / +//! │ │ │ │ +//! └─┐ ┌───────────────┘ │ │ +//! │ │ │ │ +//! │ │ ┌──────────────────┘ │ +//! │ │ │ │ +//! │ │ │ ┌─────────────────────┘ +//! │ │ │ │ +//! │ │ │ │ +//! ┌────┼────┬────┼──────┬────┼──────┬────┼──────┬─────┐ +//! │ │ │ │ │ │ │ │ │ │ +//! │ │ wakers[0] │ wakers[1] │ wakers[2] │ ... │ +//! │ │ │ │ │ │ +//! └─────────┴───────────┴───────────┴───────────┴─────┘ +//! ``` + +// TODO: Right now each waker gets its own redirect slot. +// We can save space by making size_of::<*const _>() wakers share the same slot. +// With such change, in 64-bit system, the redirect array/vec would only need ⌈N/8⌉ slots instead of N. + +use core::task::{RawWaker, RawWakerVTable, Waker}; +use std::sync::Arc; + +/// A trait to be implemented on [super::WakerArray] and [super::WakerVec] for polymorphism. +/// These are the type that goes in the Arc. They both contain the Readiness and the redirect array/vec. +/// # Safety +/// The `get_redirect_slice` method MUST always return the same slice for the same self. +pub(super) unsafe trait SharedArcContent { + /// Get the reference of the redirect slice. + fn get_redirect_slice(&self) -> &[*const Self]; + + /// Called when the subfuture at the specified index should be polled. + /// Should call `Readiness::set_ready`. + fn wake_index(&self, index: usize); +} + +/// Create one waker following the mechanism described in the [module][self] doc. +/// For safety, the index MUST be within bounds of the slice returned by `A::get_redirect_slice()`. +#[deny(unsafe_op_in_unsafe_fn)] +pub(super) unsafe fn waker_from_redirect_position( + arc: Arc, + index: usize, +) -> Waker { + // For `create_waker`, `wake_by_ref`, `wake`, and `drop_waker`, the following MUST be upheld for safety: + // - `pointer` must points to a slot in the redirect array. + // - that slot must contain a pointer of an Arc obtained from `Arc::::into_raw`. + // - that Arc must still be alive (strong count > 0) at the time the function is called. + + /// Clone a Waker from a type-erased pointer. + /// The pointer must satisfy the safety constraints listed in the code comments above. + unsafe fn clone_waker(pointer: *const ()) -> RawWaker { + // Retype the type-erased pointer. + let pointer = pointer as *const *const A; + + // Increment the count so that the Arc won't die before this new Waker we're creating. + // SAFETY: The required constraints means + // - `*pointer` is an `*const A` obtained from `Arc::::into_raw`. + // - the Arc is alive right now. + unsafe { Arc::increment_strong_count(*pointer) }; + + RawWaker::new(pointer as *const (), create_vtable::()) + } + + /// Invoke `SharedArcContent::wake_index` with the index in the redirect slice where this pointer points to. + /// The pointer must satisfy the safety constraints listed in the code comments above. + unsafe fn wake_by_ref(pointer: *const ()) { + // Retype the type-erased pointer. + let pointer = pointer as *const *const A; + + // SAFETY: we are already requiring `pointer` to point to a slot in the redirect array. + let raw: *const A = unsafe { *pointer }; + // SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`. + let arc_content: &A = unsafe { &*raw }; + + let slice_start = arc_content.get_redirect_slice().as_ptr(); + + // We'll switch to [`sub_ptr`](https://github.com/rust-lang/rust/issues/95892) once that's stable. + let index = unsafe { pointer.offset_from(slice_start) } as usize; + + arc_content.wake_index(index); + } + + /// Drop the waker (and drop the shared Arc if other Wakers and the combinator have all been dropped). + /// The pointer must satisfy the safety constraints listed in the code comments above. + unsafe fn drop_waker(pointer: *const ()) { + // Retype the type-erased pointer. + let pointer = pointer as *const *const A; + + // SAFETY: we are already requiring `pointer` to point to a slot in the redirect array. + let raw: *const A = unsafe { *pointer }; + // SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`. + unsafe { Arc::decrement_strong_count(raw) }; + } + + /// Like `wake_by_ref` but consumes the Waker. + /// The pointer must satisfy the safety constraints listed in the code comments above. + unsafe fn wake(pointer: *const ()) { + // SAFETY: we are already requiring the constraints of `wake_by_ref` and `drop_waker`. + unsafe { + wake_by_ref::(pointer); + drop_waker::(pointer); + } + } + + fn create_vtable() -> &'static RawWakerVTable { + &RawWakerVTable::new( + clone_waker::, + wake::, + wake_by_ref::, + drop_waker::, + ) + } + + let redirect_slice: &[*const A] = arc.get_redirect_slice(); + + debug_assert!(redirect_slice.len() > index); + + // SAFETY: we are already requiring that index be in bound of the slice. + let pointer: *const *const A = unsafe { redirect_slice.as_ptr().add(index) }; + // Type-erase the pointer because the Waker methods expect so. + let pointer = pointer as *const (); + + // We want to transfer management of the one strong count associated with `arc` to the Waker we're creating. + // That count should only be decremented when the Waker is dropped (by `drop_waker`). + core::mem::forget(arc); + + // SAFETY: All our vtable functions adhere to the RawWakerVTable contract, + // and we are already requiring that `pointer` is what our functions expect. + unsafe { Waker::from_raw(RawWaker::new(pointer, create_vtable::())) } +} diff --git a/src/utils/wakers/vec/mod.rs b/src/utils/wakers/vec/mod.rs index e002fd3..7f064fb 100644 --- a/src/utils/wakers/vec/mod.rs +++ b/src/utils/wakers/vec/mod.rs @@ -1,7 +1,5 @@ mod readiness; -mod waker; mod waker_vec; pub(crate) use readiness::ReadinessVec; -pub(crate) use waker::InlineWakerVec; pub(crate) use waker_vec::WakerVec; diff --git a/src/utils/wakers/vec/waker.rs b/src/utils/wakers/vec/waker.rs deleted file mode 100644 index cfd12ad..0000000 --- a/src/utils/wakers/vec/waker.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::sync::{Arc, Mutex}; -use std::task::Wake; - -use super::ReadinessVec; - -/// An efficient waker which delegates wake events. -#[derive(Debug, Clone)] -pub(crate) struct InlineWakerVec { - pub(crate) id: usize, - pub(crate) readiness: Arc>, -} - -impl InlineWakerVec { - /// Create a new instance of `InlineWaker`. - pub(crate) fn new(id: usize, readiness: Arc>) -> Self { - Self { id, readiness } - } -} - -impl Wake for InlineWakerVec { - fn wake(self: std::sync::Arc) { - let mut readiness = self.readiness.lock().unwrap(); - if !readiness.set_ready(self.id) { - readiness - .parent_waker() - .as_mut() - .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") - .wake_by_ref() - } - } -} diff --git a/src/utils/wakers/vec/waker_vec.rs b/src/utils/wakers/vec/waker_vec.rs index e317f0b..a84d9e7 100644 --- a/src/utils/wakers/vec/waker_vec.rs +++ b/src/utils/wakers/vec/waker_vec.rs @@ -1,31 +1,66 @@ -use std::sync::Arc; -use std::sync::Mutex; -use std::task::Waker; +use core::task::Waker; +use std::sync::{Arc, Mutex, Weak}; -use super::{InlineWakerVec, ReadinessVec}; +use super::{ + super::shared_arc::{waker_from_redirect_position, SharedArcContent}, + ReadinessVec, +}; -/// A collection of wakers which delegate to an in-line waker. +/// A collection of wakers sharing the same allocation. pub(crate) struct WakerVec { wakers: Vec, - readiness: Arc>, + inner: Arc, +} + +/// See [super::super::shared_arc] for how this works. +struct WakerVecInner { + redirect: Vec<*const Self>, + readiness: Mutex, } impl WakerVec { /// Create a new instance of `WakerVec`. pub(crate) fn new(len: usize) -> Self { - let readiness = Arc::new(Mutex::new(ReadinessVec::new(len))); + let inner = Arc::new_cyclic(|w| { + // `Weak::as_ptr` on a live Weak gives the same thing as `Arc::into_raw`. + let raw = Weak::as_ptr(w); + WakerVecInner { + readiness: Mutex::new(ReadinessVec::new(len)), + redirect: vec![raw; len], + } + }); + + // Now the redirect vec is complete. Time to create the actual Wakers. let wakers = (0..len) - .map(|i| Arc::new(InlineWakerVec::new(i, readiness.clone())).into()) + .map(|i| unsafe { waker_from_redirect_position(Arc::clone(&inner), i) }) .collect(); - Self { wakers, readiness } + + Self { inner, wakers } } pub(crate) fn get(&self, index: usize) -> Option<&Waker> { self.wakers.get(index) } - /// Access the `Readiness`. pub(crate) fn readiness(&self) -> &Mutex { - self.readiness.as_ref() + &self.inner.readiness + } +} + +#[deny(unsafe_op_in_unsafe_fn)] +unsafe impl SharedArcContent for WakerVecInner { + fn get_redirect_slice(&self) -> &[*const Self] { + &self.redirect + } + + fn wake_index(&self, index: usize) { + let mut readiness = self.readiness.lock().unwrap(); + if !readiness.set_ready(index) { + readiness + .parent_waker() + .as_ref() + .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") + .wake_by_ref(); + } } }