diff --git a/src/future/try_join/tuple.rs b/src/future/try_join/tuple.rs index 7214980..36f28e4 100644 --- a/src/future/try_join/tuple.rs +++ b/src/future/try_join/tuple.rs @@ -1,141 +1,259 @@ -use super::TryJoin; -use crate::utils::{self, PollArray}; +use super::TryJoin as TryJoinTrait; +use crate::utils::{PollArray, WakerArray}; -use core::fmt; +use core::fmt::{self, Debug}; use core::future::{Future, IntoFuture}; -use core::mem::{self, MaybeUninit}; +use core::mem::MaybeUninit; use core::pin::Pin; use core::task::{Context, Poll}; +use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::ops::DerefMut; use pin_project::{pin_project, pinned_drop}; +/// Generates the `poll` call for every `Future` inside `$futures`. +/// +/// SAFETY: pretty please only call this after having made very sure that the future you're trying +/// to call is actually marked `ready!`. If Rust had unsafe macros, this would be one. +// +// This is implemented as a tt-muncher of the future name `$($F:ident)` +// and the future index `$($rest)`, taking advantage that we only support +// tuples up to 12 elements +// +// # References +// TT Muncher: https://veykril.github.io/tlborm/decl-macros/patterns/tt-muncher.html +macro_rules! unsafe_poll { + // recursively iterate + (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => { + if $fut_idx == $iteration { + + if let Poll::Ready(value) = unsafe { + $futures.$fut_name.as_mut() + .map_unchecked_mut(|t| t.deref_mut()) + .poll(&mut $cx) + } { + $this.state[$fut_idx].set_ready(); + *$this.completed += 1; + // SAFETY: the future state has been changed to "ready" which + // means we'll no longer poll the future, so it's safe to drop + unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) }; + + // Check the value, short-circuit on error. + match value { + Ok(value) => $this.outputs.$fut_idx.write(value), + Err(err) => { + // The future should no longer be polled after we're done here + *$this.consumed = true; + return Poll::Ready(Err(err)); + } + }; + } + } + unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*); + }; + + // base condition + (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, | $($rest:tt)*) => {}; + + // macro start + ($iteration:ident, $this:ident, $futures:ident, $cx:ident, $LEN:ident, $($F:ident,)+) => { + unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11); + }; +} + +/// Drop all initialized values +macro_rules! drop_initialized_values { + // recursively iterate + (@drop $output:ident, $($rem_outs:ident,)* | $states:expr, $state_idx:tt, $($rem_idx:tt,)*) => { + if $states[$state_idx].is_ready() { + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { $output.assume_init_drop() }; + $states[$state_idx].set_consumed(); + } + drop_initialized_values!(@drop $($rem_outs,)* | $states, $($rem_idx,)*); + }; + + // base condition + (@drop | $states:expr, $($rem_idx:tt,)*) => {}; + + // macro start + ($($outs:ident,)+ | $states:expr) => { + drop_initialized_values!(@drop $($outs,)+ | $states, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,); + }; +} + +/// Drop all pending futures +macro_rules! drop_pending_futures { + // recursively iterate + (@inner $states:ident, $futures:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => { + if $states[$fut_idx].is_pending() { + // SAFETY: We're accessing the value behind the pinned reference to drop it exactly once. + let futures = unsafe { $futures.as_mut().get_unchecked_mut() }; + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { ManuallyDrop::drop(&mut futures.$fut_name) }; + } + drop_pending_futures!(@inner $states, $futures, $($F)* | $($rest)*); + }; + + // base condition + (@inner $states:ident, $futures:ident, | $($rest:tt)*) => {}; + + // macro start + ($states:ident, $futures:ident, $($F:ident,)+) => { + drop_pending_futures!(@inner $states, $futures, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11); + }; +} + macro_rules! impl_try_join_tuple { - ($mod_name:ident $StructName:ident $(($F:ident $R:ident))+) => { - mod $mod_name { - pub(super) struct Output<$($R,)+> - { - $(pub(super) $R: core::mem::MaybeUninit<$R>,)+ + // `Impl TryJoin for ()` + ($mod_name:ident $StructName:ident) => { + /// A future which waits for similarly-typed futures to complete, or aborts early on error. + /// + /// This `struct` is created by the [`try_join`] method on the [`TryJoin`] trait. See + /// its documentation for more. + /// + /// [`try_join`]: crate::future::TryJoin::try_join + /// [`TryJoin`]: crate::future::Join + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[allow(non_snake_case)] + pub struct $StructName {} + + impl fmt::Debug for $StructName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("TryJoin").finish() } + } - impl<$($R,)+> Default for Output<$($R,)+> - { - fn default() -> Self { - Self { - $($R: core::mem::MaybeUninit::uninit(),)+ - } - } + impl Future for $StructName { + type Output = Result<(), std::convert::Infallible>; + + fn poll( + self: Pin<&mut Self>, _cx: &mut Context<'_> + ) -> Poll { + Poll::Ready(Ok(())) } + } - #[repr(usize)] - enum Indexes { - $($F,)+ + impl TryJoinTrait for () { + type Output = (); + type Error = std::convert::Infallible; + type Future = $StructName; + fn try_join(self) -> Self::Future { + $StructName {} } + } + }; + + // `Impl TryJoin for (F..)` + ($mod_name:ident $StructName:ident $(($F:ident $T:ident))+) => { + mod $mod_name { + use std::mem::ManuallyDrop; + + #[pin_project::pin_project] + pub(super) struct Futures<$($F,)+> {$( + #[pin] + pub(super) $F: ManuallyDrop<$F>, + )+} - $( - pub(super) const $F: usize = Indexes::$F as usize; - )+ + #[repr(u8)] + pub(super) enum Indexes { $($F,)+ } pub(super) const LEN: usize = [$(Indexes::$F,)+].len(); } - /// A future which waits for all futures to complete successfully, or abort early on error. + /// Waits for many similarly-typed futures to complete or abort early on error. /// /// This `struct` is created by the [`try_join`] method on the [`TryJoin`] trait. See /// its documentation for more. /// /// [`try_join`]: crate::future::TryJoin::try_join - /// [`TryJoin`]: crate::future::TryJoin + /// [`TryJoin`]: crate::future::Join + #[pin_project(PinnedDrop)] #[must_use = "futures do nothing unless you `.await` or poll them"] #[allow(non_snake_case)] - #[pin_project(PinnedDrop)] - pub struct $StructName<$($F, $R,)+> - { - done: bool, + pub struct $StructName<$($F, $T,)+ Err> { + #[pin] + futures: $mod_name::Futures<$($F,)+>, + outputs: ($(MaybeUninit<$T>,)+), + // trace the state of outputs, marking them as ready or consumed + // then, drop the non-consumed values, if any + state: PollArray<{$mod_name::LEN}>, + wakers: WakerArray<{$mod_name::LEN}>, completed: usize, - indexer: utils::Indexer, - output: $mod_name::Output<$($R,)+>, - output_states: PollArray<{ $mod_name::LEN }>, - $( #[pin] $F: $F, )+ + consumed: bool, + _phantom: PhantomData, } - impl<$($F, $R,)+> fmt::Debug for $StructName<$($F, $R,)+> + impl<$($F, $T,)+ Err> Debug for $StructName<$($F, $T,)+ Err> where - $($F: fmt::Debug,)+ + $( $F: Future + Debug, )+ { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("TryJoin") - $(.field(&self.$F))+ + $(.field(&self.futures.$F))+ .finish() } } - impl TryJoin for ($($F,)+) - where - $( $F: IntoFuture>, )+ - { - type Output = ($($R,)+); - type Error = ERR; - type Future = $StructName<$($F::IntoFuture, $R,)+>; - - fn try_join(self) -> Self::Future { - let ($($F,)+): ($($F,)+) = self; - $StructName { - completed: 0, - done: false, - indexer: utils::Indexer::new($mod_name::LEN), - output: Default::default(), - output_states: PollArray::new(), - $($F: $F.into_future()),+ - } - } - } - - impl Future for $StructName<$($F, $R,)+> - where - $( $F: Future>, )+ - { - type Output = Result<($($R,)+), ERR>; + #[allow(unused_mut)] + #[allow(unused_parens)] + #[allow(unused_variables)] + impl<$($F, $T,)+ Err> Future for $StructName<$($F, $T,)+ Err> + where $( + $F: Future>, + )+ { + type Output = Result<($($T,)+), Err>; fn poll( self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll { + const LEN: usize = $mod_name::LEN; + let mut this = self.project(); + assert!(!*this.consumed, "Futures must not be polled after completing"); - assert!(!*this.done, "Futures must not be polled after completing"); + let mut futures = this.futures.project(); - for i in this.indexer.iter() { - if this.output_states[i].is_ready() { + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + + for index in 0..LEN { + if !readiness.any_ready() { + // nothing ready yet + return Poll::Pending; + } + if !readiness.clear_ready(index) || this.state[index].is_ready() { + // future not ready yet or already polled to completion, skip continue; } - utils::gen_conditions!(i, this, cx, poll, $(($mod_name::$F; $F, { - Poll::Ready(output) => match output { - Ok(output) => { - this.output.$R = MaybeUninit::new(output); - this.output_states[$mod_name::$F].set_ready(); - *this.completed += 1; - continue; - }, - Err(err) => { - *this.done = true; - *this.completed += 1; - return Poll::Ready(Err(err)); - }, - }, - _ => continue, - }))*); - } - let all_completed = *this.completed == $mod_name::LEN; - if all_completed { - // mark all error states as consumed before we return it - this.output_states.set_all_completed(); + // unlock readiness so we don't deadlock when polling + drop(readiness); - let mut output = Default::default(); - mem::swap(&mut output, this.output); + // obtain the intermediate waker + let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); - *this.done = true; + // generate the needed code to poll `futures.{index}` + // SAFETY: the future's state should be "pending", so it's safe to poll + unsafe_poll!(index, this, futures, cx, LEN, $($F,)+); - return Poll::Ready(Ok(( $(unsafe { output.$R.assume_init() }, )+ ))); + if *this.completed == LEN { + let out = { + let mut out = ($(MaybeUninit::<$T>::uninit(),)+); + core::mem::swap(&mut out, this.outputs); + let ($($F,)+) = out; + unsafe { ($($F.assume_init(),)+) } + }; + + this.state.set_all_completed(); + *this.consumed = true; + + return Poll::Ready(Ok(out)); + } + readiness = this.wakers.readiness().lock().unwrap(); } Poll::Pending @@ -143,24 +261,47 @@ macro_rules! impl_try_join_tuple { } #[pinned_drop] - impl<$($F, $R,)+> PinnedDrop for $StructName<$($F, $R,)+> - { + impl<$($F, $T,)+ Err> PinnedDrop for $StructName<$($F, $T,)+ Err> { fn drop(self: Pin<&mut Self>) { let this = self.project(); - $( - let mut st = this.output_states[$mod_name::$F]; - if st.is_ready() { - // SAFETY: we've just filtered down to *only* the initialized values. - unsafe { this.output.$R.assume_init_drop() }; - st.set_consumed(); - } - )+ + let ($(ref mut $F,)+) = this.outputs; + + let states = this.state; + let mut futures = this.futures; + drop_initialized_values!($($F,)+ | states); + drop_pending_futures!(states, futures, $($F,)+); } } - } + + #[allow(unused_parens)] + impl<$($F, $T,)+ Err> TryJoinTrait for ($($F,)+) + where $( + $F: IntoFuture>, + )+ { + type Output = ($($T,)+); + type Error = Err; + type Future = $StructName<$($F::IntoFuture, $T,)+ Err>; + + fn try_join(self) -> Self::Future { + let ($($F,)+): ($($F,)+) = self; + $StructName { + futures: $mod_name::Futures {$( + $F: ManuallyDrop::new($F.into_future()), + )+}, + state: PollArray::new(), + outputs: ($(MaybeUninit::<$T>::uninit(),)+), + wakers: WakerArray::new(), + completed: 0, + consumed: false, + _phantom: PhantomData, + } + } + } + }; } +impl_try_join_tuple! { try_join0 TryJoin0 } impl_try_join_tuple! { try_join_1 TryJoin1 (A ResA) } impl_try_join_tuple! { try_join_2 TryJoin2 (A ResA) (B ResB) } impl_try_join_tuple! { try_join_3 TryJoin3 (A ResA) (B ResB) (C ResC) } @@ -220,4 +361,39 @@ mod test { assert_eq!(res.unwrap(), ((), ())); }); } + + #[test] + fn does_not_leak_memory() { + use core::cell::RefCell; + use futures_lite::future::pending; + + thread_local! { + static NOT_LEAKING: RefCell = RefCell::new(false); + }; + + struct FlipFlagAtDrop; + impl Drop for FlipFlagAtDrop { + fn drop(&mut self) { + NOT_LEAKING.with(|v| { + *v.borrow_mut() = true; + }); + } + } + + futures_lite::future::block_on(async { + // this will trigger Miri if we don't drop the memory + let string = future::ready(io::Result::Ok("memory leak".to_owned())); + + // this will not flip the thread_local flag if we don't drop the memory + let flip = future::ready(io::Result::Ok(FlipFlagAtDrop)); + + let leak = (string, flip, pending::>()).try_join(); + + _ = futures_lite::future::poll_once(leak).await; + }); + + NOT_LEAKING.with(|flag| { + assert!(*flag.borrow()); + }) + } }