Skip to content

Commit

Permalink
Merge pull request #142 from yoshuawuyts/try-join
Browse files Browse the repository at this point in the history
Drop futures as soon as they're done for `{array, vec}::try_join`
  • Loading branch information
yoshuawuyts authored Jun 14, 2023
2 parents 28d6aea + 353c045 commit 919291d
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 83 deletions.
174 changes: 130 additions & 44 deletions src/future/try_join/array.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use super::TryJoin as TryJoinTrait;
use crate::utils::MaybeDone;
use crate::utils::{FutureArray, OutputArray, PollArray, WakerArray};

use core::fmt;
use core::future::{Future, IntoFuture};
use core::pin::Pin;
use core::task::{Context, Poll};
use std::mem::ManuallyDrop;
use std::ops::DerefMut;

use pin_project::pin_project;
use pin_project::{pin_project, pinned_drop};

/// A future which waits for all futures to complete successfully, or abort early on error.
///
Expand All @@ -16,21 +18,63 @@ use pin_project::pin_project;
/// [`try_join`]: crate::future::TryJoin::try_join
/// [`TryJoin`]: crate::future::TryJoin
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[pin_project]
#[pin_project(PinnedDrop)]
pub struct TryJoin<Fut, T, E, const N: usize>
where
Fut: Future<Output = Result<T, E>>,
{
elems: [MaybeDone<Fut>; N],
/// A boolean which holds whether the future has completed
consumed: bool,
/// The number of futures which are currently still in-flight
pending: usize,
/// The output data, to be returned after the future completes
items: OutputArray<T, N>,
/// A structure holding the waker passed to the future, and the various
/// sub-wakers passed to the contained futures.
wakers: WakerArray<N>,
/// The individual poll state of each future.
state: PollArray<N>,
#[pin]
/// The array of futures passed to the structure.
futures: FutureArray<Fut, N>,
}

impl<Fut, T, E, const N: usize> TryJoin<Fut, T, E, N>
where
Fut: Future<Output = Result<T, E>>,
{
#[inline]
pub(crate) fn new(futures: [Fut; N]) -> Self {
Self {
consumed: false,
pending: N,
items: OutputArray::uninit(),
wakers: WakerArray::new(),
state: PollArray::new(),
futures: FutureArray::new(futures),
}
}
}

impl<Fut, T, E, const N: usize> TryJoinTrait for [Fut; N]
where
Fut: IntoFuture<Output = Result<T, E>>,
{
type Output = [T; N];
type Error = E;
type Future = TryJoin<Fut::IntoFuture, T, E, N>;

fn try_join(self) -> Self::Future {
TryJoin::new(self.map(IntoFuture::into_future))
}
}

impl<Fut, T, E, const N: usize> fmt::Debug for TryJoin<Fut, T, E, N>
where
Fut: Future<Output = Result<T, E>> + fmt::Debug,
Fut::Output: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.elems.iter()).finish()
f.debug_list().entries(self.state.iter()).finish()
}
}

Expand All @@ -40,60 +84,102 @@ where
{
type Output = Result<[T; N], E>;

#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut all_done = true;

let this = self.project();

for elem in this.elems.iter_mut() {
// SAFETY: we don't ever move the pinned container here; we only pin project
let mut elem = unsafe { Pin::new_unchecked(elem) };
if elem.as_mut().poll(cx).is_pending() {
all_done = false
} else if let Some(err) = elem.take_err() {
return Poll::Ready(Err(err));
}
assert!(
!*this.consumed,
"Futures must not be polled after completing"
);

let mut readiness = this.wakers.readiness().lock().unwrap();
readiness.set_waker(cx.waker());
if !readiness.any_ready() {
// Nothing is ready yet
return Poll::Pending;
}

if all_done {
use core::array;
use core::mem::MaybeUninit;

// Create the result array based on the indices
// TODO: replace with `MaybeUninit::uninit_array()` when it becomes stable
let mut out: [_; N] = array::from_fn(|_| MaybeUninit::uninit());

// NOTE: this clippy attribute can be removed once we can `collect` into `[usize; K]`.
#[allow(clippy::needless_range_loop)]
for (i, el) in this.elems.iter_mut().enumerate() {
// SAFETY: we don't ever move the pinned container here; we only pin project
let pin = unsafe { Pin::new_unchecked(el) };
match pin.take_ok() {
Some(el) => out[i] = MaybeUninit::new(el),
// All futures are done and we iterate only once to take them so this is not
// reachable
None => unreachable!(),
// Poll all ready futures
for (i, mut fut) in this.futures.iter().enumerate() {
if this.state[i].is_pending() && readiness.clear_ready(i) {
// unlock readiness so we don't deadlock when polling
drop(readiness);

// Obtain the intermediate waker.
let mut cx = Context::from_waker(this.wakers.get(i).unwrap());

// Poll the future
// SAFETY: the future's state was "pending", so it's safe to poll
if let Poll::Ready(value) = unsafe {
fut.as_mut()
.map_unchecked_mut(|t| t.deref_mut())
.poll(&mut cx)
} {
this.state[i].set_ready();
*this.pending -= 1;
// SAFETY: the future state has been changed to "ready" which
// means we'll no longer poll the future, so it's safe to drop
unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) };

// Check the value, short-circuit on error.
match value {
Ok(value) => this.items.write(i, value),
Err(err) => {
// The future should no longer be polled after we're done here
*this.consumed = true;
return Poll::Ready(Err(err));
}
}
}

// Lock readiness so we can use it again
readiness = this.wakers.readiness().lock().unwrap();
}
let result = unsafe { out.as_ptr().cast::<[T; N]>().read() };
Poll::Ready(Ok(result))
}

// Check whether we're all done now or need to keep going.
if *this.pending == 0 {
// Mark all data as "consumed" before we take it
*this.consumed = true;
for state in this.state.iter_mut() {
debug_assert!(
state.is_ready(),
"Future should have reached a `Ready` state"
);
state.set_consumed();
}

// SAFETY: we've checked with the state that all of our outputs have been
// filled, which means we're ready to take the data and assume it's initialized.
Poll::Ready(Ok(unsafe { this.items.take() }))
} else {
Poll::Pending
}
}
}

impl<Fut, T, E, const N: usize> TryJoinTrait for [Fut; N]
/// Drop the already initialized values on cancellation.
#[pinned_drop]
impl<Fut, T, E, const N: usize> PinnedDrop for TryJoin<Fut, T, E, N>
where
Fut: IntoFuture<Output = Result<T, E>>,
Fut: Future<Output = Result<T, E>>,
{
type Output = [T; N];
type Error = E;
type Future = TryJoin<Fut::IntoFuture, T, E, N>;
fn drop(self: Pin<&mut Self>) {
let mut this = self.project();

// Drop all initialized values.
for i in this.state.ready_indexes() {
// SAFETY: we've just filtered down to *only* the initialized values.
// We can assume they're initialized, and this is where we drop them.
unsafe { this.items.drop(i) };
}

fn try_join(self) -> Self::Future {
TryJoin {
elems: self.map(|fut| MaybeDone::new(fut.into_future())),
// Drop all pending futures.
for i in this.state.pending_indexes() {
// SAFETY: we've just filtered down to *only* the pending futures,
// which have not yet been dropped.
unsafe { this.futures.as_mut().drop(i) };
}
}
}
Expand Down
Loading

0 comments on commit 919291d

Please sign in to comment.