Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a single Arc shared between Wakers in WakerArray and WakerVec #118

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/future/join/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,10 @@ where
#[cfg(test)]
mod test {
use super::*;
use crate::utils::DummyWaker;
use crate::utils::dummy_waker;

use std::future;
use std::future::Future;
use std::sync::Arc;
use std::task::Context;

#[test]
Expand All @@ -189,7 +188,7 @@ mod test {
assert_eq!(format!("{:?}", fut), "[Pending, Pending]");
let mut fut = Pin::new(&mut fut);

let waker = Arc::new(DummyWaker()).into();
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let _ = fut.as_mut().poll(&mut cx);
assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]");
Expand Down
5 changes: 2 additions & 3 deletions src/future/join/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,10 @@ where
#[cfg(test)]
mod test {
use super::*;
use crate::utils::DummyWaker;
use crate::utils::dummy_waker;

use std::future;
use std::future::Future;
use std::sync::Arc;
use std::task::Context;

#[test]
Expand All @@ -191,7 +190,7 @@ mod test {
assert_eq!(format!("{:?}", fut), "[Pending, Pending]");
let mut fut = Pin::new(&mut fut);

let waker = Arc::new(DummyWaker()).into();
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let _ = fut.as_mut().poll(&mut cx);
assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]");
Expand Down
2 changes: 1 addition & 1 deletion src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub(crate) use tuple::{gen_conditions, tuple_len};
pub(crate) use wakers::{WakerArray, WakerVec};

#[cfg(test)]
pub(crate) use wakers::DummyWaker;
pub(crate) use wakers::dummy_waker;

#[cfg(test)]
pub(crate) mod channel;
2 changes: 0 additions & 2 deletions src/utils/wakers/array/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
mod readiness;
mod waker;
mod waker_array;

pub(crate) use readiness::ReadinessArray;
pub(crate) use waker::InlineWakerArray;
pub(crate) use waker_array::WakerArray;
31 changes: 0 additions & 31 deletions src/utils/wakers/array/waker.rs

This file was deleted.

109 changes: 96 additions & 13 deletions src/utils/wakers/array/waker_array.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,52 @@
use core::array;
use std::sync::Arc;
use std::sync::Mutex;
use std::task::Waker;
use core::task::Waker;
use std::sync::{Arc, Mutex};

use super::{InlineWakerArray, ReadinessArray};
use super::{
super::shared_arc::{waker_for_wake_data_slot, WakeDataContainer},
ReadinessArray,
};

/// A collection of wakers which delegate to an in-line waker.
pub(crate) struct WakerArray<const N: usize> {
wakers: [Waker; N],
readiness: Arc<Mutex<ReadinessArray<N>>>,
inner: Arc<WakerArrayInner<N>>,
}

/// See [super::super::shared_arc] for how this works.
struct WakerArrayInner<const N: usize> {
wake_data: [*const Self; N],
readiness: Mutex<ReadinessArray<N>>,
}

impl<const N: usize> WakerArray<N> {
/// Create a new instance of `WakerArray`.
pub(crate) fn new() -> Self {
let readiness = Arc::new(Mutex::new(ReadinessArray::new()));
Self {
wakers: array::from_fn(|i| {
Arc::new(InlineWakerArray::new(i, readiness.clone())).into()
}),
readiness,
}
let mut inner = Arc::new(WakerArrayInner {
readiness: Mutex::new(ReadinessArray::new()),
wake_data: [std::ptr::null(); N], // We don't know the Arc's address yet so put null for now.
});
let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address.

// At this point the strong count is 2. Decrement it to 1.
// Each time we create/clone a Waker the count will be incremented by 1.
// So N Wakers -> count = N+1.
unsafe { Arc::decrement_strong_count(raw) }

// Make wake_data all point to the Arc itself.
Arc::get_mut(&mut inner).unwrap().wake_data = [raw; N];

// Now the wake_data array is complete. Time to create the actual Wakers.
let wakers = array::from_fn(|i| {
let data = inner.wake_data.get(i).unwrap();
unsafe {
waker_for_wake_data_slot::<WakerArrayInner<N>>(
data as *const *const WakerArrayInner<N>,
)
}
});

Self { inner, wakers }
}

pub(crate) fn get(&self, index: usize) -> Option<&Waker> {
Expand All @@ -29,6 +55,63 @@ impl<const N: usize> WakerArray<N> {

/// Access the `Readiness`.
pub(crate) fn readiness(&self) -> &Mutex<ReadinessArray<N>> {
self.readiness.as_ref()
&self.inner.readiness
}
}

impl<const N: usize> WakeDataContainer for WakerArrayInner<N> {
fn get_wake_data_slice(&self) -> &[*const Self] {
&self.wake_data
}

fn wake_index(&self, index: usize) {
let mut readiness = self.readiness.lock().unwrap();
if !readiness.set_ready(index) {
readiness
.parent_waker()
.as_ref()
.expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?")
.wake_by_ref();
}
}
}

#[cfg(test)]
mod tests {
use crate::utils::wakers::dummy_waker;

use super::*;
#[test]
fn check_refcount() {
let mut wa = WakerArray::<5>::new();
assert_eq!(Arc::strong_count(&wa.inner), 6);
wa.wakers[4] = dummy_waker();
assert_eq!(Arc::strong_count(&wa.inner), 5);
let cloned = wa.wakers[3].clone();
assert_eq!(Arc::strong_count(&wa.inner), 6);
wa.wakers[3] = wa.wakers[4].clone();
assert_eq!(Arc::strong_count(&wa.inner), 5);
drop(cloned);
assert_eq!(Arc::strong_count(&wa.inner), 4);

wa.wakers[0].wake_by_ref();
wa.wakers[0].wake_by_ref();
wa.wakers[0].wake_by_ref();
assert_eq!(Arc::strong_count(&wa.inner), 4);

wa.wakers[0] = wa.wakers[1].clone();
assert_eq!(Arc::strong_count(&wa.inner), 4);

let taken = std::mem::replace(&mut wa.wakers[2], dummy_waker());
assert_eq!(Arc::strong_count(&wa.inner), 4);
taken.wake_by_ref();
taken.wake_by_ref();
taken.wake_by_ref();
assert_eq!(Arc::strong_count(&wa.inner), 4);
taken.wake();
assert_eq!(Arc::strong_count(&wa.inner), 3);

wa.wakers = array::from_fn(|_| dummy_waker());
assert_eq!(Arc::strong_count(&wa.inner), 1);
}
}
18 changes: 14 additions & 4 deletions src/utils/wakers/dummy.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
use std::{sync::Arc, task::Wake};
use core::task::{RawWaker, RawWakerVTable, Waker};

pub(crate) struct DummyWaker();
impl Wake for DummyWaker {
fn wake(self: Arc<Self>) {}
/// A Waker that doesn't do anything.
pub(crate) fn dummy_waker() -> Waker {
fn new_raw_waker() -> RawWaker {
unsafe fn no_op(_data: *const ()) {}
unsafe fn clone(_data: *const ()) -> RawWaker {
new_raw_waker()
}
RawWaker::new(
core::ptr::null() as *const usize as *const (),
&RawWakerVTable::new(clone, no_op, no_op, no_op),
)
}
unsafe { Waker::from_raw(new_raw_waker()) }
}
6 changes: 4 additions & 2 deletions src/utils/wakers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod array;
mod shared_arc;
mod vec;

#[cfg(test)]
mod dummy;
mod vec;

#[cfg(test)]
pub(crate) use dummy::DummyWaker;
pub(crate) use dummy::dummy_waker;

pub(crate) use array::*;
pub(crate) use vec::*;
92 changes: 92 additions & 0 deletions src/utils/wakers/shared_arc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use core::task::{RawWaker, RawWakerVTable, Waker};
wishawa marked this conversation as resolved.
Show resolved Hide resolved
use std::sync::Arc;

// In the diagram below, `A` is the upper block.
wishawa marked this conversation as resolved.
Show resolved Hide resolved
// It is a struct that implements WakeDataContainer (so either WakerVecInner or WakerArrayInner).
// The lower block is either WakerVec or WakerArray. Each waker there points to a slot of wake_data in `A`.
// Every one of these slots contain a pointer to the Arc wrapping `A` itself.
// Wakers figure out their indices by comparing the address they are pointing to to `wake_data`'s start address.
//
// ┌───────────────────────────┬──────────────┬──────────────┐
// │ │ │ │
// │ / ┌─────────────┬──────┼───────┬──────┼───────┬──────┼───────┬─────┐ \
// ▼ / │ │ │ │ │ │ │ │ │ \
// Arc < │ Readiness │ wake_data[0] │ wake_data[1] │ wake_data[2] │ ... │ >
// ▲ \ │ │ │ │ │ │ /
// │ \ └─────────────┴──────▲───────┴──────▲───────┴──────▲───────┴─────┘ /
// │ │ │ │
// └─┐ ┌───────────────┘ │ │
// │ │ │ │
// │ │ ┌──────────────────┘ │
// │ │ │ │
// │ │ │ ┌─────────────────────┘
// │ │ │ │
// │ │ │ │
// ┌────┼────┬────┼──────┬────┼──────┬────┼──────┬─────┐
// │ │ │ │ │ │ │ │ │ │
// │ Inner │ wakers[0] │ wakers[1] │ wakers[2] │ ... │
// │ │ │ │ │ │
// └─────────┴───────────┴───────────┴───────────┴─────┘

// TODO: Right now each waker gets its own wake_data slot.
// We can save space by making size_of::<usize>() wakers share the same slot.
// With such change, in 64-bit system, the wake_data array/vec would only need ⌈N/8⌉ slots instead of N.

pub(super) trait WakeDataContainer {
wishawa marked this conversation as resolved.
Show resolved Hide resolved
/// Get the reference of the wake_data slice. This is used to compute the index.
fn get_wake_data_slice(&self) -> &[*const Self];
/// Called when the subfuture at the specified index should be polled.
fn wake_index(&self, index: usize);
}
pub(super) unsafe fn waker_for_wake_data_slot<A: WakeDataContainer>(
wishawa marked this conversation as resolved.
Show resolved Hide resolved
pointer: *const *const A,
) -> Waker {
unsafe fn clone_waker<A: WakeDataContainer>(pointer: *const ()) -> RawWaker {
let pointer = pointer as *const *const A;
let raw = *pointer; // This is the raw pointer of Arc<Inner>.

// We're creating a new Waker, so we need to increment the count.
Arc::increment_strong_count(raw);

RawWaker::new(pointer as *const (), create_vtable::<A>())
}

// Convert a pointer to a wake_data slot to the Arc<Inner>.
unsafe fn to_arc<A: WakeDataContainer>(pointer: *const *const A) -> Arc<A> {
let raw = *pointer;
Arc::from_raw(raw)
}
unsafe fn wake<A: WakeDataContainer, const BY_REF: bool>(pointer: *const ()) {
let pointer = pointer as *const *const A;
let arc = to_arc::<A>(pointer);
// Calculate the index
let index = ((pointer as usize) // This is the slot our pointer points to.
- (arc.get_wake_data_slice() as *const [*const A] as *const () as usize)) // This is the starting address of wake_data.
/ std::mem::size_of::<*const A>();
wishawa marked this conversation as resolved.
Show resolved Hide resolved

arc.wake_index(index);

// Dropping the Arc would decrement the strong count.
// We only want to do that when we're not waking by ref.
if BY_REF {
std::mem::forget(arc);
wishawa marked this conversation as resolved.
Show resolved Hide resolved
} else {
std::mem::drop(arc);
}
}
unsafe fn drop_waker<A: WakeDataContainer>(pointer: *const ()) {
let pointer = pointer as *const *const A;
let arc = to_arc::<A>(pointer);
// Decrement the strong count by dropping the Arc.
std::mem::drop(arc);
wishawa marked this conversation as resolved.
Show resolved Hide resolved
}
fn create_vtable<A: WakeDataContainer>() -> &'static RawWakerVTable {
&RawWakerVTable::new(
clone_waker::<A>,
wake::<A, false>,
wake::<A, true>,
wishawa marked this conversation as resolved.
Show resolved Hide resolved
drop_waker::<A>,
)
}
Waker::from_raw(clone_waker::<A>(pointer as *const ()))
}
2 changes: 0 additions & 2 deletions src/utils/wakers/vec/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
mod readiness;
mod waker;
mod waker_vec;

pub(crate) use readiness::ReadinessVec;
pub(crate) use waker::InlineWakerVec;
pub(crate) use waker_vec::WakerVec;
31 changes: 0 additions & 31 deletions src/utils/wakers/vec/waker.rs

This file was deleted.

Loading