Skip to content

Commit

Permalink
Waker array optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
matheus-consoli committed Mar 20, 2024
1 parent 79903a6 commit 42c5856
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 10 deletions.
47 changes: 37 additions & 10 deletions src/utils/wakers/array/waker_array.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
use alloc::sync::Arc;
use alloc::sync::{Arc, Weak};
use core::array;
use core::task::Waker;
use std::sync::{Mutex, MutexGuard};

use crate::utils::wakers::shared_arc::{waker_from_redirec_position, SharedArcContent};

use super::{InlineWakerArray, ReadinessArray};

/// A collection of wakers which delegate to an in-line waker.
pub(crate) struct WakerArray<const N: usize> {
wakers: [Waker; N],
readiness: Arc<Mutex<ReadinessArray<N>>>,
inner: Arc<WakerArrayInner<N>>,
}

impl<const N: usize> WakerArray<N> {
/// Create a new instance of `WakerArray`.
pub(crate) fn new() -> Self {
let readiness = Arc::new(Mutex::new(ReadinessArray::new()));
Self {
wakers: array::from_fn(|i| {
Arc::new(InlineWakerArray::new(i, readiness.clone())).into()
}),
readiness,
}
let inner = Arc::new_cyclic(|w| {
let raw = Weak::as_ptr(w);
WakerArrayInner {
redirect: [raw; N],
readiness: Mutex::new(ReadinessArray::new()),
}
});
let wakers =
array::from_fn(|i| unsafe { waker_from_redirec_position(Arc::clone(&inner), i) });
Self { wakers, inner }
}

pub(crate) fn get(&self, index: usize) -> Option<&Waker> {
Expand All @@ -29,6 +34,28 @@ impl<const N: usize> WakerArray<N> {

/// Access the `Readiness`.
pub(crate) fn readiness(&mut self) -> MutexGuard<'_, ReadinessArray<N>> {
self.readiness.as_ref().lock().unwrap()
self.inner.readiness.lock().unwrap() // TODO: unwrap
}
}

struct WakerArrayInner<const N: usize> {
redirect: [*const Self; N],
readiness: Mutex<ReadinessArray<N>>,
}

unsafe impl<const N: usize> SharedArcContent for WakerArrayInner<N> {
fn get_redirect_slice(&self) -> &[*const Self] {
&self.redirect
}

fn wake_index(&self, index: usize) {
let mut readiness = self.readiness.lock().unwrap(); // TODO: unwrap
if !readiness.set_ready(index) {
readiness
.parent_waker()
.as_ref()
.expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?")
.wake_by_ref()
}
}
}
2 changes: 2 additions & 0 deletions src/utils/wakers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ mod array;
mod dummy;
#[cfg(feature = "alloc")]
mod vec;
// #[cfg(feature = "alloc")]
mod shared_arc;

#[cfg(all(test, feature = "alloc"))]
pub(crate) use dummy::DummyWaker;
Expand Down
59 changes: 59 additions & 0 deletions src/utils/wakers/shared_arc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use core::task::{RawWaker, RawWakerVTable, Waker};

use alloc::sync::Arc;
use bitvec::index;

pub(super) unsafe trait SharedArcContent {
/// Get the reference of the redirect slice
fn get_redirect_slice(&self) -> &[*const Self];
/// Called when the subfuture at the specified index should be polled
/// Should call `Readiness::set_ready`
fn wake_index(&self, index: usize);
}

pub(super) unsafe fn waker_from_redirec_position<A: SharedArcContent>(
arc: Arc<A>,
index: usize,
) -> Waker {
unsafe fn clone_waker<A: SharedArcContent>(pointer: *const ()) -> RawWaker {
let pointer = pointer as *const *const A;
unsafe { Arc::increment_strong_count(*pointer) };
RawWaker::new(pointer as *const (), create_vtable::<A>())
}

unsafe fn wake_by_ref<A: SharedArcContent>(pointer: *const ()) {
let pointer = pointer as *const *const A;
let raw: *const A = unsafe { *pointer };
let arc_content: &A = unsafe { &*raw };
let slice = arc_content.get_redirect_slice().as_ptr();
let index = unsafe { pointer.offset_from(slice) } as usize;
arc_content.wake_index(index);
}

unsafe fn drop_waker<A: SharedArcContent>(pointer: *const ()) {
let pointer = pointer as *const *const A;
unsafe { Arc::decrement_strong_count(*pointer) };
}

unsafe fn wake<A: SharedArcContent>(pointer: *const ()) {
unsafe {
wake_by_ref::<A>(pointer);
drop_waker::<A>(pointer);
}
}

fn create_vtable<A: SharedArcContent>() -> &'static RawWakerVTable {
&RawWakerVTable::new(
clone_waker::<A>,
wake::<A>,
wake_by_ref::<A>,
drop_waker::<A>,
)
}

let redirect = arc.get_redirect_slice();
let pointer = unsafe { redirect.as_ptr().add(index) } as *const ();
core::mem::forget(arc);

unsafe { Waker::from_raw(RawWaker::new(pointer, create_vtable::<A>())) }
}

0 comments on commit 42c5856

Please sign in to comment.