From 9c4fcf982956e9111f751fbfb62ec21fce96e6b1 Mon Sep 17 00:00:00 2001 From: Yoshua Wuyts <2467194+yoshuawuyts@users.noreply.github.com> Date: Mon, 14 Nov 2022 16:46:51 +0100 Subject: [PATCH 1/2] Remove `MaybeDone` for `array::join` --- src/future/join/array.rs | 142 +++++++++++++++++++++++++++++---------- src/utils/array.rs | 22 ++++++ src/utils/mod.rs | 2 + src/utils/slice_ext.rs | 0 4 files changed, 130 insertions(+), 36 deletions(-) create mode 100644 src/utils/array.rs delete mode 100644 src/utils/slice_ext.rs diff --git a/src/future/join/array.rs b/src/future/join/array.rs index 5570cfe..8f05b3d 100644 --- a/src/future/join/array.rs +++ b/src/future/join/array.rs @@ -1,12 +1,15 @@ use super::Join as JoinTrait; -use crate::utils::MaybeDone; +use crate::utils; +use crate::utils::PollState; +use core::array; use core::fmt; use core::future::{Future, IntoFuture}; +use core::mem::{self, MaybeUninit}; use core::pin::Pin; use core::task::{Context, Poll}; -use pin_project::pin_project; +use pin_project::{pin_project, pinned_drop}; /// Waits for two similarly-typed futures to complete. /// @@ -16,12 +19,44 @@ use pin_project::pin_project; /// [`join`]: crate::future::Join::join /// [`Join`]: crate::future::Join #[must_use = "futures do nothing unless you `.await` or poll them"] -#[pin_project] +#[pin_project(PinnedDrop)] pub struct Join where Fut: Future, { - pub(crate) elems: [MaybeDone; N], + consumed: bool, + pending: usize, + items: [MaybeUninit<::Output>; N], + state: [PollState; N], + #[pin] + futures: [Fut; N], +} + +impl Join +where + Fut: Future, +{ + pub(crate) fn new(futures: [Fut; N]) -> Self { + Join { + consumed: false, + pending: N, + items: array::from_fn(|_| MaybeUninit::uninit()), + state: [PollState::default(); N], + futures, + } + } +} + +impl JoinTrait for [Fut; N] +where + Fut: IntoFuture, +{ + type Output = [Fut::Output; N]; + type Future = Join; + + fn join(self) -> Self::Future { + Join::new(self.map(IntoFuture::into_future)) + } } impl fmt::Debug for Join @@ -30,7 +65,7 @@ where Fut::Output: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.elems.iter()).finish() + f.debug_list().entries(self.state.iter()).finish() } } @@ -41,48 +76,68 @@ where type Output = [Fut::Output; N]; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut all_done = true; + let mut this = self.project(); - let this = self.project(); + assert!( + !*this.consumed, + "Futures must not be polled after completing" + ); - for elem in this.elems.iter_mut() { - let elem = unsafe { Pin::new_unchecked(elem) }; - if elem.poll(cx).is_pending() { - all_done = false; + // Poll all futures + for (i, fut) in utils::iter_pin_mut(this.futures.as_mut()).enumerate() { + if this.state[i].is_pending() { + if let Poll::Ready(value) = fut.poll(cx) { + this.items[i] = MaybeUninit::new(value); + this.state[i] = PollState::Done; + *this.pending -= 1; + } } } - if all_done { - use core::array; - use core::mem::MaybeUninit; + // Check whether we're all done now or need to keep going. + if *this.pending == 0 { + // Mark all data as "consumed" before we take it + *this.consumed = true; + for state in this.state.iter_mut() { + debug_assert!(state.is_done(), "Future should have reached a `Done` state"); + *state = PollState::Consumed; + } - // Create the result array based on the indices - // TODO: replace with `MaybeUninit::uninit_array()` when it becomes stable - let mut out: [_; N] = array::from_fn(|_| MaybeUninit::uninit()); + let mut items = array::from_fn(|_| MaybeUninit::uninit()); + mem::swap(this.items, &mut items); - // NOTE: this clippy attribute can be removed once we can `collect` into `[usize; K]`. - #[allow(clippy::needless_range_loop)] - for (i, el) in this.elems.iter_mut().enumerate() { - let el = unsafe { Pin::new_unchecked(el) }.take().unwrap(); - out[i] = MaybeUninit::new(el); - } - let result = unsafe { out.as_ptr().cast::<[Fut::Output; N]>().read() }; - Poll::Ready(result) + // SAFETY: we've checked with the state that all of our outputs have been + // filled, which means we're ready to take the data and assume it's initialized. + let items = unsafe { utils::array_assume_init(items) }; + Poll::Ready(items) } else { Poll::Pending } } } -impl JoinTrait for [Fut; N] +/// Drop the already initialized values on cancellation. +#[pinned_drop] +impl PinnedDrop for Join where - Fut: IntoFuture, + Fut: Future, { - type Output = [Fut::Output; N]; - type Future = Join; - fn join(self) -> Self::Future { - Join { - elems: self.map(|fut| MaybeDone::new(fut.into_future())), + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + + // Get the indexes of the initialized values. + let indexes = this + .state + .iter_mut() + .enumerate() + .filter(|(_, state)| state.is_done()) + .map(|(i, _)| i); + + // Drop each value at the index. + for i in indexes { + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { this.items[i].assume_init_drop() }; } } } @@ -90,15 +145,30 @@ where #[cfg(test)] mod test { use super::*; + use crate::utils::DummyWaker; + use std::future; + use std::future::Future; + use std::sync::Arc; + use std::task::Context; #[test] fn smoke() { futures_lite::future::block_on(async { - let res = [future::ready("hello"), future::ready("world")] - .join() - .await; - assert_eq!(res, ["hello", "world"]); + let fut = [future::ready("hello"), future::ready("world")].join(); + assert_eq!(fut.await, ["hello", "world"]); }); } + + #[test] + fn debug() { + let mut fut = [future::ready("hello"), future::ready("world")].join(); + assert_eq!(format!("{:?}", fut), "[Pending, Pending]"); + let mut fut = Pin::new(&mut fut); + + let waker = Arc::new(DummyWaker()).into(); + let mut cx = Context::from_waker(&waker); + let _ = fut.as_mut().poll(&mut cx); + assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]"); + } } diff --git a/src/utils/array.rs b/src/utils/array.rs new file mode 100644 index 0000000..27be090 --- /dev/null +++ b/src/utils/array.rs @@ -0,0 +1,22 @@ +use std::mem::{self, MaybeUninit}; + +/// Extracts the values from an array of `MaybeUninit` containers. +/// +/// # Safety +/// +/// It is up to the caller to guarantee that all elements of the array are +/// in an initialized state. +/// +/// Inlined version of: https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.array_assume_init +pub(crate) unsafe fn array_assume_init(array: [MaybeUninit; N]) -> [T; N] { + // SAFETY: + // * The caller guarantees that all elements of the array are initialized + // * `MaybeUninit` and T are guaranteed to have the same layout + // * `MaybeUninit` does not drop, so there are no double-frees + // And thus the conversion is safe + let ret = unsafe { (&array as *const _ as *const [T; N]).read() }; + + // FIXME: required to avoid `~const Destruct` bound + mem::forget(array); + ret +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 6cccb0b..0102d5f 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -5,6 +5,7 @@ //! This implementation was taken from the original `macro_rules` `join/try_join` //! macros in the `futures-preview` crate. +mod array; mod fuse; mod maybe_done; mod pin; @@ -13,6 +14,7 @@ mod rng; mod tuple; mod waker; +pub(crate) use array::array_assume_init; pub(crate) use fuse::Fuse; pub(crate) use maybe_done::MaybeDone; pub(crate) use pin::{get_pin_mut, get_pin_mut_from_vec, iter_pin_mut, iter_pin_mut_vec}; diff --git a/src/utils/slice_ext.rs b/src/utils/slice_ext.rs deleted file mode 100644 index e69de29..0000000 From 8e84401d1666419d889a10c93a1dc208bf1a35bd Mon Sep 17 00:00:00 2001 From: Yoshua Wuyts <2467194+yoshuawuyts@users.noreply.github.com> Date: Wed, 16 Nov 2022 13:52:24 +0100 Subject: [PATCH 2/2] add feedback from review --- src/future/join/array.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/future/join/array.rs b/src/future/join/array.rs index 8f05b3d..5453d15 100644 --- a/src/future/join/array.rs +++ b/src/future/join/array.rs @@ -36,6 +36,7 @@ impl Join where Fut: Future, { + #[inline] pub(crate) fn new(futures: [Fut; N]) -> Self { Join { consumed: false, @@ -54,6 +55,7 @@ where type Output = [Fut::Output; N]; type Future = Join; + #[inline] fn join(self) -> Self::Future { Join::new(self.map(IntoFuture::into_future)) } @@ -75,6 +77,7 @@ where { type Output = [Fut::Output; N]; + #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project();