From b67503c070f384e907074bb2d1ca20a03dc1e453 Mon Sep 17 00:00:00 2001 From: Yosh Date: Thu, 22 Jun 2023 02:12:43 +0200 Subject: [PATCH 1/5] start by rewriting the code --- src/future/try_join/tuple.rs | 378 ++++++++++++++++++++++++----------- 1 file changed, 266 insertions(+), 112 deletions(-) diff --git a/src/future/try_join/tuple.rs b/src/future/try_join/tuple.rs index 7214980..f36d0fd 100644 --- a/src/future/try_join/tuple.rs +++ b/src/future/try_join/tuple.rs @@ -1,141 +1,241 @@ -use super::TryJoin; -use crate::utils::{self, PollArray}; +use super::TryJoin as TryJoinTrait; +use crate::utils::{PollArray, WakerArray}; -use core::fmt; +use core::fmt::{self, Debug}; use core::future::{Future, IntoFuture}; -use core::mem::{self, MaybeUninit}; +use core::mem::MaybeUninit; use core::pin::Pin; use core::task::{Context, Poll}; +use std::mem::ManuallyDrop; +use std::ops::DerefMut; use pin_project::{pin_project, pinned_drop}; +/// Generates the `poll` call for every `Future` inside `$futures`. +/// +/// SAFETY: pretty please only call this after having made very sure that the future you're trying +/// to call is actually marked `ready!`. If Rust had unsafe macros, this would be one. +// +// This is implemented as a tt-muncher of the future name `$($F:ident)` +// and the future index `$($rest)`, taking advantage that we only support +// tuples up to 12 elements +// +// # References +// TT Muncher: https://veykril.github.io/tlborm/decl-macros/patterns/tt-muncher.html +macro_rules! unsafe_poll { + // recursively iterate + (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => { + if $fut_idx == $iteration { + + if let Poll::Ready(value) = unsafe { + $futures.$fut_name.as_mut() + .map_unchecked_mut(|t| t.deref_mut()) + .poll(&mut $cx) + } { + $this.outputs.$fut_idx.write(value); + *$this.completed += 1; + $this.state[$fut_idx].set_ready(); + // SAFETY: the future state has been changed to "ready" which + // means we'll no longer poll the future, so it's safe to drop + unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) }; + } + } + unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*); + }; + + // base condition + (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, | $($rest:tt)*) => {}; + + // macro start + ($iteration:ident, $this:ident, $futures:ident, $cx:ident, $LEN:ident, $($F:ident,)+) => { + unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11); + }; +} + +/// Drop all initialized values +macro_rules! drop_initialized_values { + // recursively iterate + (@drop $output:ident, $($rem_outs:ident,)* | $states:expr, $state_idx:tt, $($rem_idx:tt,)*) => { + if $states[$state_idx].is_ready() { + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { $output.assume_init_drop() }; + $states[$state_idx].set_consumed(); + } + drop_initialized_values!(@drop $($rem_outs,)* | $states, $($rem_idx,)*); + }; + + // base condition + (@drop | $states:expr, $($rem_idx:tt,)*) => {}; + + // macro start + ($($outs:ident,)+ | $states:expr) => { + drop_initialized_values!(@drop $($outs,)+ | $states, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,); + }; +} + +/// Drop all pending futures +macro_rules! drop_pending_futures { + // recursively iterate + (@inner $states:ident, $futures:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => { + if $states[$fut_idx].is_pending() { + // SAFETY: We're accessing the value behind the pinned reference to drop it exactly once. + let futures = unsafe { $futures.as_mut().get_unchecked_mut() }; + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { ManuallyDrop::drop(&mut futures.$fut_name) }; + } + drop_pending_futures!(@inner $states, $futures, $($F)* | $($rest)*); + }; + + // base condition + (@inner $states:ident, $futures:ident, | $($rest:tt)*) => {}; + + // macro start + ($states:ident, $futures:ident, $($F:ident,)+) => { + drop_pending_futures!(@inner $states, $futures, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11); + }; +} + macro_rules! impl_try_join_tuple { - ($mod_name:ident $StructName:ident $(($F:ident $R:ident))+) => { - mod $mod_name { - pub(super) struct Output<$($R,)+> - { - $(pub(super) $R: core::mem::MaybeUninit<$R>,)+ + ($mod_name:ident $StructName:ident) => { + /// A future which waits for similarly-typed futures to complete, or aborts early on error. + /// + /// This `struct` is created by the [`try_join`] method on the [`TryJoin`] trait. See + /// its documentation for more. + /// + /// [`try_join`]: crate::future::TryJoin::try_join + /// [`TryJoin`]: crate::future::Join + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[allow(non_snake_case)] + pub struct $StructName {} + + impl fmt::Debug for $StructName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("TryJoin").finish() } + } - impl<$($R,)+> Default for Output<$($R,)+> - { - fn default() -> Self { - Self { - $($R: core::mem::MaybeUninit::uninit(),)+ - } - } + impl Future for $StructName { + type Output = Result<(), std::convert::Infallible>; + + fn poll( + self: Pin<&mut Self>, _cx: &mut Context<'_> + ) -> Poll { + Poll::Ready(Ok(())) } + } - #[repr(usize)] - enum Indexes { - $($F,)+ + impl TryJoinTrait for () { + type Output = (); + type Error = std::convert::Infallible; + type Future = $StructName; + fn try_join(self) -> Self::Future { + $StructName {} } + } + }; + ($mod_name:ident $StructName:ident $($F:ident)+) => { + mod $mod_name { + use std::mem::ManuallyDrop; - $( - pub(super) const $F: usize = Indexes::$F as usize; - )+ + #[pin_project::pin_project] + pub(super) struct Futures<$($F,)+> {$( + #[pin] + pub(super) $F: ManuallyDrop<$F>, + )+} + + #[repr(u8)] + pub(super) enum Indexes { $($F,)+ } pub(super) const LEN: usize = [$(Indexes::$F,)+].len(); } - /// A future which waits for all futures to complete successfully, or abort early on error. + /// Waits for many similarly-typed futures to complete or abort early on error. /// /// This `struct` is created by the [`try_join`] method on the [`TryJoin`] trait. See /// its documentation for more. /// /// [`try_join`]: crate::future::TryJoin::try_join - /// [`TryJoin`]: crate::future::TryJoin + /// [`TryJoin`]: crate::future::Join + #[pin_project(PinnedDrop)] #[must_use = "futures do nothing unless you `.await` or poll them"] #[allow(non_snake_case)] - #[pin_project(PinnedDrop)] - pub struct $StructName<$($F, $R,)+> - { - done: bool, + pub struct $StructName<$($F: Future),+, Err> { + #[pin] + futures: $mod_name::Futures<$($F,)+>, + outputs: ($(MaybeUninit<$F::Output>,)+), + // trace the state of outputs, marking them as ready or consumed + // then, drop the non-consumed values, if any + state: PollArray<{$mod_name::LEN}>, + wakers: WakerArray<{$mod_name::LEN}>, completed: usize, - indexer: utils::Indexer, - output: $mod_name::Output<$($R,)+>, - output_states: PollArray<{ $mod_name::LEN }>, - $( #[pin] $F: $F, )+ } - impl<$($F, $R,)+> fmt::Debug for $StructName<$($F, $R,)+> + impl<$($F),+, Err> Debug for $StructName<$($F),+, Err> where - $($F: fmt::Debug,)+ + $( $F: Future + Debug, )+ { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("TryJoin") - $(.field(&self.$F))+ + $(.field(&self.futures.$F))+ .finish() } } - impl TryJoin for ($($F,)+) - where - $( $F: IntoFuture>, )+ - { - type Output = ($($R,)+); - type Error = ERR; - type Future = $StructName<$($F::IntoFuture, $R,)+>; - - fn try_join(self) -> Self::Future { - let ($($F,)+): ($($F,)+) = self; - $StructName { - completed: 0, - done: false, - indexer: utils::Indexer::new($mod_name::LEN), - output: Default::default(), - output_states: PollArray::new(), - $($F: $F.into_future()),+ - } - } - } - - impl Future for $StructName<$($F, $R,)+> - where - $( $F: Future>, )+ - { - type Output = Result<($($R,)+), ERR>; + #[allow(unused_mut)] + #[allow(unused_parens)] + #[allow(unused_variables)] + impl<$($F: Future),+, Err> Future for $StructName<$($F),+, Err> { + type Output = ($($F::Output,)+); fn poll( self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll { + const LEN: usize = $mod_name::LEN; + let mut this = self.project(); + let all_completed = !(*this.completed == LEN); + assert!(all_completed, "Futures must not be polled after completing"); + + let mut futures = this.futures.project(); - assert!(!*this.done, "Futures must not be polled after completing"); + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); - for i in this.indexer.iter() { - if this.output_states[i].is_ready() { + for index in 0..LEN { + if !readiness.any_ready() { + // nothing ready yet + return Poll::Pending; + } + if !readiness.clear_ready(index) || this.state[index].is_ready() { + // future not ready yet or already polled to completion, skip continue; } - utils::gen_conditions!(i, this, cx, poll, $(($mod_name::$F; $F, { - Poll::Ready(output) => match output { - Ok(output) => { - this.output.$R = MaybeUninit::new(output); - this.output_states[$mod_name::$F].set_ready(); - *this.completed += 1; - continue; - }, - Err(err) => { - *this.done = true; - *this.completed += 1; - return Poll::Ready(Err(err)); - }, - }, - _ => continue, - }))*); - } - let all_completed = *this.completed == $mod_name::LEN; - if all_completed { - // mark all error states as consumed before we return it - this.output_states.set_all_completed(); + // unlock readiness so we don't deadlock when polling + drop(readiness); + + // obtain the intermediate waker + let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); - let mut output = Default::default(); - mem::swap(&mut output, this.output); + // generate the needed code to poll `futures.{index}` + // SAFETY: the future's state should be "pending", so it's safe to poll + unsafe_poll!(index, this, futures, cx, LEN, $($F,)+); - *this.done = true; + if *this.completed == LEN { + let out = { + let mut out = ($(MaybeUninit::<$F::Output>::uninit(),)+); + core::mem::swap(&mut out, this.outputs); + let ($($F,)+) = out; + unsafe { ($($F.assume_init(),)+) } + }; - return Poll::Ready(Ok(( $(unsafe { output.$R.assume_init() }, )+ ))); + this.state.set_all_completed(); + + return Poll::Ready(out); + } + readiness = this.wakers.readiness().lock().unwrap(); } Poll::Pending @@ -143,36 +243,55 @@ macro_rules! impl_try_join_tuple { } #[pinned_drop] - impl<$($F, $R,)+> PinnedDrop for $StructName<$($F, $R,)+> - { + impl<$($F: Future),+, Err> PinnedDrop for $StructName<$($F),+, Err> { fn drop(self: Pin<&mut Self>) { let this = self.project(); - $( - let mut st = this.output_states[$mod_name::$F]; - if st.is_ready() { - // SAFETY: we've just filtered down to *only* the initialized values. - unsafe { this.output.$R.assume_init_drop() }; - st.set_consumed(); - } - )+ + let ($(ref mut $F,)+) = this.outputs; + + let states = this.state; + let mut futures = this.futures; + drop_initialized_values!($($F,)+ | states); + drop_pending_futures!(states, futures, $($F,)+); } } - } + + #[allow(unused_parens)] + impl<$($F, $F,)+ Err> TryJoinTrait for ($($F,)+) + where $( + $F: IntoFuture>, + )+ { + type Output = $FT; + type Error = Err; + type Future = $StructName<$($F::IntoFuture),*, Err>; + + fn try_join(self) -> Self::Future { + let ($($F,)+): ($($F,)+) = self; + $StructName { + futures: $mod_name::Futures {$($F: ManuallyDrop::new($F.into_future()),)+}, + state: PollArray::new(), + outputs: ($(MaybeUninit::<$F::Output>::uninit(),)+), + wakers: WakerArray::new(), + completed: 0, + } + } + } + }; } -impl_try_join_tuple! { try_join_1 TryJoin1 (A ResA) } -impl_try_join_tuple! { try_join_2 TryJoin2 (A ResA) (B ResB) } -impl_try_join_tuple! { try_join_3 TryJoin3 (A ResA) (B ResB) (C ResC) } -impl_try_join_tuple! { try_join_4 TryJoin4 (A ResA) (B ResB) (C ResC) (D ResD) } -impl_try_join_tuple! { try_join_5 TryJoin5 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) } -impl_try_join_tuple! { try_join_6 TryJoin6 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) } -impl_try_join_tuple! { try_join_7 TryJoin7 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) } -impl_try_join_tuple! { try_join_8 TryJoin8 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) } -impl_try_join_tuple! { try_join_9 TryJoin9 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) } -impl_try_join_tuple! { try_join_10 TryJoin10 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) } -impl_try_join_tuple! { try_join_11 TryJoin11 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) } -impl_try_join_tuple! { try_join_12 TryJoin12 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) (L ResL) } +impl_try_join_tuple! { try_join0 TryJoin0 } +impl_try_join_tuple! { try_join1 TryJoin1 A } +impl_try_join_tuple! { try_join2 TryJoin2 A B } +impl_try_join_tuple! { try_join3 TryJoin3 A B C } +impl_try_join_tuple! { try_join4 TryJoin4 A B C D } +impl_try_join_tuple! { try_join5 TryJoin5 A B C D E } +impl_try_join_tuple! { try_join6 TryJoin6 A B C D E F } +impl_try_join_tuple! { try_join7 TryJoin7 A B C D E F G } +impl_try_join_tuple! { try_join8 TryJoin8 A B C D E F G H } +impl_try_join_tuple! { try_join9 TryJoin9 A B C D E F G H I } +impl_try_join_tuple! { try_join10 TryJoin10 A B C D E F G H I J } +impl_try_join_tuple! { try_join11 TryJoin11 A B C D E F G H I J K } +impl_try_join_tuple! { try_join12 TryJoin12 A B C D E F G H I J K L } #[cfg(test)] mod test { @@ -220,4 +339,39 @@ mod test { assert_eq!(res.unwrap(), ((), ())); }); } + + #[test] + fn does_not_leak_memory() { + use core::cell::RefCell; + use futures_lite::future::pending; + + thread_local! { + static NOT_LEAKING: RefCell = RefCell::new(false); + }; + + struct FlipFlagAtDrop; + impl Drop for FlipFlagAtDrop { + fn drop(&mut self) { + NOT_LEAKING.with(|v| { + *v.borrow_mut() = true; + }); + } + } + + futures_lite::future::block_on(async { + // this will trigger Miri if we don't drop the memory + let string = future::ready("memory leak".to_owned()); + + // this will not flip the thread_local flag if we don't drop the memory + let flip = future::ready(FlipFlagAtDrop); + + let leak = (string, flip, pending::()).join(); + + _ = futures_lite::future::poll_once(leak).await; + }); + + NOT_LEAKING.with(|flag| { + assert!(*flag.borrow()); + }) + } } From 1f2c522808ae56fa886c74adda12c38358089171 Mon Sep 17 00:00:00 2001 From: Yosh Date: Thu, 22 Jun 2023 02:47:39 +0200 Subject: [PATCH 2/5] progressss --- src/future/try_join/tuple.rs | 220 ++++++++++++++++++----------------- 1 file changed, 113 insertions(+), 107 deletions(-) diff --git a/src/future/try_join/tuple.rs b/src/future/try_join/tuple.rs index f36d0fd..71a5237 100644 --- a/src/future/try_join/tuple.rs +++ b/src/future/try_join/tuple.rs @@ -6,6 +6,7 @@ use core::future::{Future, IntoFuture}; use core::mem::MaybeUninit; use core::pin::Pin; use core::task::{Context, Poll}; +use std::marker::PhantomData; use std::mem::ManuallyDrop; use std::ops::DerefMut; @@ -98,6 +99,7 @@ macro_rules! drop_pending_futures { } macro_rules! impl_try_join_tuple { + // `Impl TryJoin for ()` ($mod_name:ident $StructName:ident) => { /// A future which waits for similarly-typed futures to complete, or aborts early on error. /// @@ -135,7 +137,9 @@ macro_rules! impl_try_join_tuple { } } }; - ($mod_name:ident $StructName:ident $($F:ident)+) => { + + // `Impl TryJoin for (F..)` + ($mod_name:ident $StructName:ident $(($F:ident $T:ident))+) => { mod $mod_name { use std::mem::ManuallyDrop; @@ -161,18 +165,19 @@ macro_rules! impl_try_join_tuple { #[pin_project(PinnedDrop)] #[must_use = "futures do nothing unless you `.await` or poll them"] #[allow(non_snake_case)] - pub struct $StructName<$($F: Future),+, Err> { + pub struct $StructName<$($F, $T,)+ Err> { #[pin] futures: $mod_name::Futures<$($F,)+>, - outputs: ($(MaybeUninit<$F::Output>,)+), + outputs: ($(MaybeUninit<$T>,)+), // trace the state of outputs, marking them as ready or consumed // then, drop the non-consumed values, if any state: PollArray<{$mod_name::LEN}>, wakers: WakerArray<{$mod_name::LEN}>, completed: usize, + _phantom: PhantomData, } - impl<$($F),+, Err> Debug for $StructName<$($F),+, Err> + impl<$($F, $T)+, Err> Debug for $StructName<$($F, $T,)+ Err> where $( $F: Future + Debug, )+ { @@ -186,8 +191,8 @@ macro_rules! impl_try_join_tuple { #[allow(unused_mut)] #[allow(unused_parens)] #[allow(unused_variables)] - impl<$($F: Future),+, Err> Future for $StructName<$($F),+, Err> { - type Output = ($($F::Output,)+); + impl<$($F: Future, $T,)+ Err> Future for $StructName<$($F, $T,)+ Err> { + type Output = Result<($($F::Output,)+), Err>; fn poll( self: Pin<&mut Self>, cx: &mut Context<'_> @@ -233,7 +238,7 @@ macro_rules! impl_try_join_tuple { this.state.set_all_completed(); - return Poll::Ready(out); + return Poll::Ready(Ok(out)); } readiness = this.wakers.readiness().lock().unwrap(); } @@ -243,7 +248,7 @@ macro_rules! impl_try_join_tuple { } #[pinned_drop] - impl<$($F: Future),+, Err> PinnedDrop for $StructName<$($F),+, Err> { + impl<$($F, $T)+, Err> PinnedDrop for $StructName<$($F, $T,)+ Err> { fn drop(self: Pin<&mut Self>) { let this = self.project(); @@ -257,13 +262,13 @@ macro_rules! impl_try_join_tuple { } #[allow(unused_parens)] - impl<$($F, $F,)+ Err> TryJoinTrait for ($($F,)+) + impl<$($F, $T)+, Err> TryJoinTrait for ($($F,)+) where $( - $F: IntoFuture>, + $F: IntoFuture>, )+ { - type Output = $FT; + type Output = ($($T,)+); type Error = Err; - type Future = $StructName<$($F::IntoFuture),*, Err>; + type Future = $StructName<$($F, $T,)+ Err>; fn try_join(self) -> Self::Future { let ($($F,)+): ($($F,)+) = self; @@ -273,6 +278,7 @@ macro_rules! impl_try_join_tuple { outputs: ($(MaybeUninit::<$F::Output>::uninit(),)+), wakers: WakerArray::new(), completed: 0, + _phantom: PhantomData, } } } @@ -280,98 +286,98 @@ macro_rules! impl_try_join_tuple { } impl_try_join_tuple! { try_join0 TryJoin0 } -impl_try_join_tuple! { try_join1 TryJoin1 A } -impl_try_join_tuple! { try_join2 TryJoin2 A B } -impl_try_join_tuple! { try_join3 TryJoin3 A B C } -impl_try_join_tuple! { try_join4 TryJoin4 A B C D } -impl_try_join_tuple! { try_join5 TryJoin5 A B C D E } -impl_try_join_tuple! { try_join6 TryJoin6 A B C D E F } -impl_try_join_tuple! { try_join7 TryJoin7 A B C D E F G } -impl_try_join_tuple! { try_join8 TryJoin8 A B C D E F G H } -impl_try_join_tuple! { try_join9 TryJoin9 A B C D E F G H I } -impl_try_join_tuple! { try_join10 TryJoin10 A B C D E F G H I J } -impl_try_join_tuple! { try_join11 TryJoin11 A B C D E F G H I J K } -impl_try_join_tuple! { try_join12 TryJoin12 A B C D E F G H I J K L } - -#[cfg(test)] -mod test { - use super::*; - - use std::convert::Infallible; - use std::future; - use std::io::{self, Error, ErrorKind}; - - #[test] - fn all_ok() { - futures_lite::future::block_on(async { - let a = async { Ok::<_, Infallible>("aaaa") }; - let b = async { Ok::<_, Infallible>(1) }; - let c = async { Ok::<_, Infallible>('z') }; - - let result = (a, b, c).try_join().await; - assert_eq!(result, Ok(("aaaa", 1, 'z'))); - }) - } - - #[test] - fn one_err() { - futures_lite::future::block_on(async { - let err = Error::new(ErrorKind::Other, "oh no"); - let res: io::Result<(_, char)> = (future::ready(Ok("hello")), future::ready(Err(err))) - .try_join() - .await; - assert_eq!(res.unwrap_err().to_string(), String::from("oh no")); - }) - } - - #[test] - fn issue_135_resume_after_completion() { - use futures_lite::future::yield_now; - futures_lite::future::block_on(async { - let ok = async { Ok::<_, ()>(()) }; - let err = async { - yield_now().await; - Ok::<_, ()>(()) - }; - - let res = (ok, err).try_join().await; - - assert_eq!(res.unwrap(), ((), ())); - }); - } - - #[test] - fn does_not_leak_memory() { - use core::cell::RefCell; - use futures_lite::future::pending; - - thread_local! { - static NOT_LEAKING: RefCell = RefCell::new(false); - }; - - struct FlipFlagAtDrop; - impl Drop for FlipFlagAtDrop { - fn drop(&mut self) { - NOT_LEAKING.with(|v| { - *v.borrow_mut() = true; - }); - } - } - - futures_lite::future::block_on(async { - // this will trigger Miri if we don't drop the memory - let string = future::ready("memory leak".to_owned()); - - // this will not flip the thread_local flag if we don't drop the memory - let flip = future::ready(FlipFlagAtDrop); - - let leak = (string, flip, pending::()).join(); - - _ = futures_lite::future::poll_once(leak).await; - }); - - NOT_LEAKING.with(|flag| { - assert!(*flag.borrow()); - }) - } -} +impl_try_join_tuple! { try_join_1 TryJoin1 (A ResA) } +// impl_try_join_tuple! { try_join_2 TryJoin2 (A ResA) (B ResB) } +// impl_try_join_tuple! { try_join_3 TryJoin3 (A ResA) (B ResB) (C ResC) } +// impl_try_join_tuple! { try_join_4 TryJoin4 (A ResA) (B ResB) (C ResC) (D ResD) } +// impl_try_join_tuple! { try_join_5 TryJoin5 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) } +// impl_try_join_tuple! { try_join_6 TryJoin6 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) } +// impl_try_join_tuple! { try_join_7 TryJoin7 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) } +// impl_try_join_tuple! { try_join_8 TryJoin8 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) } +// impl_try_join_tuple! { try_join_9 TryJoin9 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) } +// impl_try_join_tuple! { try_join_10 TryJoin10 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) } +// impl_try_join_tuple! { try_join_11 TryJoin11 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) } +// impl_try_join_tuple! { try_join_12 TryJoin12 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) (L ResL) } +// +// #[cfg(test)] +// mod test { +// use super::*; + +// use std::convert::Infallible; +// use std::future; +// use std::io::{self, Error, ErrorKind}; + +// #[test] +// fn all_ok() { +// futures_lite::future::block_on(async { +// let a = async { Ok::<_, Infallible>("aaaa") }; +// let b = async { Ok::<_, Infallible>(1) }; +// let c = async { Ok::<_, Infallible>('z') }; + +// let result = (a, b, c).try_join().await; +// assert_eq!(result, Ok(("aaaa", 1, 'z'))); +// }) +// } + +// #[test] +// fn one_err() { +// futures_lite::future::block_on(async { +// let err = Error::new(ErrorKind::Other, "oh no"); +// let res: io::Result<(_, char)> = (future::ready(Ok("hello")), future::ready(Err(err))) +// .try_join() +// .await; +// assert_eq!(res.unwrap_err().to_string(), String::from("oh no")); +// }) +// } + +// #[test] +// fn issue_135_resume_after_completion() { +// use futures_lite::future::yield_now; +// futures_lite::future::block_on(async { +// let ok = async { Ok::<_, ()>(()) }; +// let err = async { +// yield_now().await; +// Ok::<_, ()>(()) +// }; + +// let res = (ok, err).try_join().await; + +// assert_eq!(res.unwrap(), ((), ())); +// }); +// } + +// #[test] +// fn does_not_leak_memory() { +// use core::cell::RefCell; +// use futures_lite::future::pending; + +// thread_local! { +// static NOT_LEAKING: RefCell = RefCell::new(false); +// }; + +// struct FlipFlagAtDrop; +// impl Drop for FlipFlagAtDrop { +// fn drop(&mut self) { +// NOT_LEAKING.with(|v| { +// *v.borrow_mut() = true; +// }); +// } +// } + +// futures_lite::future::block_on(async { +// // this will trigger Miri if we don't drop the memory +// let string = future::ready("memory leak".to_owned()); + +// // this will not flip the thread_local flag if we don't drop the memory +// let flip = future::ready(FlipFlagAtDrop); + +// let leak = (string, flip, pending::()).join(); + +// _ = futures_lite::future::poll_once(leak).await; +// }); + +// NOT_LEAKING.with(|flag| { +// assert!(*flag.borrow()); +// }) +// } +// } From 962a8a219c6ed2ddc8d70d54a5855da52adbf12e Mon Sep 17 00:00:00 2001 From: Yosh Date: Thu, 22 Jun 2023 02:59:39 +0200 Subject: [PATCH 3/5] progress --- src/future/try_join/tuple.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/future/try_join/tuple.rs b/src/future/try_join/tuple.rs index 71a5237..25a9c58 100644 --- a/src/future/try_join/tuple.rs +++ b/src/future/try_join/tuple.rs @@ -191,7 +191,10 @@ macro_rules! impl_try_join_tuple { #[allow(unused_mut)] #[allow(unused_parens)] #[allow(unused_variables)] - impl<$($F: Future, $T,)+ Err> Future for $StructName<$($F, $T,)+ Err> { + impl<$($F: Future, $T,)+ Err> Future for $StructName<$($F, $T,)+ Err> + where $( + $F: Future> + )+ { type Output = Result<($($F::Output,)+), Err>; fn poll( @@ -268,14 +271,16 @@ macro_rules! impl_try_join_tuple { )+ { type Output = ($($T,)+); type Error = Err; - type Future = $StructName<$($F, $T,)+ Err>; + type Future = $StructName<$($F::IntoFuture, $T,)+ Err>; fn try_join(self) -> Self::Future { let ($($F,)+): ($($F,)+) = self; $StructName { - futures: $mod_name::Futures {$($F: ManuallyDrop::new($F.into_future()),)+}, + futures: $mod_name::Futures {$( + $F: ManuallyDrop::new($F.into_future()), + )+}, state: PollArray::new(), - outputs: ($(MaybeUninit::<$F::Output>::uninit(),)+), + outputs: ($(MaybeUninit::<$T>::uninit(),)+), wakers: WakerArray::new(), completed: 0, _phantom: PhantomData, From f7624433e594dbe1eb56911e2ebceba055095413 Mon Sep 17 00:00:00 2001 From: Yosh Date: Fri, 23 Jun 2023 02:55:39 +0200 Subject: [PATCH 4/5] tests pass --- src/future/try_join/tuple.rs | 221 ++++++++++++++++++----------------- 1 file changed, 116 insertions(+), 105 deletions(-) diff --git a/src/future/try_join/tuple.rs b/src/future/try_join/tuple.rs index 25a9c58..f379cc3 100644 --- a/src/future/try_join/tuple.rs +++ b/src/future/try_join/tuple.rs @@ -33,12 +33,21 @@ macro_rules! unsafe_poll { .map_unchecked_mut(|t| t.deref_mut()) .poll(&mut $cx) } { - $this.outputs.$fut_idx.write(value); - *$this.completed += 1; $this.state[$fut_idx].set_ready(); + *$this.completed += 1; // SAFETY: the future state has been changed to "ready" which // means we'll no longer poll the future, so it's safe to drop unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) }; + + // Check the value, short-circuit on error. + match value { + Ok(value) => $this.outputs.$fut_idx.write(value), + Err(err) => { + // The future should no longer be polled after we're done here + *$this.consumed = true; + return Poll::Ready(Err(err)); + } + }; } } unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*); @@ -174,10 +183,11 @@ macro_rules! impl_try_join_tuple { state: PollArray<{$mod_name::LEN}>, wakers: WakerArray<{$mod_name::LEN}>, completed: usize, + consumed: bool, _phantom: PhantomData, } - impl<$($F, $T)+, Err> Debug for $StructName<$($F, $T,)+ Err> + impl<$($F, $T,)+ Err> Debug for $StructName<$($F, $T,)+ Err> where $( $F: Future + Debug, )+ { @@ -191,11 +201,11 @@ macro_rules! impl_try_join_tuple { #[allow(unused_mut)] #[allow(unused_parens)] #[allow(unused_variables)] - impl<$($F: Future, $T,)+ Err> Future for $StructName<$($F, $T,)+ Err> + impl<$($F, $T,)+ Err> Future for $StructName<$($F, $T,)+ Err> where $( - $F: Future> + $F: Future>, )+ { - type Output = Result<($($F::Output,)+), Err>; + type Output = Result<($($T,)+), Err>; fn poll( self: Pin<&mut Self>, cx: &mut Context<'_> @@ -203,8 +213,7 @@ macro_rules! impl_try_join_tuple { const LEN: usize = $mod_name::LEN; let mut this = self.project(); - let all_completed = !(*this.completed == LEN); - assert!(all_completed, "Futures must not be polled after completing"); + assert!(!*this.consumed, "Futures must not be polled after completing"); let mut futures = this.futures.project(); @@ -233,13 +242,14 @@ macro_rules! impl_try_join_tuple { if *this.completed == LEN { let out = { - let mut out = ($(MaybeUninit::<$F::Output>::uninit(),)+); + let mut out = ($(MaybeUninit::<$T>::uninit(),)+); core::mem::swap(&mut out, this.outputs); let ($($F,)+) = out; unsafe { ($($F.assume_init(),)+) } }; this.state.set_all_completed(); + *this.consumed = true; return Poll::Ready(Ok(out)); } @@ -251,7 +261,7 @@ macro_rules! impl_try_join_tuple { } #[pinned_drop] - impl<$($F, $T)+, Err> PinnedDrop for $StructName<$($F, $T,)+ Err> { + impl<$($F, $T,)+ Err> PinnedDrop for $StructName<$($F, $T,)+ Err> { fn drop(self: Pin<&mut Self>) { let this = self.project(); @@ -265,7 +275,7 @@ macro_rules! impl_try_join_tuple { } #[allow(unused_parens)] - impl<$($F, $T)+, Err> TryJoinTrait for ($($F,)+) + impl<$($F, $T,)+ Err> TryJoinTrait for ($($F,)+) where $( $F: IntoFuture>, )+ { @@ -283,6 +293,7 @@ macro_rules! impl_try_join_tuple { outputs: ($(MaybeUninit::<$T>::uninit(),)+), wakers: WakerArray::new(), completed: 0, + consumed: false, _phantom: PhantomData, } } @@ -292,97 +303,97 @@ macro_rules! impl_try_join_tuple { impl_try_join_tuple! { try_join0 TryJoin0 } impl_try_join_tuple! { try_join_1 TryJoin1 (A ResA) } -// impl_try_join_tuple! { try_join_2 TryJoin2 (A ResA) (B ResB) } -// impl_try_join_tuple! { try_join_3 TryJoin3 (A ResA) (B ResB) (C ResC) } -// impl_try_join_tuple! { try_join_4 TryJoin4 (A ResA) (B ResB) (C ResC) (D ResD) } -// impl_try_join_tuple! { try_join_5 TryJoin5 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) } -// impl_try_join_tuple! { try_join_6 TryJoin6 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) } -// impl_try_join_tuple! { try_join_7 TryJoin7 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) } -// impl_try_join_tuple! { try_join_8 TryJoin8 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) } -// impl_try_join_tuple! { try_join_9 TryJoin9 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) } -// impl_try_join_tuple! { try_join_10 TryJoin10 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) } -// impl_try_join_tuple! { try_join_11 TryJoin11 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) } -// impl_try_join_tuple! { try_join_12 TryJoin12 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) (L ResL) } -// -// #[cfg(test)] -// mod test { -// use super::*; - -// use std::convert::Infallible; -// use std::future; -// use std::io::{self, Error, ErrorKind}; - -// #[test] -// fn all_ok() { -// futures_lite::future::block_on(async { -// let a = async { Ok::<_, Infallible>("aaaa") }; -// let b = async { Ok::<_, Infallible>(1) }; -// let c = async { Ok::<_, Infallible>('z') }; - -// let result = (a, b, c).try_join().await; -// assert_eq!(result, Ok(("aaaa", 1, 'z'))); -// }) -// } - -// #[test] -// fn one_err() { -// futures_lite::future::block_on(async { -// let err = Error::new(ErrorKind::Other, "oh no"); -// let res: io::Result<(_, char)> = (future::ready(Ok("hello")), future::ready(Err(err))) -// .try_join() -// .await; -// assert_eq!(res.unwrap_err().to_string(), String::from("oh no")); -// }) -// } - -// #[test] -// fn issue_135_resume_after_completion() { -// use futures_lite::future::yield_now; -// futures_lite::future::block_on(async { -// let ok = async { Ok::<_, ()>(()) }; -// let err = async { -// yield_now().await; -// Ok::<_, ()>(()) -// }; - -// let res = (ok, err).try_join().await; - -// assert_eq!(res.unwrap(), ((), ())); -// }); -// } - -// #[test] -// fn does_not_leak_memory() { -// use core::cell::RefCell; -// use futures_lite::future::pending; - -// thread_local! { -// static NOT_LEAKING: RefCell = RefCell::new(false); -// }; - -// struct FlipFlagAtDrop; -// impl Drop for FlipFlagAtDrop { -// fn drop(&mut self) { -// NOT_LEAKING.with(|v| { -// *v.borrow_mut() = true; -// }); -// } -// } - -// futures_lite::future::block_on(async { -// // this will trigger Miri if we don't drop the memory -// let string = future::ready("memory leak".to_owned()); - -// // this will not flip the thread_local flag if we don't drop the memory -// let flip = future::ready(FlipFlagAtDrop); - -// let leak = (string, flip, pending::()).join(); - -// _ = futures_lite::future::poll_once(leak).await; -// }); - -// NOT_LEAKING.with(|flag| { -// assert!(*flag.borrow()); -// }) -// } -// } +impl_try_join_tuple! { try_join_2 TryJoin2 (A ResA) (B ResB) } +impl_try_join_tuple! { try_join_3 TryJoin3 (A ResA) (B ResB) (C ResC) } +impl_try_join_tuple! { try_join_4 TryJoin4 (A ResA) (B ResB) (C ResC) (D ResD) } +impl_try_join_tuple! { try_join_5 TryJoin5 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) } +impl_try_join_tuple! { try_join_6 TryJoin6 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) } +impl_try_join_tuple! { try_join_7 TryJoin7 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) } +impl_try_join_tuple! { try_join_8 TryJoin8 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) } +impl_try_join_tuple! { try_join_9 TryJoin9 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) } +impl_try_join_tuple! { try_join_10 TryJoin10 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) } +impl_try_join_tuple! { try_join_11 TryJoin11 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) } +impl_try_join_tuple! { try_join_12 TryJoin12 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) (L ResL) } + +#[cfg(test)] +mod test { + use super::*; + + use std::convert::Infallible; + use std::future; + use std::io::{self, Error, ErrorKind}; + + #[test] + fn all_ok() { + futures_lite::future::block_on(async { + let a = async { Ok::<_, Infallible>("aaaa") }; + let b = async { Ok::<_, Infallible>(1) }; + let c = async { Ok::<_, Infallible>('z') }; + + let result = (a, b, c).try_join().await; + assert_eq!(result, Ok(("aaaa", 1, 'z'))); + }) + } + + #[test] + fn one_err() { + futures_lite::future::block_on(async { + let err = Error::new(ErrorKind::Other, "oh no"); + let res: io::Result<(_, char)> = (future::ready(Ok("hello")), future::ready(Err(err))) + .try_join() + .await; + assert_eq!(res.unwrap_err().to_string(), String::from("oh no")); + }) + } + + #[test] + fn issue_135_resume_after_completion() { + use futures_lite::future::yield_now; + futures_lite::future::block_on(async { + let ok = async { Ok::<_, ()>(()) }; + let err = async { + yield_now().await; + Ok::<_, ()>(()) + }; + + let res = (ok, err).try_join().await; + + assert_eq!(res.unwrap(), ((), ())); + }); + } + + // #[test] + // fn does_not_leak_memory() { + // use core::cell::RefCell; + // use futures_lite::future::pending; + + // thread_local! { + // static NOT_LEAKING: RefCell = RefCell::new(false); + // }; + + // struct FlipFlagAtDrop; + // impl Drop for FlipFlagAtDrop { + // fn drop(&mut self) { + // NOT_LEAKING.with(|v| { + // *v.borrow_mut() = true; + // }); + // } + // } + + // futures_lite::future::block_on(async { + // // this will trigger Miri if we don't drop the memory + // let string = future::ready("memory leak".to_owned()); + + // // this will not flip the thread_local flag if we don't drop the memory + // let flip = future::ready(FlipFlagAtDrop); + + // let leak = (string, flip, pending::()).try_join(); + + // _ = futures_lite::future::poll_once(leak).await; + // }); + + // NOT_LEAKING.with(|flag| { + // assert!(*flag.borrow()); + // }) + // } +} From 54b728a9e392f4e4ff580d8a3f06d79122b8d1dc Mon Sep 17 00:00:00 2001 From: Yosh Date: Fri, 23 Jun 2023 02:57:50 +0200 Subject: [PATCH 5/5] validate miri --- src/future/try_join/tuple.rs | 68 ++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/src/future/try_join/tuple.rs b/src/future/try_join/tuple.rs index f379cc3..36f28e4 100644 --- a/src/future/try_join/tuple.rs +++ b/src/future/try_join/tuple.rs @@ -362,38 +362,38 @@ mod test { }); } - // #[test] - // fn does_not_leak_memory() { - // use core::cell::RefCell; - // use futures_lite::future::pending; - - // thread_local! { - // static NOT_LEAKING: RefCell = RefCell::new(false); - // }; - - // struct FlipFlagAtDrop; - // impl Drop for FlipFlagAtDrop { - // fn drop(&mut self) { - // NOT_LEAKING.with(|v| { - // *v.borrow_mut() = true; - // }); - // } - // } - - // futures_lite::future::block_on(async { - // // this will trigger Miri if we don't drop the memory - // let string = future::ready("memory leak".to_owned()); - - // // this will not flip the thread_local flag if we don't drop the memory - // let flip = future::ready(FlipFlagAtDrop); - - // let leak = (string, flip, pending::()).try_join(); - - // _ = futures_lite::future::poll_once(leak).await; - // }); - - // NOT_LEAKING.with(|flag| { - // assert!(*flag.borrow()); - // }) - // } + #[test] + fn does_not_leak_memory() { + use core::cell::RefCell; + use futures_lite::future::pending; + + thread_local! { + static NOT_LEAKING: RefCell = RefCell::new(false); + }; + + struct FlipFlagAtDrop; + impl Drop for FlipFlagAtDrop { + fn drop(&mut self) { + NOT_LEAKING.with(|v| { + *v.borrow_mut() = true; + }); + } + } + + futures_lite::future::block_on(async { + // this will trigger Miri if we don't drop the memory + let string = future::ready(io::Result::Ok("memory leak".to_owned())); + + // this will not flip the thread_local flag if we don't drop the memory + let flip = future::ready(io::Result::Ok(FlipFlagAtDrop)); + + let leak = (string, flip, pending::>()).try_join(); + + _ = futures_lite::future::poll_once(leak).await; + }); + + NOT_LEAKING.with(|flag| { + assert!(*flag.borrow()); + }) + } }