From 42c5856c32882c82181295778c6ed2e7c447af58 Mon Sep 17 00:00:00 2001 From: Matheus Consoli Date: Wed, 20 Mar 2024 12:36:41 -0300 Subject: [PATCH] Waker array optimization --- src/utils/wakers/array/waker_array.rs | 47 ++++++++++++++++----- src/utils/wakers/mod.rs | 2 + src/utils/wakers/shared_arc.rs | 59 +++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 10 deletions(-) create mode 100644 src/utils/wakers/shared_arc.rs diff --git a/src/utils/wakers/array/waker_array.rs b/src/utils/wakers/array/waker_array.rs index c6a3912..d4c1ff1 100644 --- a/src/utils/wakers/array/waker_array.rs +++ b/src/utils/wakers/array/waker_array.rs @@ -1,26 +1,31 @@ -use alloc::sync::Arc; +use alloc::sync::{Arc, Weak}; use core::array; use core::task::Waker; use std::sync::{Mutex, MutexGuard}; +use crate::utils::wakers::shared_arc::{waker_from_redirec_position, SharedArcContent}; + use super::{InlineWakerArray, ReadinessArray}; /// A collection of wakers which delegate to an in-line waker. pub(crate) struct WakerArray { wakers: [Waker; N], - readiness: Arc>>, + inner: Arc>, } impl WakerArray { /// Create a new instance of `WakerArray`. pub(crate) fn new() -> Self { - let readiness = Arc::new(Mutex::new(ReadinessArray::new())); - Self { - wakers: array::from_fn(|i| { - Arc::new(InlineWakerArray::new(i, readiness.clone())).into() - }), - readiness, - } + let inner = Arc::new_cyclic(|w| { + let raw = Weak::as_ptr(w); + WakerArrayInner { + redirect: [raw; N], + readiness: Mutex::new(ReadinessArray::new()), + } + }); + let wakers = + array::from_fn(|i| unsafe { waker_from_redirec_position(Arc::clone(&inner), i) }); + Self { wakers, inner } } pub(crate) fn get(&self, index: usize) -> Option<&Waker> { @@ -29,6 +34,28 @@ impl WakerArray { /// Access the `Readiness`. pub(crate) fn readiness(&mut self) -> MutexGuard<'_, ReadinessArray> { - self.readiness.as_ref().lock().unwrap() + self.inner.readiness.lock().unwrap() // TODO: unwrap + } +} + +struct WakerArrayInner { + redirect: [*const Self; N], + readiness: Mutex>, +} + +unsafe impl SharedArcContent for WakerArrayInner { + fn get_redirect_slice(&self) -> &[*const Self] { + &self.redirect + } + + fn wake_index(&self, index: usize) { + let mut readiness = self.readiness.lock().unwrap(); // TODO: unwrap + if !readiness.set_ready(index) { + readiness + .parent_waker() + .as_ref() + .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") + .wake_by_ref() + } } } diff --git a/src/utils/wakers/mod.rs b/src/utils/wakers/mod.rs index 089af6e..b391a41 100644 --- a/src/utils/wakers/mod.rs +++ b/src/utils/wakers/mod.rs @@ -3,6 +3,8 @@ mod array; mod dummy; #[cfg(feature = "alloc")] mod vec; +// #[cfg(feature = "alloc")] +mod shared_arc; #[cfg(all(test, feature = "alloc"))] pub(crate) use dummy::DummyWaker; diff --git a/src/utils/wakers/shared_arc.rs b/src/utils/wakers/shared_arc.rs new file mode 100644 index 0000000..c9ba3b4 --- /dev/null +++ b/src/utils/wakers/shared_arc.rs @@ -0,0 +1,59 @@ +use core::task::{RawWaker, RawWakerVTable, Waker}; + +use alloc::sync::Arc; +use bitvec::index; + +pub(super) unsafe trait SharedArcContent { + /// Get the reference of the redirect slice + fn get_redirect_slice(&self) -> &[*const Self]; + /// Called when the subfuture at the specified index should be polled + /// Should call `Readiness::set_ready` + fn wake_index(&self, index: usize); +} + +pub(super) unsafe fn waker_from_redirec_position( + arc: Arc, + index: usize, +) -> Waker { + unsafe fn clone_waker(pointer: *const ()) -> RawWaker { + let pointer = pointer as *const *const A; + unsafe { Arc::increment_strong_count(*pointer) }; + RawWaker::new(pointer as *const (), create_vtable::()) + } + + unsafe fn wake_by_ref(pointer: *const ()) { + let pointer = pointer as *const *const A; + let raw: *const A = unsafe { *pointer }; + let arc_content: &A = unsafe { &*raw }; + let slice = arc_content.get_redirect_slice().as_ptr(); + let index = unsafe { pointer.offset_from(slice) } as usize; + arc_content.wake_index(index); + } + + unsafe fn drop_waker(pointer: *const ()) { + let pointer = pointer as *const *const A; + unsafe { Arc::decrement_strong_count(*pointer) }; + } + + unsafe fn wake(pointer: *const ()) { + unsafe { + wake_by_ref::(pointer); + drop_waker::(pointer); + } + } + + fn create_vtable() -> &'static RawWakerVTable { + &RawWakerVTable::new( + clone_waker::, + wake::, + wake_by_ref::, + drop_waker::, + ) + } + + let redirect = arc.get_redirect_slice(); + let pointer = unsafe { redirect.as_ptr().add(index) } as *const (); + core::mem::forget(arc); + + unsafe { Waker::from_raw(RawWaker::new(pointer, create_vtable::())) } +}