Skip to content

Commit

Permalink
improve readability, provenance compliance, and documentation for sha…
Browse files Browse the repository at this point in the history
…red arc waker
  • Loading branch information
wishawa committed Feb 12, 2023
1 parent 1d0be87 commit 4f164de
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 88 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ harness = false
bitvec = { version = "1.0.1", default-features = false, features = ["alloc"] }
futures-core = "0.3"
pin-project = "1.0.8"
sptr = "0.3.2"

[dev-dependencies]
futures = "0.3.25"
Expand Down
22 changes: 11 additions & 11 deletions src/utils/wakers/array/waker_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use core::task::Waker;
use std::sync::{Arc, Mutex};

use super::{
super::shared_arc::{waker_for_wake_data_slot, WakeDataContainer},
super::shared_arc::{waker_from_redirect_position, SharedArcContent},
ReadinessArray,
};

Expand All @@ -15,7 +15,7 @@ pub(crate) struct WakerArray<const N: usize> {

/// See [super::super::shared_arc] for how this works.
struct WakerArrayInner<const N: usize> {
wake_data: [*const Self; N],
redirect: [*const Self; N],
readiness: Mutex<ReadinessArray<N>>,
}

Expand All @@ -24,7 +24,7 @@ impl<const N: usize> WakerArray<N> {
pub(crate) fn new() -> Self {
let mut inner = Arc::new(WakerArrayInner {
readiness: Mutex::new(ReadinessArray::new()),
wake_data: [std::ptr::null(); N], // We don't know the Arc's address yet so put null for now.
redirect: [std::ptr::null(); N], // We don't know the Arc's address yet so put null for now.
});
let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address.

Expand All @@ -33,14 +33,14 @@ impl<const N: usize> WakerArray<N> {
// So N Wakers -> count = N+1.
unsafe { Arc::decrement_strong_count(raw) }

// Make wake_data all point to the Arc itself.
Arc::get_mut(&mut inner).unwrap().wake_data = [raw; N];
// Make redirect all point to the Arc itself.
Arc::get_mut(&mut inner).unwrap().redirect = [raw; N];

// Now the wake_data array is complete. Time to create the actual Wakers.
// Now the redirect array is complete. Time to create the actual Wakers.
let wakers = array::from_fn(|i| {
let data = inner.wake_data.get(i).unwrap();
let data = inner.redirect.get(i).unwrap();
unsafe {
waker_for_wake_data_slot::<WakerArrayInner<N>>(
waker_from_redirect_position::<WakerArrayInner<N>>(
data as *const *const WakerArrayInner<N>,
)
}
Expand All @@ -59,9 +59,9 @@ impl<const N: usize> WakerArray<N> {
}
}

impl<const N: usize> WakeDataContainer for WakerArrayInner<N> {
fn get_wake_data_slice(&self) -> &[*const Self] {
&self.wake_data
impl<const N: usize> SharedArcContent for WakerArrayInner<N> {
fn get_redirect_slice(&self) -> &[*const Self] {
&self.redirect
}

fn wake_index(&self, index: usize) {
Expand Down
8 changes: 8 additions & 0 deletions src/utils/wakers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
//! Wakers that track when they are woken.
//!
//! By tracking which subfutures have woken, we can avoid having to re-poll N subfutures every time.
//! This tracking is done by a [ReadinessArray]/[ReadinessVec]. These store the indexes of the subfutures that have woken.
//! Each subfuture are given a Waker when polled.
//! This waker must know the index of its corresponding subfuture so that it can update Readiness correctly.
//!
mod array;
mod shared_arc;
mod vec;
Expand Down
176 changes: 110 additions & 66 deletions src/utils/wakers/shared_arc.rs
Original file line number Diff line number Diff line change
@@ -1,92 +1,136 @@
use core::task::{RawWaker, RawWakerVTable, Waker};
use std::sync::Arc;
//! To save on allocations, we avoid making a separate Arc Waker for every subfuture.
//! Rather, we have all N Wakers share a single Arc, and use a "redirect" mechanism to allow different wakers to be distinguished.
//! The mechanism works as follows.
//! The Arc contains 2 things:
//! - the Readiness structure ([ReadinessArray][super::array::ReadinessArray] / [ReadinessVec][super::vec::ReadinessVec])
//! - the redirect array.
//! The redirect array contains N repeated copies of the pointer to the Arc itself (obtained by `Arc::into_raw`).
//! The Waker for the `i`th subfuture points to the `i`th item in the redirect array.
//! (i.e. the Waker pointer is `*const *const A` where `A` is the type of the item in the Arc)
//! When the Waker is woken, we deref it twice (giving reference to the content of the Arc),
//! and compare it to the address of the redirect slice.
//! The difference tells us the index of the waker. We can then record this woken index in the Readiness.
//!
//! ```text
//! ┌───────────────────────────┬──────────────┬──────────────┐
//! │ │ │ │
//! │ / ┌─────────────┬──────┼───────┬──────┼───────┬──────┼───────┬─────┐ \
//! ▼ / │ │ │ │ │ │ │ │ │ \
//! Arc < │ Readiness │ redirect[0] │ redirect[1] │ redirect[2] │ ... │ >
//! ▲ \ │ │ │ │ │ │ /
//! │ \ └─────────────┴──────▲───────┴──────▲───────┴──────▲───────┴─────┘ /
//! │ │ │ │
//! └─┐ ┌───────────────┘ │ │
//! │ │ │ │
//! │ │ ┌──────────────────┘ │
//! │ │ │ │
//! │ │ │ ┌─────────────────────┘
//! │ │ │ │
//! │ │ │ │
//! ┌────┼────┬────┼──────┬────┼──────┬────┼──────┬─────┐
//! │ │ │ │ │ │ │ │ │ │
//! │ │ wakers[0] │ wakers[1] │ wakers[2] │ ... │
//! │ │ │ │ │ │
//! └─────────┴───────────┴───────────┴───────────┴─────┘
//! ```
// In the diagram below, `A` is the upper block.
// It is a struct that implements WakeDataContainer (so either WakerVecInner or WakerArrayInner).
// The lower block is either WakerVec or WakerArray. Each waker there points to a slot of wake_data in `A`.
// Every one of these slots contain a pointer to the Arc wrapping `A` itself.
// Wakers figure out their indices by comparing the address they are pointing to to `wake_data`'s start address.
//
// ┌───────────────────────────┬──────────────┬──────────────┐
// │ │ │ │
// │ / ┌─────────────┬──────┼───────┬──────┼───────┬──────┼───────┬─────┐ \
// ▼ / │ │ │ │ │ │ │ │ │ \
// Arc < │ Readiness │ wake_data[0] │ wake_data[1] │ wake_data[2] │ ... │ >
// ▲ \ │ │ │ │ │ │ /
// │ \ └─────────────┴──────▲───────┴──────▲───────┴──────▲───────┴─────┘ /
// │ │ │ │
// └─┐ ┌───────────────┘ │ │
// │ │ │ │
// │ │ ┌──────────────────┘ │
// │ │ │ │
// │ │ │ ┌─────────────────────┘
// │ │ │ │
// │ │ │ │
// ┌────┼────┬────┼──────┬────┼──────┬────┼──────┬─────┐
// │ │ │ │ │ │ │ │ │ │
// │ Inner │ wakers[0] │ wakers[1] │ wakers[2] │ ... │
// │ │ │ │ │ │
// └─────────┴───────────┴───────────┴───────────┴─────┘
// TODO: Right now each waker gets its own redirect slot.
// We can save space by making size_of::<*const _>() wakers share the same slot.
// With such change, in 64-bit system, the redirect array/vec would only need ⌈N/8⌉ slots instead of N.

// TODO: Right now each waker gets its own wake_data slot.
// We can save space by making size_of::<usize>() wakers share the same slot.
// With such change, in 64-bit system, the wake_data array/vec would only need ⌈N/8⌉ slots instead of N.
use core::task::{RawWaker, RawWakerVTable, Waker};
use std::sync::Arc;

pub(super) trait WakeDataContainer {
/// Get the reference of the wake_data slice. This is used to compute the index.
fn get_wake_data_slice(&self) -> &[*const Self];
/// A trait to be implemented on [super::WakerArray] and [super::WakerVec] for polymorphism.
/// These are the type that goes in the Arc. They both contain the Readiness and the redirect array/vec.
pub(super) trait SharedArcContent {
/// Get the reference of the redirect slice. This is used to compute the index.
fn get_redirect_slice(&self) -> &[*const Self];
/// Called when the subfuture at the specified index should be polled.
/// Should call `Readiness::set_ready`.
fn wake_index(&self, index: usize);
}
pub(super) unsafe fn waker_for_wake_data_slot<A: WakeDataContainer>(

/// Create one waker following the mechanism described in the [module][self] doc.
/// The following must be upheld for safety:
/// - `pointer` must points to a slot in the redirect array.
/// - that slot must contain a pointer obtained by `Arc::<A>::into_raw`.
/// - the Arc must still be alive at the time this function is called.
/// The following should be upheld for correct behavior:
/// - calling `SharedArcContent::get_redirect_slice` on the content of the Arc should give the redirect array within which `pointer` points to.
#[deny(unsafe_op_in_unsafe_fn)]
pub(super) unsafe fn waker_from_redirect_position<A: SharedArcContent>(
pointer: *const *const A,
) -> Waker {
unsafe fn clone_waker<A: WakeDataContainer>(pointer: *const ()) -> RawWaker {
/// Create a Waker from a type-erased pointer.
/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
unsafe fn create_waker<A: SharedArcContent>(pointer: *const ()) -> RawWaker {
// Retype the type-erased pointer.
let pointer = pointer as *const *const A;
let raw = *pointer; // This is the raw pointer of Arc<Inner>.

// We're creating a new Waker, so we need to increment the count.
Arc::increment_strong_count(raw);
// SAFETY: The constraints listed for the wrapping function documentation means
// - `*pointer` is an `*const A` obtained from `Arc::<A>::into_raw`.
// - the Arc is alive.
// So this operation is safe.
unsafe { Arc::increment_strong_count(*pointer) };

RawWaker::new(pointer as *const (), create_vtable::<A>())
}

// Convert a pointer to a wake_data slot to the Arc<Inner>.
unsafe fn to_arc<A: WakeDataContainer>(pointer: *const *const A) -> Arc<A> {
let raw = *pointer;
Arc::from_raw(raw)
}
unsafe fn wake<A: WakeDataContainer, const BY_REF: bool>(pointer: *const ()) {
/// Invoke `SharedArcContent::wake_index` with the index in the redirect slice where this pointer points to.
/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
unsafe fn wake_by_ref<A: SharedArcContent>(pointer: *const ()) {
// Retype the type-erased pointer.
let pointer = pointer as *const *const A;
let arc = to_arc::<A>(pointer);
// Calculate the index
let index = ((pointer as usize) // This is the slot our pointer points to.
- (arc.get_wake_data_slice() as *const [*const A] as *const () as usize)) // This is the starting address of wake_data.
/ std::mem::size_of::<*const A>();

arc.wake_index(index);
// SAFETY: we are already requiring `pointer` to point to a slot in the redirect array.
let raw: *const A = unsafe { *pointer };
// SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`.
let arc_content: &A = unsafe { &*raw };

// Dropping the Arc would decrement the strong count.
// We only want to do that when we're not waking by ref.
if BY_REF {
std::mem::forget(arc);
} else {
std::mem::drop(arc);
}
// Calculate the index.
// This is your familiar pointer math
// `item_address = array_address + (index * item_size)`
// rearranged to
// `index = (item_address - array_address) / item_size`.
let item_address = sptr::Strict::addr(pointer);
let redirect_slice_address = sptr::Strict::addr(arc_content.get_redirect_slice().as_ptr());
let redirect_item_size = core::mem::size_of::<*const A>(); // the size of each item in the redirect slice
let index = (item_address - redirect_slice_address) / redirect_item_size;

arc_content.wake_index(index);
}
unsafe fn drop_waker<A: WakeDataContainer>(pointer: *const ()) {

/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
unsafe fn drop_waker<A: SharedArcContent>(pointer: *const ()) {
// Retype the type-erased pointer.
let pointer = pointer as *const *const A;
let arc = to_arc::<A>(pointer);
// Decrement the strong count by dropping the Arc.
std::mem::drop(arc);

// SAFETY: we are already requiring `pointer` to point to a slot in the redirect array.
let raw = unsafe { *pointer };
// SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`.
unsafe { Arc::decrement_strong_count(raw) };
}
fn create_vtable<A: WakeDataContainer>() -> &'static RawWakerVTable {

/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
unsafe fn wake<A: SharedArcContent>(pointer: *const ()) {
// SAFETY: we are already requiring the constraints of `wake_by_ref` and `drop_waker`.
unsafe {
wake_by_ref::<A>(pointer);
drop_waker::<A>(pointer);
}
}

fn create_vtable<A: SharedArcContent>() -> &'static RawWakerVTable {
&RawWakerVTable::new(
clone_waker::<A>,
wake::<A, false>,
wake::<A, true>,
create_waker::<A>,
wake::<A>,
wake_by_ref::<A>,
drop_waker::<A>,
)
}
Waker::from_raw(clone_waker::<A>(pointer as *const ()))
// SAFETY: All our vtable functions adhere to the RawWakerVTable contract,
// and we are already requiring that `pointer` is what our functions expect.
unsafe { Waker::from_raw(create_waker::<A>(pointer as *const ())) }
}
22 changes: 11 additions & 11 deletions src/utils/wakers/vec/waker_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use core::task::Waker;
use std::sync::{Arc, Mutex};

use super::{
super::shared_arc::{waker_for_wake_data_slot, WakeDataContainer},
super::shared_arc::{waker_from_redirect_position, SharedArcContent},
ReadinessVec,
};

Expand All @@ -14,7 +14,7 @@ pub(crate) struct WakerVec {

/// See [super::super::shared_arc] for how this works.
struct WakerVecInner {
wake_data: Vec<*const Self>,
redirect: Vec<*const Self>,
readiness: Mutex<ReadinessVec>,
}

Expand All @@ -23,7 +23,7 @@ impl WakerVec {
pub(crate) fn new(len: usize) -> Self {
let mut inner = Arc::new(WakerVecInner {
readiness: Mutex::new(ReadinessVec::new(len)),
wake_data: Vec::new(),
redirect: Vec::new(),
});
let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address.

Expand All @@ -32,15 +32,15 @@ impl WakerVec {
// So N Wakers -> count = N+1.
unsafe { Arc::decrement_strong_count(raw) }

// Make wake_data all point to the Arc itself.
Arc::get_mut(&mut inner).unwrap().wake_data = vec![raw; len];
// Make redirect all point to the Arc itself.
Arc::get_mut(&mut inner).unwrap().redirect = vec![raw; len];

// Now the wake_data vec is complete. Time to create the actual Wakers.
// Now the redirect vec is complete. Time to create the actual Wakers.
let wakers = inner
.wake_data
.redirect
.iter()
.map(|data| unsafe {
waker_for_wake_data_slot::<WakerVecInner>(data as *const *const WakerVecInner)
waker_from_redirect_position::<WakerVecInner>(data as *const *const WakerVecInner)
})
.collect();

Expand All @@ -56,9 +56,9 @@ impl WakerVec {
}
}

impl WakeDataContainer for WakerVecInner {
fn get_wake_data_slice(&self) -> &[*const Self] {
&self.wake_data
impl SharedArcContent for WakerVecInner {
fn get_redirect_slice(&self) -> &[*const Self] {
&self.redirect
}

fn wake_index(&self, index: usize) {
Expand Down

0 comments on commit 4f164de

Please sign in to comment.