Skip to content

Commit

Permalink
simplify creation of WakerArray and WakerVec and make waker_from_redi…
Browse files Browse the repository at this point in the history
…rect_position safer
  • Loading branch information
wishawa committed Apr 25, 2023
1 parent 4f164de commit 6466c29
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 77 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ harness = false
bitvec = { version = "1.0.1", default-features = false, features = ["alloc"] }
futures-core = "0.3"
pin-project = "1.0.8"
sptr = "0.3.2"

[dev-dependencies]
futures = "0.3.25"
Expand Down
46 changes: 21 additions & 25 deletions src/utils/wakers/array/waker_array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::array;
use core::task::Waker;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, Weak};

use super::{
super::shared_arc::{waker_from_redirect_position, SharedArcContent},
Expand All @@ -22,30 +22,18 @@ struct WakerArrayInner<const N: usize> {
impl<const N: usize> WakerArray<N> {
/// Create a new instance of `WakerArray`.
pub(crate) fn new() -> Self {
let mut inner = Arc::new(WakerArrayInner {
readiness: Mutex::new(ReadinessArray::new()),
redirect: [std::ptr::null(); N], // We don't know the Arc's address yet so put null for now.
});
let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address.

// At this point the strong count is 2. Decrement it to 1.
// Each time we create/clone a Waker the count will be incremented by 1.
// So N Wakers -> count = N+1.
unsafe { Arc::decrement_strong_count(raw) }

// Make redirect all point to the Arc itself.
Arc::get_mut(&mut inner).unwrap().redirect = [raw; N];

// Now the redirect array is complete. Time to create the actual Wakers.
let wakers = array::from_fn(|i| {
let data = inner.redirect.get(i).unwrap();
unsafe {
waker_from_redirect_position::<WakerArrayInner<N>>(
data as *const *const WakerArrayInner<N>,
)
let inner = Arc::new_cyclic(|w| {
// `Weak::as_ptr` on a live Weak gives the same thing as `Arc::into_raw`.
let raw = Weak::as_ptr(w);
WakerArrayInner {
readiness: Mutex::new(ReadinessArray::new()),
redirect: [raw; N],
}
});

let wakers =
array::from_fn(|i| unsafe { waker_from_redirect_position(Arc::clone(&inner), i) });

Self { inner, wakers }
}

Expand All @@ -59,7 +47,8 @@ impl<const N: usize> WakerArray<N> {
}
}

impl<const N: usize> SharedArcContent for WakerArrayInner<N> {
#[deny(unsafe_op_in_unsafe_fn)]
unsafe impl<const N: usize> SharedArcContent for WakerArrayInner<N> {
fn get_redirect_slice(&self) -> &[*const Self] {
&self.redirect
}
Expand All @@ -84,7 +73,10 @@ mod tests {
#[test]
fn check_refcount() {
let mut wa = WakerArray::<5>::new();

// Each waker holds 1 ref, and the combinator itself holds 1.
assert_eq!(Arc::strong_count(&wa.inner), 6);

wa.wakers[4] = dummy_waker();
assert_eq!(Arc::strong_count(&wa.inner), 5);
let cloned = wa.wakers[3].clone();
Expand All @@ -105,13 +97,17 @@ mod tests {
let taken = std::mem::replace(&mut wa.wakers[2], dummy_waker());
assert_eq!(Arc::strong_count(&wa.inner), 4);
taken.wake_by_ref();
taken.wake_by_ref();
taken.wake_by_ref();
assert_eq!(Arc::strong_count(&wa.inner), 4);
taken.clone().wake();
assert_eq!(Arc::strong_count(&wa.inner), 4);
taken.wake();
assert_eq!(Arc::strong_count(&wa.inner), 3);

wa.wakers = array::from_fn(|_| dummy_waker());
assert_eq!(Arc::strong_count(&wa.inner), 1);

let weak = Arc::downgrade(&wa.inner);
drop(wa);
assert_eq!(weak.strong_count(), 0);
}
}
76 changes: 45 additions & 31 deletions src/utils/wakers/shared_arc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,43 +43,46 @@ use std::sync::Arc;

/// A trait to be implemented on [super::WakerArray] and [super::WakerVec] for polymorphism.
/// These are the type that goes in the Arc. They both contain the Readiness and the redirect array/vec.
pub(super) trait SharedArcContent {
/// Get the reference of the redirect slice. This is used to compute the index.
/// # Safety
/// The `get_redirect_slice` method MUST always return the same slice for the same self.
pub(super) unsafe trait SharedArcContent {
/// Get the reference of the redirect slice.
fn get_redirect_slice(&self) -> &[*const Self];

/// Called when the subfuture at the specified index should be polled.
/// Should call `Readiness::set_ready`.
fn wake_index(&self, index: usize);
}

/// Create one waker following the mechanism described in the [module][self] doc.
/// The following must be upheld for safety:
/// - `pointer` must points to a slot in the redirect array.
/// - that slot must contain a pointer obtained by `Arc::<A>::into_raw`.
/// - the Arc must still be alive at the time this function is called.
/// The following should be upheld for correct behavior:
/// - calling `SharedArcContent::get_redirect_slice` on the content of the Arc should give the redirect array within which `pointer` points to.
/// For safety, the index MUST be within bounds of the slice returned by `A::get_redirect_slice()`.
#[deny(unsafe_op_in_unsafe_fn)]
pub(super) unsafe fn waker_from_redirect_position<A: SharedArcContent>(
pointer: *const *const A,
arc: Arc<A>,
index: usize,
) -> Waker {
/// Create a Waker from a type-erased pointer.
/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
unsafe fn create_waker<A: SharedArcContent>(pointer: *const ()) -> RawWaker {
// For `create_waker`, `wake_by_ref`, `wake`, and `drop_waker`, the following MUST be upheld for safety:
// - `pointer` must points to a slot in the redirect array.
// - that slot must contain a pointer of an Arc obtained from `Arc::<A>::into_raw`.
// - that Arc must still be alive (strong count > 0) at the time the function is called.

/// Clone a Waker from a type-erased pointer.
/// The pointer must satisfy the safety constraints listed in the code comments above.
unsafe fn clone_waker<A: SharedArcContent>(pointer: *const ()) -> RawWaker {
// Retype the type-erased pointer.
let pointer = pointer as *const *const A;

// We're creating a new Waker, so we need to increment the count.
// SAFETY: The constraints listed for the wrapping function documentation means
// Increment the count so that the Arc won't die before this new Waker we're creating.
// SAFETY: The required constraints means
// - `*pointer` is an `*const A` obtained from `Arc::<A>::into_raw`.
// - the Arc is alive.
// So this operation is safe.
// - the Arc is alive right now.
unsafe { Arc::increment_strong_count(*pointer) };

RawWaker::new(pointer as *const (), create_vtable::<A>())
}

/// Invoke `SharedArcContent::wake_index` with the index in the redirect slice where this pointer points to.
/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
/// The pointer must satisfy the safety constraints listed in the code comments above.
unsafe fn wake_by_ref<A: SharedArcContent>(pointer: *const ()) {
// Retype the type-erased pointer.
let pointer = pointer as *const *const A;
Expand All @@ -89,31 +92,28 @@ pub(super) unsafe fn waker_from_redirect_position<A: SharedArcContent>(
// SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`.
let arc_content: &A = unsafe { &*raw };

// Calculate the index.
// This is your familiar pointer math
// `item_address = array_address + (index * item_size)`
// rearranged to
// `index = (item_address - array_address) / item_size`.
let item_address = sptr::Strict::addr(pointer);
let redirect_slice_address = sptr::Strict::addr(arc_content.get_redirect_slice().as_ptr());
let redirect_item_size = core::mem::size_of::<*const A>(); // the size of each item in the redirect slice
let index = (item_address - redirect_slice_address) / redirect_item_size;
let slice_start = arc_content.get_redirect_slice().as_ptr();

// We'll switch to [`sub_ptr`](https://github.com/rust-lang/rust/issues/95892) once that's stable.
let index = unsafe { pointer.offset_from(slice_start) } as usize;

arc_content.wake_index(index);
}

/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
/// Drop the waker (and drop the shared Arc if other Wakers and the combinator have all been dropped).
/// The pointer must satisfy the safety constraints listed in the code comments above.
unsafe fn drop_waker<A: SharedArcContent>(pointer: *const ()) {
// Retype the type-erased pointer.
let pointer = pointer as *const *const A;

// SAFETY: we are already requiring `pointer` to point to a slot in the redirect array.
let raw = unsafe { *pointer };
let raw: *const A = unsafe { *pointer };
// SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`.
unsafe { Arc::decrement_strong_count(raw) };
}

/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
/// Like `wake_by_ref` but consumes the Waker.
/// The pointer must satisfy the safety constraints listed in the code comments above.
unsafe fn wake<A: SharedArcContent>(pointer: *const ()) {
// SAFETY: we are already requiring the constraints of `wake_by_ref` and `drop_waker`.
unsafe {
Expand All @@ -124,13 +124,27 @@ pub(super) unsafe fn waker_from_redirect_position<A: SharedArcContent>(

fn create_vtable<A: SharedArcContent>() -> &'static RawWakerVTable {
&RawWakerVTable::new(
create_waker::<A>,
clone_waker::<A>,
wake::<A>,
wake_by_ref::<A>,
drop_waker::<A>,
)
}

let redirect_slice: &[*const A] = arc.get_redirect_slice();

debug_assert!(redirect_slice.len() > index);

// SAFETY: we are already requiring that index be in bound of the slice.
let pointer: *const *const A = unsafe { redirect_slice.as_ptr().add(index) };
// Type-erase the pointer because the Waker methods expect so.
let pointer = pointer as *const ();

// We want to transfer management of the one strong count associated with `arc` to the Waker we're creating.
// That count should only be decremented when the Waker is dropped (by `drop_waker`).
core::mem::forget(arc);

// SAFETY: All our vtable functions adhere to the RawWakerVTable contract,
// and we are already requiring that `pointer` is what our functions expect.
unsafe { Waker::from_raw(create_waker::<A>(pointer as *const ())) }
unsafe { Waker::from_raw(RawWaker::new(pointer, create_vtable::<A>())) }
}
32 changes: 12 additions & 20 deletions src/utils/wakers/vec/waker_vec.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::task::Waker;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, Weak};

use super::{
super::shared_arc::{waker_from_redirect_position, SharedArcContent},
Expand All @@ -21,27 +21,18 @@ struct WakerVecInner {
impl WakerVec {
/// Create a new instance of `WakerVec`.
pub(crate) fn new(len: usize) -> Self {
let mut inner = Arc::new(WakerVecInner {
readiness: Mutex::new(ReadinessVec::new(len)),
redirect: Vec::new(),
let inner = Arc::new_cyclic(|w| {
// `Weak::as_ptr` on a live Weak gives the same thing as `Arc::into_raw`.
let raw = Weak::as_ptr(w);
WakerVecInner {
readiness: Mutex::new(ReadinessVec::new(len)),
redirect: vec![raw; len],
}
});
let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address.

// At this point the strong count is 2. Decrement it to 1.
// Each time we create/clone a Waker the count will be incremented by 1.
// So N Wakers -> count = N+1.
unsafe { Arc::decrement_strong_count(raw) }

// Make redirect all point to the Arc itself.
Arc::get_mut(&mut inner).unwrap().redirect = vec![raw; len];

// Now the redirect vec is complete. Time to create the actual Wakers.
let wakers = inner
.redirect
.iter()
.map(|data| unsafe {
waker_from_redirect_position::<WakerVecInner>(data as *const *const WakerVecInner)
})
let wakers = (0..len)
.map(|i| unsafe { waker_from_redirect_position(Arc::clone(&inner), i) })
.collect();

Self { inner, wakers }
Expand All @@ -56,7 +47,8 @@ impl WakerVec {
}
}

impl SharedArcContent for WakerVecInner {
#[deny(unsafe_op_in_unsafe_fn)]
unsafe impl SharedArcContent for WakerVecInner {
fn get_redirect_slice(&self) -> &[*const Self] {
&self.redirect
}
Expand Down

0 comments on commit 6466c29

Please sign in to comment.