Skip to content

Commit

Permalink
Merge pull request #138 from yoshuawuyts/join-array-drop-sooner
Browse files Browse the repository at this point in the history
Drop futures as soon as they're done for `array::join`
  • Loading branch information
yoshuawuyts authored Jun 12, 2023
2 parents ad1764e + 0bd7d99 commit 4264e6d
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 9 deletions.
50 changes: 41 additions & 9 deletions src/future/join/array.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use super::Join as JoinTrait;
use crate::utils::{self, PollArray, WakerArray};
use crate::utils::{self, FutureArray, PollArray, WakerArray};

use core::array;
use core::fmt;
use core::future::{Future, IntoFuture};
use core::mem::{self, MaybeUninit};
use core::mem::{self, ManuallyDrop, MaybeUninit};
use core::pin::Pin;
use core::task::{Context, Poll};
use std::ops::DerefMut;

use pin_project::{pin_project, pinned_drop};

Expand All @@ -23,13 +24,20 @@ pub struct Join<Fut, const N: usize>
where
Fut: Future,
{
/// A boolean which holds whether the future has completed
consumed: bool,
/// The number of futures which are currently still in-flight
pending: usize,
/// The output data, to be returned after the future completes
items: [MaybeUninit<<Fut as Future>::Output>; N],
/// A structure holding the waker passed to the future, and the various
/// sub-wakers passed to the contained futures.
wakers: WakerArray<N>,
/// The individual poll state of each future.
state: PollArray<N>,
#[pin]
futures: [Fut; N],
/// The array of futures passed to the structure.
futures: FutureArray<Fut, N>,
}

impl<Fut, const N: usize> Join<Fut, N>
Expand All @@ -44,7 +52,7 @@ where
items: array::from_fn(|_| MaybeUninit::uninit()),
wakers: WakerArray::new(),
state: PollArray::new(),
futures,
futures: FutureArray::new(futures),
}
}
}
Expand Down Expand Up @@ -79,7 +87,7 @@ where

#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
let this = self.project();

assert!(
!*this.consumed,
Expand All @@ -94,18 +102,27 @@ where
}

// Poll all ready futures
for (i, fut) in utils::iter_pin_mut(this.futures.as_mut()).enumerate() {
for (i, mut fut) in this.futures.iter().enumerate() {
if this.state[i].is_pending() && readiness.clear_ready(i) {
// unlock readiness so we don't deadlock when polling
drop(readiness);

// Obtain the intermediate waker.
let mut cx = Context::from_waker(this.wakers.get(i).unwrap());

if let Poll::Ready(value) = fut.poll(&mut cx) {
// Poll the future
// SAFETY: the future's state was "pending", so it's safe to poll
if let Poll::Ready(value) = unsafe {
fut.as_mut()
.map_unchecked_mut(|t| t.deref_mut())
.poll(&mut cx)
} {
this.items[i] = MaybeUninit::new(value);
this.state[i].set_ready();
*this.pending -= 1;
// SAFETY: the future state has been changed to "ready" which
// means we'll no longer poll the future, so it's safe to drop
unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) };
}

// Lock readiness so we can use it again
Expand Down Expand Up @@ -145,9 +162,9 @@ where
Fut: Future,
{
fn drop(self: Pin<&mut Self>) {
let this = self.project();
let mut this = self.project();

// Get the indexes of the initialized values.
// Get the indexes of the initialized output values.
let indexes = this
.state
.iter_mut()
Expand All @@ -161,6 +178,21 @@ where
// We can assume they're initialized, and this is where we drop them.
unsafe { this.items[i].assume_init_drop() };
}

// Get the indexes of the pending futures.
let indexes = this
.state
.iter_mut()
.enumerate()
.filter(|(_, state)| state.is_pending())
.map(|(i, _)| i);

// Drop each future at the index.
for i in indexes {
// SAFETY: we've just filtered down to *only* the pending futures,
// which have not yet been dropped.
unsafe { this.futures.as_mut().drop(i) };
}
}
}

Expand Down
44 changes: 44 additions & 0 deletions src/utils/futures/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use std::{
mem::{self, ManuallyDrop, MaybeUninit},
pin::Pin,
};

/// An array of futures which can be dropped in-place, intended to be
/// constructed once and then accessed through pin projections.
pub(crate) struct FutureArray<T, const N: usize> {
futures: [ManuallyDrop<T>; N],
}

impl<T, const N: usize> FutureArray<T, N> {
/// Create a new instance of `FutureArray`
pub(crate) fn new(futures: [T; N]) -> Self {
// Implementation copied from: https://doc.rust-lang.org/src/core/mem/maybe_uninit.rs.html#1292
let futures = MaybeUninit::new(futures);
// SAFETY: T and MaybeUninit<T> have the same layout
let futures = unsafe { mem::transmute_copy(&mem::ManuallyDrop::new(futures)) };
Self { futures }
}

/// Create an iterator of pinned references.
pub(crate) fn iter(self: Pin<&mut Self>) -> impl Iterator<Item = Pin<&mut ManuallyDrop<T>>> {
// SAFETY: `std` _could_ make this unsound if it were to decide Pin's
// invariants aren't required to transmit through slices. Otherwise this has
// the same safety as a normal field pin projection.
unsafe { self.get_unchecked_mut() }
.futures
.iter_mut()
.map(|t| unsafe { Pin::new_unchecked(t) })
}

/// Drop a future at the given index.
///
/// # Safety
///
/// The future is held in a `ManuallyDrop`, so no double-dropping, etc
pub(crate) unsafe fn drop(mut self: Pin<&mut Self>, idx: usize) {
unsafe {
let futures = self.as_mut().get_unchecked_mut().futures.as_mut();
ManuallyDrop::drop(&mut futures[idx]);
};
}
}
3 changes: 3 additions & 0 deletions src/utils/futures/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod array;

pub(crate) use array::FutureArray;
2 changes: 2 additions & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
//! Utilities to implement the different futures of this crate.
mod array;
mod futures;
mod indexer;
mod pin;
mod poll_state;
mod tuple;
mod wakers;

pub(crate) use self::futures::FutureArray;
pub(crate) use array::array_assume_init;
pub(crate) use indexer::Indexer;
pub(crate) use pin::{get_pin_mut, get_pin_mut_from_vec, iter_pin_mut, iter_pin_mut_vec};
Expand Down

0 comments on commit 4264e6d

Please sign in to comment.