diff --git a/src/future/join/array.rs b/src/future/join/array.rs index e96a395..800b222 100644 --- a/src/future/join/array.rs +++ b/src/future/join/array.rs @@ -1,12 +1,13 @@ use super::Join as JoinTrait; -use crate::utils::{self, PollArray, WakerArray}; +use crate::utils::{self, FutureArray, PollArray, WakerArray}; use core::array; use core::fmt; use core::future::{Future, IntoFuture}; -use core::mem::{self, MaybeUninit}; +use core::mem::{self, ManuallyDrop, MaybeUninit}; use core::pin::Pin; use core::task::{Context, Poll}; +use std::ops::DerefMut; use pin_project::{pin_project, pinned_drop}; @@ -23,13 +24,20 @@ pub struct Join where Fut: Future, { + /// A boolean which holds whether the future has completed consumed: bool, + /// The number of futures which are currently still in-flight pending: usize, + /// The output data, to be returned after the future completes items: [MaybeUninit<::Output>; N], + /// A structure holding the waker passed to the future, and the various + /// sub-wakers passed to the contained futures. wakers: WakerArray, + /// The individual poll state of each future. state: PollArray, #[pin] - futures: [Fut; N], + /// The array of futures passed to the structure. + futures: FutureArray, } impl Join @@ -44,7 +52,7 @@ where items: array::from_fn(|_| MaybeUninit::uninit()), wakers: WakerArray::new(), state: PollArray::new(), - futures, + futures: FutureArray::new(futures), } } } @@ -79,7 +87,7 @@ where #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); + let this = self.project(); assert!( !*this.consumed, @@ -94,7 +102,7 @@ where } // Poll all ready futures - for (i, fut) in utils::iter_pin_mut(this.futures.as_mut()).enumerate() { + for (i, mut fut) in this.futures.iter().enumerate() { if this.state[i].is_pending() && readiness.clear_ready(i) { // unlock readiness so we don't deadlock when polling drop(readiness); @@ -102,10 +110,19 @@ where // Obtain the intermediate waker. let mut cx = Context::from_waker(this.wakers.get(i).unwrap()); - if let Poll::Ready(value) = fut.poll(&mut cx) { + // Poll the future + // SAFETY: the future's state was "pending", so it's safe to poll + if let Poll::Ready(value) = unsafe { + fut.as_mut() + .map_unchecked_mut(|t| t.deref_mut()) + .poll(&mut cx) + } { this.items[i] = MaybeUninit::new(value); this.state[i].set_ready(); *this.pending -= 1; + // SAFETY: the future state has been changed to "ready" which + // means we'll no longer poll the future, so it's safe to drop + unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) }; } // Lock readiness so we can use it again @@ -145,9 +162,9 @@ where Fut: Future, { fn drop(self: Pin<&mut Self>) { - let this = self.project(); + let mut this = self.project(); - // Get the indexes of the initialized values. + // Get the indexes of the initialized output values. let indexes = this .state .iter_mut() @@ -161,6 +178,21 @@ where // We can assume they're initialized, and this is where we drop them. unsafe { this.items[i].assume_init_drop() }; } + + // Get the indexes of the pending futures. + let indexes = this + .state + .iter_mut() + .enumerate() + .filter(|(_, state)| state.is_pending()) + .map(|(i, _)| i); + + // Drop each future at the index. + for i in indexes { + // SAFETY: we've just filtered down to *only* the pending futures, + // which have not yet been dropped. + unsafe { this.futures.as_mut().drop(i) }; + } } } diff --git a/src/utils/futures/array.rs b/src/utils/futures/array.rs new file mode 100644 index 0000000..111afd3 --- /dev/null +++ b/src/utils/futures/array.rs @@ -0,0 +1,44 @@ +use std::{ + mem::{self, ManuallyDrop, MaybeUninit}, + pin::Pin, +}; + +/// An array of futures which can be dropped in-place, intended to be +/// constructed once and then accessed through pin projections. +pub(crate) struct FutureArray { + futures: [ManuallyDrop; N], +} + +impl FutureArray { + /// Create a new instance of `FutureArray` + pub(crate) fn new(futures: [T; N]) -> Self { + // Implementation copied from: https://doc.rust-lang.org/src/core/mem/maybe_uninit.rs.html#1292 + let futures = MaybeUninit::new(futures); + // SAFETY: T and MaybeUninit have the same layout + let futures = unsafe { mem::transmute_copy(&mem::ManuallyDrop::new(futures)) }; + Self { futures } + } + + /// Create an iterator of pinned references. + pub(crate) fn iter(self: Pin<&mut Self>) -> impl Iterator>> { + // SAFETY: `std` _could_ make this unsound if it were to decide Pin's + // invariants aren't required to transmit through slices. Otherwise this has + // the same safety as a normal field pin projection. + unsafe { self.get_unchecked_mut() } + .futures + .iter_mut() + .map(|t| unsafe { Pin::new_unchecked(t) }) + } + + /// Drop a future at the given index. + /// + /// # Safety + /// + /// The future is held in a `ManuallyDrop`, so no double-dropping, etc + pub(crate) unsafe fn drop(mut self: Pin<&mut Self>, idx: usize) { + unsafe { + let futures = self.as_mut().get_unchecked_mut().futures.as_mut(); + ManuallyDrop::drop(&mut futures[idx]); + }; + } +} diff --git a/src/utils/futures/mod.rs b/src/utils/futures/mod.rs new file mode 100644 index 0000000..8658e34 --- /dev/null +++ b/src/utils/futures/mod.rs @@ -0,0 +1,3 @@ +mod array; + +pub(crate) use array::FutureArray; diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ac5d38e..b8c6073 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,12 +1,14 @@ //! Utilities to implement the different futures of this crate. mod array; +mod futures; mod indexer; mod pin; mod poll_state; mod tuple; mod wakers; +pub(crate) use self::futures::FutureArray; pub(crate) use array::array_assume_init; pub(crate) use indexer::Indexer; pub(crate) use pin::{get_pin_mut, get_pin_mut_from_vec, iter_pin_mut, iter_pin_mut_vec};