Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop futures as soon as they're done for {array, vec}::try_join #142

Merged
merged 2 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 130 additions & 44 deletions src/future/try_join/array.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use super::TryJoin as TryJoinTrait;
use crate::utils::MaybeDone;
use crate::utils::{FutureArray, OutputArray, PollArray, WakerArray};

use core::fmt;
use core::future::{Future, IntoFuture};
use core::pin::Pin;
use core::task::{Context, Poll};
use std::mem::ManuallyDrop;
use std::ops::DerefMut;

use pin_project::pin_project;
use pin_project::{pin_project, pinned_drop};

/// A future which waits for all futures to complete successfully, or abort early on error.
///
Expand All @@ -16,21 +18,63 @@ use pin_project::pin_project;
/// [`try_join`]: crate::future::TryJoin::try_join
/// [`TryJoin`]: crate::future::TryJoin
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[pin_project]
#[pin_project(PinnedDrop)]
pub struct TryJoin<Fut, T, E, const N: usize>
where
Fut: Future<Output = Result<T, E>>,
{
elems: [MaybeDone<Fut>; N],
/// A boolean which holds whether the future has completed
consumed: bool,
/// The number of futures which are currently still in-flight
pending: usize,
/// The output data, to be returned after the future completes
items: OutputArray<T, N>,
/// A structure holding the waker passed to the future, and the various
/// sub-wakers passed to the contained futures.
wakers: WakerArray<N>,
/// The individual poll state of each future.
state: PollArray<N>,
#[pin]
/// The array of futures passed to the structure.
futures: FutureArray<Fut, N>,
}

impl<Fut, T, E, const N: usize> TryJoin<Fut, T, E, N>
where
Fut: Future<Output = Result<T, E>>,
{
#[inline]
pub(crate) fn new(futures: [Fut; N]) -> Self {
Self {
consumed: false,
pending: N,
items: OutputArray::uninit(),
wakers: WakerArray::new(),
state: PollArray::new(),
futures: FutureArray::new(futures),
}
}
}

impl<Fut, T, E, const N: usize> TryJoinTrait for [Fut; N]
where
Fut: IntoFuture<Output = Result<T, E>>,
{
type Output = [T; N];
type Error = E;
type Future = TryJoin<Fut::IntoFuture, T, E, N>;

fn try_join(self) -> Self::Future {
TryJoin::new(self.map(IntoFuture::into_future))
}
}

impl<Fut, T, E, const N: usize> fmt::Debug for TryJoin<Fut, T, E, N>
where
Fut: Future<Output = Result<T, E>> + fmt::Debug,
Fut::Output: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.elems.iter()).finish()
f.debug_list().entries(self.state.iter()).finish()
}
}

Expand All @@ -40,60 +84,102 @@ where
{
type Output = Result<[T; N], E>;

#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut all_done = true;

let this = self.project();

for elem in this.elems.iter_mut() {
// SAFETY: we don't ever move the pinned container here; we only pin project
let mut elem = unsafe { Pin::new_unchecked(elem) };
if elem.as_mut().poll(cx).is_pending() {
all_done = false
} else if let Some(err) = elem.take_err() {
return Poll::Ready(Err(err));
}
assert!(
!*this.consumed,
"Futures must not be polled after completing"
);

let mut readiness = this.wakers.readiness().lock().unwrap();
readiness.set_waker(cx.waker());
if !readiness.any_ready() {
// Nothing is ready yet
return Poll::Pending;
}

if all_done {
use core::array;
use core::mem::MaybeUninit;

// Create the result array based on the indices
// TODO: replace with `MaybeUninit::uninit_array()` when it becomes stable
let mut out: [_; N] = array::from_fn(|_| MaybeUninit::uninit());

// NOTE: this clippy attribute can be removed once we can `collect` into `[usize; K]`.
#[allow(clippy::needless_range_loop)]
for (i, el) in this.elems.iter_mut().enumerate() {
// SAFETY: we don't ever move the pinned container here; we only pin project
let pin = unsafe { Pin::new_unchecked(el) };
match pin.take_ok() {
Some(el) => out[i] = MaybeUninit::new(el),
// All futures are done and we iterate only once to take them so this is not
// reachable
None => unreachable!(),
// Poll all ready futures
for (i, mut fut) in this.futures.iter().enumerate() {
if this.state[i].is_pending() && readiness.clear_ready(i) {
// unlock readiness so we don't deadlock when polling
drop(readiness);

// Obtain the intermediate waker.
let mut cx = Context::from_waker(this.wakers.get(i).unwrap());

// Poll the future
// SAFETY: the future's state was "pending", so it's safe to poll
if let Poll::Ready(value) = unsafe {
fut.as_mut()
.map_unchecked_mut(|t| t.deref_mut())
.poll(&mut cx)
} {
this.state[i].set_ready();
*this.pending -= 1;
// SAFETY: the future state has been changed to "ready" which
// means we'll no longer poll the future, so it's safe to drop
unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) };

// Check the value, short-circuit on error.
match value {
Ok(value) => this.items.write(i, value),
Err(err) => {
// The future should no longer be polled after we're done here
*this.consumed = true;
return Poll::Ready(Err(err));
}
}
}

// Lock readiness so we can use it again
readiness = this.wakers.readiness().lock().unwrap();
}
let result = unsafe { out.as_ptr().cast::<[T; N]>().read() };
Poll::Ready(Ok(result))
}

// Check whether we're all done now or need to keep going.
if *this.pending == 0 {
// Mark all data as "consumed" before we take it
*this.consumed = true;
for state in this.state.iter_mut() {
debug_assert!(
state.is_ready(),
"Future should have reached a `Ready` state"
);
state.set_consumed();
}

// SAFETY: we've checked with the state that all of our outputs have been
// filled, which means we're ready to take the data and assume it's initialized.
Poll::Ready(Ok(unsafe { this.items.take() }))
} else {
Poll::Pending
}
}
}

impl<Fut, T, E, const N: usize> TryJoinTrait for [Fut; N]
/// Drop the already initialized values on cancellation.
#[pinned_drop]
impl<Fut, T, E, const N: usize> PinnedDrop for TryJoin<Fut, T, E, N>
where
Fut: IntoFuture<Output = Result<T, E>>,
Fut: Future<Output = Result<T, E>>,
{
type Output = [T; N];
type Error = E;
type Future = TryJoin<Fut::IntoFuture, T, E, N>;
fn drop(self: Pin<&mut Self>) {
let mut this = self.project();

// Drop all initialized values.
for i in this.state.ready_indexes() {
// SAFETY: we've just filtered down to *only* the initialized values.
// We can assume they're initialized, and this is where we drop them.
unsafe { this.items.drop(i) };
}

fn try_join(self) -> Self::Future {
TryJoin {
elems: self.map(|fut| MaybeDone::new(fut.into_future())),
// Drop all pending futures.
for i in this.state.pending_indexes() {
// SAFETY: we've just filtered down to *only* the pending futures,
// which have not yet been dropped.
unsafe { this.futures.as_mut().drop(i) };
}
}
}
Expand Down
Loading