diff --git a/src/utils/wakers/array/waker_array.rs b/src/utils/wakers/array/waker_array.rs index c6a3912..0bde2d6 100644 --- a/src/utils/wakers/array/waker_array.rs +++ b/src/utils/wakers/array/waker_array.rs @@ -1,26 +1,31 @@ -use alloc::sync::Arc; +use alloc::sync::{Arc, Weak}; use core::array; use core::task::Waker; use std::sync::{Mutex, MutexGuard}; -use super::{InlineWakerArray, ReadinessArray}; +use crate::utils::wakers::shared_arc::{waker_from_redirec_position, SharedArcContent}; + +use super::ReadinessArray; /// A collection of wakers which delegate to an in-line waker. pub(crate) struct WakerArray { wakers: [Waker; N], - readiness: Arc>>, + inner: Arc>, } impl WakerArray { /// Create a new instance of `WakerArray`. pub(crate) fn new() -> Self { - let readiness = Arc::new(Mutex::new(ReadinessArray::new())); - Self { - wakers: array::from_fn(|i| { - Arc::new(InlineWakerArray::new(i, readiness.clone())).into() - }), - readiness, - } + let inner = Arc::new_cyclic(|w| { + let raw = Weak::as_ptr(w); + WakerArrayInner { + redirect: [raw; N], + readiness: Mutex::new(ReadinessArray::new()), + } + }); + let wakers = + array::from_fn(|i| unsafe { waker_from_redirec_position(Arc::clone(&inner), i) }); + Self { wakers, inner } } pub(crate) fn get(&self, index: usize) -> Option<&Waker> { @@ -29,6 +34,28 @@ impl WakerArray { /// Access the `Readiness`. pub(crate) fn readiness(&mut self) -> MutexGuard<'_, ReadinessArray> { - self.readiness.as_ref().lock().unwrap() + self.inner.readiness.lock().unwrap() // TODO: unwrap + } +} + +struct WakerArrayInner { + redirect: [*const Self; N], + readiness: Mutex>, +} + +unsafe impl SharedArcContent for WakerArrayInner { + fn get_redirect_slice(&self) -> &[*const Self] { + &self.redirect + } + + fn wake_index(&self, index: usize) { + let mut readiness = self.readiness.lock().unwrap(); // TODO: unwrap + if !readiness.set_ready(index) { + readiness + .parent_waker() + .as_ref() + .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") + .wake_by_ref() + } } } diff --git a/src/utils/wakers/mod.rs b/src/utils/wakers/mod.rs index 089af6e..b391a41 100644 --- a/src/utils/wakers/mod.rs +++ b/src/utils/wakers/mod.rs @@ -3,6 +3,8 @@ mod array; mod dummy; #[cfg(feature = "alloc")] mod vec; +// #[cfg(feature = "alloc")] +mod shared_arc; #[cfg(all(test, feature = "alloc"))] pub(crate) use dummy::DummyWaker; diff --git a/src/utils/wakers/shared_arc.rs b/src/utils/wakers/shared_arc.rs new file mode 100644 index 0000000..84f907d --- /dev/null +++ b/src/utils/wakers/shared_arc.rs @@ -0,0 +1,58 @@ +use core::task::{RawWaker, RawWakerVTable, Waker}; + +use alloc::sync::Arc; + +pub(super) unsafe trait SharedArcContent { + /// Get the reference of the redirect slice + fn get_redirect_slice(&self) -> &[*const Self]; + /// Called when the subfuture at the specified index should be polled + /// Should call `Readiness::set_ready` + fn wake_index(&self, index: usize); +} + +pub(super) unsafe fn waker_from_redirec_position( + arc: Arc, + index: usize, +) -> Waker { + unsafe fn clone_waker(pointer: *const ()) -> RawWaker { + let pointer = pointer as *const *const A; + unsafe { Arc::increment_strong_count(*pointer) }; + RawWaker::new(pointer as *const (), create_vtable::()) + } + + unsafe fn wake_by_ref(pointer: *const ()) { + let pointer = pointer as *const *const A; + let raw: *const A = unsafe { *pointer }; + let arc_content: &A = unsafe { &*raw }; + let slice = arc_content.get_redirect_slice().as_ptr(); + let index = unsafe { pointer.offset_from(slice) } as usize; + arc_content.wake_index(index); + } + + unsafe fn drop_waker(pointer: *const ()) { + let pointer = pointer as *const *const A; + unsafe { Arc::decrement_strong_count(*pointer) }; + } + + unsafe fn wake(pointer: *const ()) { + unsafe { + wake_by_ref::(pointer); + drop_waker::(pointer); + } + } + + fn create_vtable() -> &'static RawWakerVTable { + &RawWakerVTable::new( + clone_waker::, + wake::, + wake_by_ref::, + drop_waker::, + ) + } + + let redirect = arc.get_redirect_slice(); + let pointer = unsafe { redirect.as_ptr().add(index) } as *const (); + core::mem::forget(arc); + + unsafe { Waker::from_raw(RawWaker::new(pointer, create_vtable::())) } +} diff --git a/src/utils/wakers/vec/waker_vec.rs b/src/utils/wakers/vec/waker_vec.rs index 18d289e..6ce977e 100644 --- a/src/utils/wakers/vec/waker_vec.rs +++ b/src/utils/wakers/vec/waker_vec.rs @@ -1,18 +1,32 @@ #[cfg(all(feature = "alloc", not(feature = "std")))] use alloc::vec::Vec; -use alloc::sync::Arc; +use alloc::sync::{Arc, Weak}; use core::task::Waker; use std::sync::{Mutex, MutexGuard}; -use super::{InlineWakerVec, ReadinessVec}; +use crate::utils::wakers::shared_arc::{waker_from_redirec_position, SharedArcContent}; + +use super::ReadinessVec; /// A collection of wakers which delegate to an in-line waker. pub(crate) struct WakerVec { wakers: Vec, - readiness: Arc>, + inner: Arc, +} + +struct WakerVecInner { + redirect: Mutex>, + readiness: Mutex, } +#[derive(Clone, Copy)] +#[repr(transparent)] +struct WakerVecInnerPtr(*const WakerVecInner); + +unsafe impl Send for WakerVecInnerPtr {} +unsafe impl Sync for WakerVecInnerPtr {} + impl Default for WakerVec { fn default() -> Self { Self::new(0) @@ -22,11 +36,17 @@ impl Default for WakerVec { impl WakerVec { /// Create a new instance of `WakerVec`. pub(crate) fn new(len: usize) -> Self { - let readiness = Arc::new(Mutex::new(ReadinessVec::new(len))); + let inner = Arc::new_cyclic(|weak| { + let raw = Weak::as_ptr(weak); + WakerVecInner { + redirect: Mutex::new(vec![WakerVecInnerPtr(raw); len]), + readiness: Mutex::new(ReadinessVec::new(len)), + } + }); let wakers = (0..len) - .map(|i| Arc::new(InlineWakerVec::new(i, readiness.clone())).into()) + .map(|i| unsafe { waker_from_redirec_position(Arc::clone(&inner), i) }) .collect(); - Self { wakers, readiness } + Self { wakers, inner } } pub(crate) fn get(&self, index: usize) -> Option<&Waker> { @@ -35,7 +55,7 @@ impl WakerVec { /// Access the `Readiness`. pub(crate) fn readiness(&self) -> MutexGuard<'_, ReadinessVec> { - self.readiness.lock().unwrap() + self.inner.readiness.lock().unwrap() } /// Resize the `WakerVec` to the new size. @@ -44,13 +64,38 @@ impl WakerVec { // Which means the first position is the current length, and every position // beyond that is incremented by 1. let mut index = self.wakers.len(); + + let ptr = WakerVecInnerPtr(Arc::as_ptr(&self.inner)); + let mut lock = self.inner.redirect.lock().unwrap(); + lock.resize_with(len, || ptr); + drop(lock); + self.wakers.resize_with(len, || { - let ret = Arc::new(InlineWakerVec::new(index, self.readiness.clone())).into(); + let ret = unsafe { waker_from_redirec_position(Arc::clone(&self.inner), index) }; index += 1; ret }); - let mut readiness = self.readiness.lock().unwrap(); + let mut readiness = self.inner.readiness.lock().unwrap(); readiness.resize(len); } } + +unsafe impl SharedArcContent for WakerVecInner { + fn get_redirect_slice(&self) -> &[*const Self] { + let slice = self.redirect.lock().unwrap(); + let slice = slice.as_slice(); + unsafe { core::mem::transmute(slice) } + } + + fn wake_index(&self, index: usize) { + let mut readiness = self.readiness.lock().unwrap(); + if !readiness.set_ready(index) { + readiness + .parent_waker() + .as_ref() + .expect("msg") // todo message + .wake_by_ref() + } + } +}