From 33dfa8871e11c3093c94ed5460baf502f592870a Mon Sep 17 00:00:00 2001 From: Yosh Date: Tue, 13 Jun 2023 00:39:48 +0200 Subject: [PATCH 1/2] array::try_join - drop futures early --- src/future/try_join/array.rs | 174 ++++++++++++++++++++++++++--------- 1 file changed, 130 insertions(+), 44 deletions(-) diff --git a/src/future/try_join/array.rs b/src/future/try_join/array.rs index b3cbd39..7440b65 100644 --- a/src/future/try_join/array.rs +++ b/src/future/try_join/array.rs @@ -1,12 +1,14 @@ use super::TryJoin as TryJoinTrait; -use crate::utils::MaybeDone; +use crate::utils::{FutureArray, OutputArray, PollArray, WakerArray}; use core::fmt; use core::future::{Future, IntoFuture}; use core::pin::Pin; use core::task::{Context, Poll}; +use std::mem::ManuallyDrop; +use std::ops::DerefMut; -use pin_project::pin_project; +use pin_project::{pin_project, pinned_drop}; /// A future which waits for all futures to complete successfully, or abort early on error. /// @@ -16,21 +18,63 @@ use pin_project::pin_project; /// [`try_join`]: crate::future::TryJoin::try_join /// [`TryJoin`]: crate::future::TryJoin #[must_use = "futures do nothing unless you `.await` or poll them"] -#[pin_project] +#[pin_project(PinnedDrop)] pub struct TryJoin where Fut: Future>, { - elems: [MaybeDone; N], + /// A boolean which holds whether the future has completed + consumed: bool, + /// The number of futures which are currently still in-flight + pending: usize, + /// The output data, to be returned after the future completes + items: OutputArray, + /// A structure holding the waker passed to the future, and the various + /// sub-wakers passed to the contained futures. + wakers: WakerArray, + /// The individual poll state of each future. + state: PollArray, + #[pin] + /// The array of futures passed to the structure. + futures: FutureArray, +} + +impl TryJoin +where + Fut: Future>, +{ + #[inline] + pub(crate) fn new(futures: [Fut; N]) -> Self { + Self { + consumed: false, + pending: N, + items: OutputArray::uninit(), + wakers: WakerArray::new(), + state: PollArray::new(), + futures: FutureArray::new(futures), + } + } +} + +impl TryJoinTrait for [Fut; N] +where + Fut: IntoFuture>, +{ + type Output = [T; N]; + type Error = E; + type Future = TryJoin; + + fn try_join(self) -> Self::Future { + TryJoin::new(self.map(IntoFuture::into_future)) + } } impl fmt::Debug for TryJoin where Fut: Future> + fmt::Debug, - Fut::Output: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.elems.iter()).finish() + f.debug_list().entries(self.state.iter()).finish() } } @@ -40,60 +84,102 @@ where { type Output = Result<[T; N], E>; + #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut all_done = true; - let this = self.project(); - for elem in this.elems.iter_mut() { - // SAFETY: we don't ever move the pinned container here; we only pin project - let mut elem = unsafe { Pin::new_unchecked(elem) }; - if elem.as_mut().poll(cx).is_pending() { - all_done = false - } else if let Some(err) = elem.take_err() { - return Poll::Ready(Err(err)); - } + assert!( + !*this.consumed, + "Futures must not be polled after completing" + ); + + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; } - if all_done { - use core::array; - use core::mem::MaybeUninit; - - // Create the result array based on the indices - // TODO: replace with `MaybeUninit::uninit_array()` when it becomes stable - let mut out: [_; N] = array::from_fn(|_| MaybeUninit::uninit()); - - // NOTE: this clippy attribute can be removed once we can `collect` into `[usize; K]`. - #[allow(clippy::needless_range_loop)] - for (i, el) in this.elems.iter_mut().enumerate() { - // SAFETY: we don't ever move the pinned container here; we only pin project - let pin = unsafe { Pin::new_unchecked(el) }; - match pin.take_ok() { - Some(el) => out[i] = MaybeUninit::new(el), - // All futures are done and we iterate only once to take them so this is not - // reachable - None => unreachable!(), + // Poll all ready futures + for (i, mut fut) in this.futures.iter().enumerate() { + if this.state[i].is_pending() && readiness.clear_ready(i) { + // unlock readiness so we don't deadlock when polling + drop(readiness); + + // Obtain the intermediate waker. + let mut cx = Context::from_waker(this.wakers.get(i).unwrap()); + + // Poll the future + // SAFETY: the future's state was "pending", so it's safe to poll + if let Poll::Ready(value) = unsafe { + fut.as_mut() + .map_unchecked_mut(|t| t.deref_mut()) + .poll(&mut cx) + } { + this.state[i].set_ready(); + *this.pending -= 1; + // SAFETY: the future state has been changed to "ready" which + // means we'll no longer poll the future, so it's safe to drop + unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) }; + + // Check the value, short-circuit on error. + match value { + Ok(value) => this.items.write(i, value), + Err(err) => { + // The future should no longer be polled after we're done here + *this.consumed = true; + return Poll::Ready(Err(err)); + } + } } + + // Lock readiness so we can use it again + readiness = this.wakers.readiness().lock().unwrap(); } - let result = unsafe { out.as_ptr().cast::<[T; N]>().read() }; - Poll::Ready(Ok(result)) + } + + // Check whether we're all done now or need to keep going. + if *this.pending == 0 { + // Mark all data as "consumed" before we take it + *this.consumed = true; + for state in this.state.iter_mut() { + debug_assert!( + state.is_ready(), + "Future should have reached a `Ready` state" + ); + state.set_consumed(); + } + + // SAFETY: we've checked with the state that all of our outputs have been + // filled, which means we're ready to take the data and assume it's initialized. + Poll::Ready(Ok(unsafe { this.items.take() })) } else { Poll::Pending } } } -impl TryJoinTrait for [Fut; N] +/// Drop the already initialized values on cancellation. +#[pinned_drop] +impl PinnedDrop for TryJoin where - Fut: IntoFuture>, + Fut: Future>, { - type Output = [T; N]; - type Error = E; - type Future = TryJoin; + fn drop(self: Pin<&mut Self>) { + let mut this = self.project(); + + // Drop all initialized values. + for i in this.state.ready_indexes() { + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { this.items.drop(i) }; + } - fn try_join(self) -> Self::Future { - TryJoin { - elems: self.map(|fut| MaybeDone::new(fut.into_future())), + // Drop all pending futures. + for i in this.state.pending_indexes() { + // SAFETY: we've just filtered down to *only* the pending futures, + // which have not yet been dropped. + unsafe { this.futures.as_mut().drop(i) }; } } } From 353c045c87e9787467331117011586b290974960 Mon Sep 17 00:00:00 2001 From: Yosh Date: Tue, 13 Jun 2023 00:43:23 +0200 Subject: [PATCH 2/2] `vec::try_join` - drop futures early --- src/future/try_join/vec.rs | 175 ++++++++++++++++++++++++++++--------- 1 file changed, 136 insertions(+), 39 deletions(-) diff --git a/src/future/try_join/vec.rs b/src/future/try_join/vec.rs index 9f3a133..07207d9 100644 --- a/src/future/try_join/vec.rs +++ b/src/future/try_join/vec.rs @@ -1,14 +1,14 @@ use super::TryJoin as TryJoinTrait; -use crate::utils::iter_pin_mut; -use crate::utils::MaybeDone; +use crate::utils::{FutureVec, OutputVec, PollVec, WakerVec}; use core::fmt; use core::future::{Future, IntoFuture}; -use core::mem; use core::pin::Pin; use core::task::{Context, Poll}; -use std::boxed::Box; -use std::vec::Vec; +use std::mem::ManuallyDrop; +use std::ops::DerefMut; + +use pin_project::{pin_project, pinned_drop}; /// A future which waits for all futures to complete successfully, or abort early on error. /// @@ -18,20 +18,64 @@ use std::vec::Vec; /// [`try_join`]: crate::future::TryJoin::try_join /// [`TryJoin`]: crate::future::TryJoin #[must_use = "futures do nothing unless you `.await` or poll them"] +#[pin_project(PinnedDrop)] pub struct TryJoin where Fut: Future>, { - elems: Pin]>>, + /// A boolean which holds whether the future has completed + consumed: bool, + /// The number of futures which are currently still in-flight + pending: usize, + /// The output data, to be returned after the future completes + items: OutputVec, + /// A structure holding the waker passed to the future, and the various + /// sub-wakers passed to the contained futures. + wakers: WakerVec, + /// The individual poll state of each future. + state: PollVec, + #[pin] + /// The array of futures passed to the structure. + futures: FutureVec, +} + +impl TryJoin +where + Fut: Future>, +{ + #[inline] + pub(crate) fn new(futures: Vec) -> Self { + let len = futures.len(); + Self { + consumed: false, + pending: len, + items: OutputVec::uninit(len), + wakers: WakerVec::new(len), + state: PollVec::new(len), + futures: FutureVec::new(futures), + } + } +} + +impl TryJoinTrait for Vec +where + Fut: IntoFuture>, +{ + type Output = Vec; + type Error = E; + type Future = TryJoin; + + fn try_join(self) -> Self::Future { + TryJoin::new(self.into_iter().map(IntoFuture::into_future).collect()) + } } impl fmt::Debug for TryJoin where Fut: Future> + fmt::Debug, - Fut::Output: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.elems.iter()).finish() + f.debug_list().entries(self.state.iter()).finish() } } @@ -41,49 +85,102 @@ where { type Output = Result, E>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut all_done = true; + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + assert!( + !*this.consumed, + "Futures must not be polled after completing" + ); + + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; + } + + // Poll all ready futures + for (i, mut fut) in this.futures.iter().enumerate() { + if this.state[i].is_pending() && readiness.clear_ready(i) { + // unlock readiness so we don't deadlock when polling + drop(readiness); - for mut elem in iter_pin_mut(self.elems.as_mut()) { - if elem.as_mut().poll(cx).is_pending() { - all_done = false - } else if let Some(err) = elem.take_err() { - return Poll::Ready(Err(err)); + // Obtain the intermediate waker. + let mut cx = Context::from_waker(this.wakers.get(i).unwrap()); + + // Poll the future + // SAFETY: the future's state was "pending", so it's safe to poll + if let Poll::Ready(value) = unsafe { + fut.as_mut() + .map_unchecked_mut(|t| t.deref_mut()) + .poll(&mut cx) + } { + this.state[i].set_ready(); + *this.pending -= 1; + // SAFETY: the future state has been changed to "ready" which + // means we'll no longer poll the future, so it's safe to drop + unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) }; + + // Check the value, short-circuit on error. + match value { + Ok(value) => this.items.write(i, value), + Err(err) => { + // The future should no longer be polled after we're done here + *this.consumed = true; + return Poll::Ready(Err(err)); + } + } + } + + // Lock readiness so we can use it again + readiness = this.wakers.readiness().lock().unwrap(); } } - if all_done { - let mut elems = mem::replace(&mut self.elems, Box::pin([])); - let result = iter_pin_mut(elems.as_mut()) - .map(|e| match e.take_ok() { - Some(output) => output, - // Since all futures are done and we reached here, it means none returned an - // `Err` and so this is unreachable. - None => unreachable!(), - }) - .collect(); - Poll::Ready(Ok(result)) + // Check whether we're all done now or need to keep going. + if *this.pending == 0 { + // Mark all data as "consumed" before we take it + *this.consumed = true; + for state in this.state.iter_mut() { + debug_assert!( + state.is_ready(), + "Future should have reached a `Ready` state" + ); + state.set_consumed(); + } + + // SAFETY: we've checked with the state that all of our outputs have been + // filled, which means we're ready to take the data and assume it's initialized. + Poll::Ready(Ok(unsafe { this.items.take() })) } else { Poll::Pending } } } -impl TryJoinTrait for Vec +/// Drop the already initialized values on cancellation. +#[pinned_drop] +impl PinnedDrop for TryJoin where - Fut: IntoFuture>, + Fut: Future>, { - type Output = Vec; - type Error = E; - type Future = TryJoin; + fn drop(self: Pin<&mut Self>) { + let mut this = self.project(); - fn try_join(self) -> Self::Future { - let elems: Box<[_]> = self - .into_iter() - .map(|fut| MaybeDone::new(fut.into_future())) - .collect(); - TryJoin { - elems: elems.into(), + // Drop all initialized values. + for i in this.state.ready_indexes() { + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { this.items.drop(i) }; + } + + // Drop all pending futures. + for i in this.state.pending_indexes() { + // SAFETY: we've just filtered down to *only* the pending futures, + // which have not yet been dropped. + unsafe { this.futures.as_mut().drop(i) }; } } } @@ -100,7 +197,7 @@ mod test { let res: io::Result<_> = vec![future::ready(Ok("hello")), future::ready(Ok("world"))] .try_join() .await; - assert_eq!(res.unwrap(), vec!["hello", "world"]); + assert_eq!(res.unwrap(), ["hello", "world"]); }) }