From 8a672dbbbeb17f883d9e068667c45bd3b3211bf1 Mon Sep 17 00:00:00 2001 From: Matheus Consoli Date: Sun, 20 Nov 2022 00:03:30 -0300 Subject: [PATCH 1/3] Remove `MaybeDone` from tuple::join --- src/future/join/tuple.rs | 187 ++++++++++++++++++++++++++++++--------- 1 file changed, 147 insertions(+), 40 deletions(-) diff --git a/src/future/join/tuple.rs b/src/future/join/tuple.rs index 3a17b76..f246e4c 100644 --- a/src/future/join/tuple.rs +++ b/src/future/join/tuple.rs @@ -1,15 +1,81 @@ use super::Join as JoinTrait; -use crate::utils::MaybeDone; +use crate::utils::{PollArray, RandomGenerator, WakerArray}; use core::fmt::{self, Debug}; use core::future::{Future, IntoFuture}; +use core::mem::MaybeUninit; use core::pin::Pin; use core::task::{Context, Poll}; use pin_project::pin_project; -macro_rules! impl_merge_tuple { - ($StructName:ident $($F:ident)*) => { +macro_rules! poll_future { + ($fut_idx:tt, $iteration:ident, $this:ident, $outputs:ident, $futures:ident . $fut_member:ident, $cx:ident) => { + if $fut_idx == $iteration { + if let Poll::Ready(value) = + unsafe { Pin::new_unchecked(&mut $futures.$fut_member) }.poll(&mut $cx) + { + $this.outputs.$fut_member.write(value); + *$this.completed += 1; + $this.state[$fut_idx].set_consumed(); + } + } + }; +} + +macro_rules! impl_join_tuple { + ($mod_name:ident $StructName:ident) => { + /// Waits for two similarly-typed futures to complete. + /// + /// This `struct` is created by the [`join`] method on the [`Join`] trait. See + /// its documentation for more. + /// + /// [`join`]: crate::future::Join::join + /// [`Join`]: crate::future::Join + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[allow(non_snake_case)] + pub struct $StructName {} + + impl fmt::Debug for $StructName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Join").finish() + } + } + + impl Future for $StructName { + type Output = (); + + fn poll( + self: Pin<&mut Self>, _cx: &mut Context<'_> + ) -> Poll { + Poll::Ready(()) + } + } + + impl JoinTrait for () { + type Output = (); + type Future = $StructName; + fn join(self) -> Self::Future { + $StructName {} + } + } + }; + ($mod_name:ident $StructName:ident $($F:ident)+) => { + mod $mod_name { + use core::mem::MaybeUninit; + use core::future::Future; + + #[pin_project::pin_project] + pub(super) struct Futures<$($F,)+> { $(#[pin] pub(super) $F: $F,)+ } + + pub(super) struct Outputs<$($F: Future,)+> { $(pub(super) $F: MaybeUninit<$F::Output>,)+ } + + #[repr(u8)] + pub(super) enum Indexes { $($F,)+ } + + pub(super) const LEN: usize = [$(Indexes::$F,)+].len(); + } + /// Waits for two similarly-typed futures to complete. /// /// This `struct` is created by the [`join`] method on the [`Join`] trait. See @@ -20,19 +86,23 @@ macro_rules! impl_merge_tuple { #[pin_project] #[must_use = "futures do nothing unless you `.await` or poll them"] #[allow(non_snake_case)] - pub struct $StructName<$($F: Future),*> { - done: bool, - $(#[pin] $F: MaybeDone<$F>,)* + pub struct $StructName<$($F: Future),+> { + #[pin] futures: $mod_name::Futures<$($F,)+>, + outputs: $mod_name::Outputs<$($F,)+>, + rng: RandomGenerator, + wakers: WakerArray<{$mod_name::LEN}>, + state: PollArray<{$mod_name::LEN}>, + completed: u8, } - impl<$($F),*> Debug for $StructName<$($F),*> + impl<$($F),+> Debug for $StructName<$($F),+> where $( $F: Future + Debug, $F::Output: Debug, - )* { + )+ { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("Join") - $(.field(&self.$F))* + $(.field(&self.futures.$F))+ .finish() } } @@ -40,59 +110,96 @@ macro_rules! impl_merge_tuple { #[allow(unused_mut)] #[allow(unused_parens)] #[allow(unused_variables)] - impl<$($F: Future),*> Future for $StructName<$($F),*> { - type Output = ($($F::Output,)*); + impl<$($F: Future),+> Future for $StructName<$($F),+> { + type Output = ($($F::Output,)+); fn poll( self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll { - let mut all_done = true; - let mut this = self.project(); - assert!(!*this.done, "Futures must not be polled after completing"); - - $(all_done &= this.$F.as_mut().poll(cx).is_ready();)* - - if all_done { - *this.done = true; - Poll::Ready(($(this.$F.take().unwrap(),)*)) - } else { - Poll::Pending + let this = self.project(); + + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + + const LEN: u8 = $mod_name::LEN as u8; + let r = this.rng.generate(LEN as u32) as u8; + + let mut futures = this.futures.project(); + + for index in (0..LEN).map(|n| (r + n).wrapping_rem(LEN) as usize) { + if !readiness.any_ready() { + return Poll::Pending; + } else if !readiness.clear_ready(index) || this.state[index].is_consumed() { + continue; + } + + drop(readiness); + + let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); + + $( + let fut_index = $mod_name::Indexes::$F as usize; + poll_future!( + fut_index, + index, + this, + outputs, + futures . $F, + cx + ); + )+ + + if *this.completed == LEN { + let out = { + let mut output = $mod_name::Outputs { $($F: MaybeUninit::uninit(),)+ }; + core::mem::swap(this.outputs, &mut output); + unsafe { ( $(output.$F.assume_init(),)+ ) } + }; + return Poll::Ready(out); + } + readiness = this.wakers.readiness().lock().unwrap(); } + + Poll::Pending } } #[allow(unused_parens)] - impl<$($F),*> JoinTrait for ($($F,)*) + impl<$($F),+> JoinTrait for ($($F,)+) where $( $F: IntoFuture, - )* { + )+ { type Output = ($($F::Output,)*); type Future = $StructName<$($F::IntoFuture),*>; fn join(self) -> Self::Future { - let ($($F,)*): ($($F,)*) = self; + let ($($F,)+): ($($F,)+) = self; $StructName { - done: false, - $($F: MaybeDone::new($F.into_future())),* + futures: $mod_name::Futures { $($F: $F.into_future(),)+ }, + rng: RandomGenerator::new(), + wakers: WakerArray::new(), + state: PollArray::new(), + outputs: $mod_name::Outputs { $($F: MaybeUninit::uninit(),)+ }, + completed: 0, } } } }; } -impl_merge_tuple! { Join0 } -impl_merge_tuple! { Join1 A } -impl_merge_tuple! { Join2 A B } -impl_merge_tuple! { Join3 A B C } -impl_merge_tuple! { Join4 A B C D } -impl_merge_tuple! { Join5 A B C D E } -impl_merge_tuple! { Join6 A B C D E F } -impl_merge_tuple! { Join7 A B C D E F G } -impl_merge_tuple! { Join8 A B C D E F G H } -impl_merge_tuple! { Join9 A B C D E F G H I } -impl_merge_tuple! { Join10 A B C D E F G H I J } -impl_merge_tuple! { Join11 A B C D E F G H I J K } -impl_merge_tuple! { Join12 A B C D E F G H I J K L } +impl_join_tuple! { join0 Join0 } +impl_join_tuple! { join1 Join1 A } +impl_join_tuple! { join2 Join2 A B } +impl_join_tuple! { join3 Join3 A B C } +impl_join_tuple! { join4 Join4 A B C D } +impl_join_tuple! { join5 Join5 A B C D E } +impl_join_tuple! { join6 Join6 A B C D E F } +impl_join_tuple! { join7 Join7 A B C D E F G } +impl_join_tuple! { join8 Join8 A B C D E F G H } +impl_join_tuple! { join9 Join9 A B C D E F G H I } +impl_join_tuple! { join10 Join10 A B C D E F G H I J } +impl_join_tuple! { join11 Join11 A B C D E F G H I J K } +impl_join_tuple! { join12 Join12 A B C D E F G H I J K L } #[cfg(test)] mod test { From 03c26b20da284f9beed5c1b4982cba72b2d075dc Mon Sep 17 00:00:00 2001 From: Matheus Consoli Date: Tue, 22 Nov 2022 01:26:47 -0300 Subject: [PATCH 2/3] Efficiently remove MaybeDone from tuple::join --- src/future/join/tuple.rs | 83 +++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 48 deletions(-) diff --git a/src/future/join/tuple.rs b/src/future/join/tuple.rs index f246e4c..4e44aa5 100644 --- a/src/future/join/tuple.rs +++ b/src/future/join/tuple.rs @@ -1,5 +1,5 @@ use super::Join as JoinTrait; -use crate::utils::{PollArray, RandomGenerator, WakerArray}; +use crate::utils::PollArray; use core::fmt::{self, Debug}; use core::future::{Future, IntoFuture}; @@ -9,17 +9,30 @@ use core::task::{Context, Poll}; use pin_project::pin_project; -macro_rules! poll_future { - ($fut_idx:tt, $iteration:ident, $this:ident, $outputs:ident, $futures:ident . $fut_member:ident, $cx:ident) => { +/// Generates the `poll` call for every `Future` inside `$futures`. +// This is implemented as a tt-muncher of the future name `$($F:ident)` +// and the future index `$($rest)`, taking advantage that we only support +// tuples up to 12 elements +// +// # References +// TT Muncher: https://veykril.github.io/tlborm/decl-macros/patterns/tt-muncher.html +macro_rules! poll { + (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => { if $fut_idx == $iteration { - if let Poll::Ready(value) = - unsafe { Pin::new_unchecked(&mut $futures.$fut_member) }.poll(&mut $cx) - { - $this.outputs.$fut_member.write(value); + if let Poll::Ready(value) = $futures.$fut_name.as_mut().poll($cx) { + $this.outputs.$fut_idx.write(value); *$this.completed += 1; $this.state[$fut_idx].set_consumed(); } } + poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*); + }; + + // base condition, no more futures to poll + (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, | $($rest:tt)*) => {}; + + ($iteration:ident, $this:ident, $futures:ident, $cx:ident, $LEN:ident, $($F:ident,)+) => { + poll!(@inner $iteration, $this, $futures, $cx, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11); }; } @@ -62,14 +75,10 @@ macro_rules! impl_join_tuple { }; ($mod_name:ident $StructName:ident $($F:ident)+) => { mod $mod_name { - use core::mem::MaybeUninit; - use core::future::Future; #[pin_project::pin_project] pub(super) struct Futures<$($F,)+> { $(#[pin] pub(super) $F: $F,)+ } - pub(super) struct Outputs<$($F: Future,)+> { $(pub(super) $F: MaybeUninit<$F::Output>,)+ } - #[repr(u8)] pub(super) enum Indexes { $($F,)+ } @@ -88,11 +97,9 @@ macro_rules! impl_join_tuple { #[allow(non_snake_case)] pub struct $StructName<$($F: Future),+> { #[pin] futures: $mod_name::Futures<$($F,)+>, - outputs: $mod_name::Outputs<$($F,)+>, - rng: RandomGenerator, - wakers: WakerArray<{$mod_name::LEN}>, + outputs: ($(MaybeUninit<$F::Output>,)+), state: PollArray<{$mod_name::LEN}>, - completed: u8, + completed: usize, } impl<$($F),+> Debug for $StructName<$($F),+> @@ -116,48 +123,30 @@ macro_rules! impl_join_tuple { fn poll( self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll { - let this = self.project(); + let mut this = self.project(); - let mut readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_waker(cx.waker()); - - const LEN: u8 = $mod_name::LEN as u8; - let r = this.rng.generate(LEN as u32) as u8; + const LEN: usize = $mod_name::LEN; let mut futures = this.futures.project(); - for index in (0..LEN).map(|n| (r + n).wrapping_rem(LEN) as usize) { - if !readiness.any_ready() { - return Poll::Pending; - } else if !readiness.clear_ready(index) || this.state[index].is_consumed() { + for index in 0..LEN { + if this.state[index].is_consumed() { continue; } - drop(readiness); - - let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); - - $( - let fut_index = $mod_name::Indexes::$F as usize; - poll_future!( - fut_index, - index, - this, - outputs, - futures . $F, - cx - ); - )+ + // generate the needed code to poll `futures.{index}` + poll!(index, this, futures, cx, LEN, $($F,)+); if *this.completed == LEN { let out = { - let mut output = $mod_name::Outputs { $($F: MaybeUninit::uninit(),)+ }; - core::mem::swap(this.outputs, &mut output); - unsafe { ( $(output.$F.assume_init(),)+ ) } + let mut out = ($(MaybeUninit::<$F::Output>::uninit(),)+); + core::mem::swap(&mut out, this.outputs); + let ($($F,)+) = out; + unsafe { ($($F.assume_init(),)+) } }; + return Poll::Ready(out); } - readiness = this.wakers.readiness().lock().unwrap(); } Poll::Pending @@ -175,11 +164,9 @@ macro_rules! impl_join_tuple { fn join(self) -> Self::Future { let ($($F,)+): ($($F,)+) = self; $StructName { - futures: $mod_name::Futures { $($F: $F.into_future(),)+ }, - rng: RandomGenerator::new(), - wakers: WakerArray::new(), + futures: $mod_name::Futures {$($F: $F.into_future(),)+}, state: PollArray::new(), - outputs: $mod_name::Outputs { $($F: MaybeUninit::uninit(),)+ }, + outputs: ($(MaybeUninit::<$F::Output>::uninit(),)+), completed: 0, } } From 7c272f6333aa260d30a4ed144d2216053df6634f Mon Sep 17 00:00:00 2001 From: Matheus Consoli Date: Tue, 22 Nov 2022 14:13:20 -0300 Subject: [PATCH 3/3] Make tuple Join fair --- src/future/join/tuple.rs | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/future/join/tuple.rs b/src/future/join/tuple.rs index 4e44aa5..68732ce 100644 --- a/src/future/join/tuple.rs +++ b/src/future/join/tuple.rs @@ -1,5 +1,5 @@ use super::Join as JoinTrait; -use crate::utils::PollArray; +use crate::utils::{PollArray, WakerArray}; use core::fmt::{self, Debug}; use core::future::{Future, IntoFuture}; @@ -19,7 +19,7 @@ use pin_project::pin_project; macro_rules! poll { (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => { if $fut_idx == $iteration { - if let Poll::Ready(value) = $futures.$fut_name.as_mut().poll($cx) { + if let Poll::Ready(value) = $futures.$fut_name.as_mut().poll(&mut $cx) { $this.outputs.$fut_idx.write(value); *$this.completed += 1; $this.state[$fut_idx].set_consumed(); @@ -99,6 +99,7 @@ macro_rules! impl_join_tuple { #[pin] futures: $mod_name::Futures<$($F,)+>, outputs: ($(MaybeUninit<$F::Output>,)+), state: PollArray<{$mod_name::LEN}>, + wakers: WakerArray<{$mod_name::LEN}>, completed: usize, } @@ -123,17 +124,33 @@ macro_rules! impl_join_tuple { fn poll( self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll { - let mut this = self.project(); - const LEN: usize = $mod_name::LEN; + let mut this = self.project(); + let all_completed = !(*this.completed == LEN); + assert!(all_completed, "Futures must not be polled after completing"); + let mut futures = this.futures.project(); + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + for index in 0..LEN { - if this.state[index].is_consumed() { + if !readiness.any_ready() { + // nothing ready yet + return Poll::Pending; + } + if !readiness.clear_ready(index) || this.state[index].is_consumed() { + // future not ready yet or already polled to completion, skip continue; } + // unlock readiness so we don't deadlock when polling + drop(readiness); + + // obtain the intermediate waker + let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); + // generate the needed code to poll `futures.{index}` poll!(index, this, futures, cx, LEN, $($F,)+); @@ -147,6 +164,7 @@ macro_rules! impl_join_tuple { return Poll::Ready(out); } + readiness = this.wakers.readiness().lock().unwrap(); } Poll::Pending @@ -167,6 +185,7 @@ macro_rules! impl_join_tuple { futures: $mod_name::Futures {$($F: $F.into_future(),)+}, state: PollArray::new(), outputs: ($(MaybeUninit::<$F::Output>::uninit(),)+), + wakers: WakerArray::new(), completed: 0, } }