From 69f7db78498f4aed23cddefbfef8ed41e3f049a5 Mon Sep 17 00:00:00 2001
From: Wisha Wa <wisha.wa@yandex.com>
Date: Mon, 13 Feb 2023 04:25:16 +0000
Subject: [PATCH] simplify creation of WakerArray and WakerVec and make
 waker_from_redirect_position safer

---
 src/utils/wakers/array/waker_array.rs | 46 ++++++++---------
 src/utils/wakers/shared_arc.rs        | 74 ++++++++++++++++-----------
 src/utils/wakers/vec/waker_vec.rs     | 32 +++++-------
 3 files changed, 77 insertions(+), 75 deletions(-)

diff --git a/src/utils/wakers/array/waker_array.rs b/src/utils/wakers/array/waker_array.rs
index 3ad3291..ac6def9 100644
--- a/src/utils/wakers/array/waker_array.rs
+++ b/src/utils/wakers/array/waker_array.rs
@@ -1,6 +1,6 @@
 use core::array;
 use core::task::Waker;
-use std::sync::{Arc, Mutex};
+use std::sync::{Arc, Mutex, Weak};
 
 use super::{
     super::shared_arc::{waker_from_redirect_position, SharedArcContent},
@@ -22,30 +22,18 @@ struct WakerArrayInner<const N: usize> {
 impl<const N: usize> WakerArray<N> {
     /// Create a new instance of `WakerArray`.
     pub(crate) fn new() -> Self {
-        let mut inner = Arc::new(WakerArrayInner {
-            readiness: Mutex::new(ReadinessArray::new()),
-            redirect: [std::ptr::null(); N], // We don't know the Arc's address yet so put null for now.
-        });
-        let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address.
-
-        // At this point the strong count is 2. Decrement it to 1.
-        // Each time we create/clone a Waker the count will be incremented by 1.
-        // So N Wakers -> count = N+1.
-        unsafe { Arc::decrement_strong_count(raw) }
-
-        // Make redirect all point to the Arc itself.
-        Arc::get_mut(&mut inner).unwrap().redirect = [raw; N];
-
-        // Now the redirect array is complete. Time to create the actual Wakers.
-        let wakers = array::from_fn(|i| {
-            let data = inner.redirect.get(i).unwrap();
-            unsafe {
-                waker_from_redirect_position::<WakerArrayInner<N>>(
-                    data as *const *const WakerArrayInner<N>,
-                )
+        let inner = Arc::new_cyclic(|w| {
+            // `Weak::as_ptr` on a live Weak gives the same thing as `Arc::into_raw`.
+            let raw = Weak::as_ptr(w);
+            WakerArrayInner {
+                readiness: Mutex::new(ReadinessArray::new()),
+                redirect: [raw; N],
             }
         });
 
+        let wakers =
+            array::from_fn(|i| unsafe { waker_from_redirect_position(Arc::clone(&inner), i) });
+
         Self { inner, wakers }
     }
 
@@ -59,7 +47,8 @@ impl<const N: usize> WakerArray<N> {
     }
 }
 
-impl<const N: usize> SharedArcContent for WakerArrayInner<N> {
+#[deny(unsafe_op_in_unsafe_fn)]
+unsafe impl<const N: usize> SharedArcContent for WakerArrayInner<N> {
     fn get_redirect_slice(&self) -> &[*const Self] {
         &self.redirect
     }
@@ -84,7 +73,10 @@ mod tests {
     #[test]
     fn check_refcount() {
         let mut wa = WakerArray::<5>::new();
+
+        // Each waker holds 1 ref, and the combinator itself holds 1.
         assert_eq!(Arc::strong_count(&wa.inner), 6);
+
         wa.wakers[4] = dummy_waker();
         assert_eq!(Arc::strong_count(&wa.inner), 5);
         let cloned = wa.wakers[3].clone();
@@ -105,13 +97,17 @@ mod tests {
         let taken = std::mem::replace(&mut wa.wakers[2], dummy_waker());
         assert_eq!(Arc::strong_count(&wa.inner), 4);
         taken.wake_by_ref();
-        taken.wake_by_ref();
-        taken.wake_by_ref();
+        assert_eq!(Arc::strong_count(&wa.inner), 4);
+        taken.clone().wake();
         assert_eq!(Arc::strong_count(&wa.inner), 4);
         taken.wake();
         assert_eq!(Arc::strong_count(&wa.inner), 3);
 
         wa.wakers = array::from_fn(|_| dummy_waker());
         assert_eq!(Arc::strong_count(&wa.inner), 1);
+
+        let weak = Arc::downgrade(&wa.inner);
+        drop(wa);
+        assert_eq!(weak.strong_count(), 0);
     }
 }
diff --git a/src/utils/wakers/shared_arc.rs b/src/utils/wakers/shared_arc.rs
index 313fb41..da46e48 100644
--- a/src/utils/wakers/shared_arc.rs
+++ b/src/utils/wakers/shared_arc.rs
@@ -43,43 +43,46 @@ use std::sync::Arc;
 
 /// A trait to be implemented on [super::WakerArray] and [super::WakerVec] for polymorphism.
 /// These are the type that goes in the Arc. They both contain the Readiness and the redirect array/vec.
-pub(super) trait SharedArcContent {
+pub(super) unsafe trait SharedArcContent {
     /// Get the reference of the redirect slice. This is used to compute the index.
+    /// This method MUST always return the same slice for the same self.
+    /// Failure to do so may lead to unsafety.
     fn get_redirect_slice(&self) -> &[*const Self];
+
     /// Called when the subfuture at the specified index should be polled.
     /// Should call `Readiness::set_ready`.
     fn wake_index(&self, index: usize);
 }
 
 /// Create one waker following the mechanism described in the [module][self] doc.
-/// The following must be upheld for safety:
-/// - `pointer` must points to a slot in the redirect array.
-/// - that slot must contain a pointer obtained by `Arc::<A>::into_raw`.
-/// - the Arc must still be alive at the time this function is called.
-/// The following should be upheld for correct behavior:
-/// - calling `SharedArcContent::get_redirect_slice` on the content of the Arc should give the redirect array within which `pointer` points to.
+/// For safety, the index MUST be within bounds of the slice returned by `A::get_redirect_slice()`.
 #[deny(unsafe_op_in_unsafe_fn)]
 pub(super) unsafe fn waker_from_redirect_position<A: SharedArcContent>(
-    pointer: *const *const A,
+    arc: Arc<A>,
+    index: usize,
 ) -> Waker {
-    /// Create a Waker from a type-erased pointer.
-    /// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
-    unsafe fn create_waker<A: SharedArcContent>(pointer: *const ()) -> RawWaker {
+    // For `create_waker`, `wake_by_ref`, `wake`, and `drop_waker`, the following MUST be upheld for safety:
+    // - `pointer` must points to a slot in the redirect array.
+    // - that slot must contain a pointer of an Arc obtained from `Arc::<A>::into_raw`.
+    // - that Arc must still be alive (strong count > 0) at the time the function is called.
+
+    /// Clone a Waker from a type-erased pointer.
+    /// The pointer must satisfy the safety constraints listed in the code comments above.
+    unsafe fn clone_waker<A: SharedArcContent>(pointer: *const ()) -> RawWaker {
         // Retype the type-erased pointer.
         let pointer = pointer as *const *const A;
 
-        // We're creating a new Waker, so we need to increment the count.
-        // SAFETY: The constraints listed for the wrapping function documentation means
+        // Increment the count so that the Arc won't die before this new Waker we're creating.
+        // SAFETY: The required constraints means
         // - `*pointer` is an `*const A` obtained from `Arc::<A>::into_raw`.
-        // - the Arc is alive.
-        // So this operation is safe.
+        // - the Arc is alive right now.
         unsafe { Arc::increment_strong_count(*pointer) };
 
         RawWaker::new(pointer as *const (), create_vtable::<A>())
     }
 
     /// Invoke `SharedArcContent::wake_index` with the index in the redirect slice where this pointer points to.
-    /// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
+    /// The pointer must satisfy the safety constraints listed in the code comments above.
     unsafe fn wake_by_ref<A: SharedArcContent>(pointer: *const ()) {
         // Retype the type-erased pointer.
         let pointer = pointer as *const *const A;
@@ -89,31 +92,28 @@ pub(super) unsafe fn waker_from_redirect_position<A: SharedArcContent>(
         // SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`.
         let arc_content: &A = unsafe { &*raw };
 
-        // Calculate the index.
-        // This is your familiar pointer math
-        // `item_address = array_address + (index * item_size)`
-        // rearranged to
-        // `index = (item_address - array_address) / item_size`.
-        let item_address = sptr::Strict::addr(pointer);
-        let redirect_slice_address = sptr::Strict::addr(arc_content.get_redirect_slice().as_ptr());
-        let redirect_item_size = core::mem::size_of::<*const A>(); // the size of each item in the redirect slice
-        let index = (item_address - redirect_slice_address) / redirect_item_size;
+        let slice_start = arc_content.get_redirect_slice().as_ptr();
+
+        // We'll switch to [`sub_ptr`](https://github.com/rust-lang/rust/issues/95892) once that's stable.
+        let index = unsafe { pointer.offset_from(slice_start) } as usize;
 
         arc_content.wake_index(index);
     }
 
-    /// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
+    /// Drop the waker (and drop the shared Arc if other Wakers and the combinator have all been dropped).
+    /// The pointer must satisfy the safety constraints listed in the code comments above.
     unsafe fn drop_waker<A: SharedArcContent>(pointer: *const ()) {
         // Retype the type-erased pointer.
         let pointer = pointer as *const *const A;
 
         // SAFETY: we are already requiring `pointer` to point to a slot in the redirect array.
-        let raw = unsafe { *pointer };
+        let raw: *const A = unsafe { *pointer };
         // SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`.
         unsafe { Arc::decrement_strong_count(raw) };
     }
 
-    /// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
+    /// Like `wake_by_ref` but consumes the Waker.
+    /// The pointer must satisfy the safety constraints listed in the code comments above.
     unsafe fn wake<A: SharedArcContent>(pointer: *const ()) {
         // SAFETY: we are already requiring the constraints of `wake_by_ref` and `drop_waker`.
         unsafe {
@@ -124,13 +124,27 @@ pub(super) unsafe fn waker_from_redirect_position<A: SharedArcContent>(
 
     fn create_vtable<A: SharedArcContent>() -> &'static RawWakerVTable {
         &RawWakerVTable::new(
-            create_waker::<A>,
+            clone_waker::<A>,
             wake::<A>,
             wake_by_ref::<A>,
             drop_waker::<A>,
         )
     }
+
+    let redirect_slice: &[*const A] = arc.get_redirect_slice();
+
+    debug_assert!(redirect_slice.len() > index);
+
+    // SAFETY: we are already requiring that index be in bound of the slice.
+    let pointer: *const *const A = unsafe { redirect_slice.as_ptr().add(index) };
+    // Type-erase the pointer because the Waker methods expect so.
+    let pointer = pointer as *const ();
+
+    // We want to transfer management of the one strong count associated with `arc` to the Waker we're creating.
+    // That count should only be decremented when the Waker is dropped (by `drop_waker`).
+    core::mem::forget(arc);
+
     // SAFETY: All our vtable functions adhere to the RawWakerVTable contract,
     // and we are already requiring that `pointer` is what our functions expect.
-    unsafe { Waker::from_raw(create_waker::<A>(pointer as *const ())) }
+    unsafe { Waker::from_raw(RawWaker::new(pointer, create_vtable::<A>())) }
 }
diff --git a/src/utils/wakers/vec/waker_vec.rs b/src/utils/wakers/vec/waker_vec.rs
index e4f993e..a84d9e7 100644
--- a/src/utils/wakers/vec/waker_vec.rs
+++ b/src/utils/wakers/vec/waker_vec.rs
@@ -1,5 +1,5 @@
 use core::task::Waker;
-use std::sync::{Arc, Mutex};
+use std::sync::{Arc, Mutex, Weak};
 
 use super::{
     super::shared_arc::{waker_from_redirect_position, SharedArcContent},
@@ -21,27 +21,18 @@ struct WakerVecInner {
 impl WakerVec {
     /// Create a new instance of `WakerVec`.
     pub(crate) fn new(len: usize) -> Self {
-        let mut inner = Arc::new(WakerVecInner {
-            readiness: Mutex::new(ReadinessVec::new(len)),
-            redirect: Vec::new(),
+        let inner = Arc::new_cyclic(|w| {
+            // `Weak::as_ptr` on a live Weak gives the same thing as `Arc::into_raw`.
+            let raw = Weak::as_ptr(w);
+            WakerVecInner {
+                readiness: Mutex::new(ReadinessVec::new(len)),
+                redirect: vec![raw; len],
+            }
         });
-        let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address.
-
-        // At this point the strong count is 2. Decrement it to 1.
-        // Each time we create/clone a Waker the count will be incremented by 1.
-        // So N Wakers -> count = N+1.
-        unsafe { Arc::decrement_strong_count(raw) }
-
-        // Make redirect all point to the Arc itself.
-        Arc::get_mut(&mut inner).unwrap().redirect = vec![raw; len];
 
         // Now the redirect vec is complete. Time to create the actual Wakers.
-        let wakers = inner
-            .redirect
-            .iter()
-            .map(|data| unsafe {
-                waker_from_redirect_position::<WakerVecInner>(data as *const *const WakerVecInner)
-            })
+        let wakers = (0..len)
+            .map(|i| unsafe { waker_from_redirect_position(Arc::clone(&inner), i) })
             .collect();
 
         Self { inner, wakers }
@@ -56,7 +47,8 @@ impl WakerVec {
     }
 }
 
-impl SharedArcContent for WakerVecInner {
+#[deny(unsafe_op_in_unsafe_fn)]
+unsafe impl SharedArcContent for WakerVecInner {
     fn get_redirect_slice(&self) -> &[*const Self] {
         &self.redirect
     }