Skip to content

Commit

Permalink
Remove MaybeDone from tuple::join
Browse files Browse the repository at this point in the history
  • Loading branch information
matheus-consoli committed Nov 20, 2022
1 parent e3f61cd commit 072998c
Showing 1 changed file with 148 additions and 40 deletions.
188 changes: 148 additions & 40 deletions src/future/join/tuple.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,81 @@
use super::Join as JoinTrait;
use crate::utils::MaybeDone;
use crate::utils::{PollArray, RandomGenerator, WakerArray};

use core::fmt::{self, Debug};
use core::future::{Future, IntoFuture};
use core::mem::MaybeUninit;
use core::pin::Pin;
use core::task::{Context, Poll};

use pin_project::pin_project;

macro_rules! impl_merge_tuple {
($StructName:ident $($F:ident)*) => {
macro_rules! poll_future {
($fut_idx:tt, $iteration:ident, $this:ident, $outputs:ident, $futures:ident . $fut_member:ident, $cx:ident) => {
if $fut_idx == $iteration {
if let Poll::Ready(value) =
unsafe { Pin::new_unchecked(&mut $futures.$fut_member) }.poll(&mut $cx)
{
$this.outputs.$fut_member.write(value);
*$this.completed += 1;
$this.state[$fut_idx].set_consumed();
}
}
};
}

macro_rules! impl_join_tuple {
($mod_name:ident $StructName:ident) => {
/// Waits for two similarly-typed futures to complete.
///
/// This `struct` is created by the [`join`] method on the [`Join`] trait. See
/// its documentation for more.
///
/// [`join`]: crate::future::Join::join
/// [`Join`]: crate::future::Join
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[allow(non_snake_case)]
pub struct $StructName {}

impl fmt::Debug for $StructName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Join").finish()
}
}

impl Future for $StructName {
type Output = ();

fn poll(
self: Pin<&mut Self>, _cx: &mut Context<'_>
) -> Poll<Self::Output> {
Poll::Ready(())
}
}

impl JoinTrait for () {
type Output = ();
type Future = $StructName;
fn join(self) -> Self::Future {
$StructName {}
}
}
};
($mod_name:ident $StructName:ident $($F:ident)+) => {
mod $mod_name {
use core::mem::MaybeUninit;
use core::future::Future;

#[pin_project::pin_project]
pub(super) struct Futures<$($F,)+> { $(#[pin] pub(super) $F: $F,)+ }

pub(super) struct Outputs<$($F: Future,)+> { $(pub(super) $F: MaybeUninit<$F::Output>,)+ }

#[repr(u8)]
pub(super) enum Indexes { $($F,)+ }

pub(super) const LEN: usize = [$(Indexes::$F,)+].len();
}

/// Waits for two similarly-typed futures to complete.
///
/// This `struct` is created by the [`join`] method on the [`Join`] trait. See
Expand All @@ -20,79 +86,121 @@ macro_rules! impl_merge_tuple {
#[pin_project]
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[allow(non_snake_case)]
pub struct $StructName<$($F: Future),*> {
done: bool,
$(#[pin] $F: MaybeDone<$F>,)*
pub struct $StructName<$($F: Future),+> {
#[pin] futures: $mod_name::Futures<$($F,)+>,
outputs: $mod_name::Outputs<$($F,)+>,
rng: RandomGenerator,
wakers: WakerArray<{$mod_name::LEN}>,
state: PollArray<{$mod_name::LEN}>,
completed: u8,
}

impl<$($F),*> Debug for $StructName<$($F),*>
impl<$($F),+> Debug for $StructName<$($F),+>
where $(
$F: Future + Debug,
$F::Output: Debug,
)* {
)+ {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Join")
$(.field(&self.$F))*
$(.field(&self.futures.$F))+
.finish()
}
}

#[allow(unused_mut)]
#[allow(unused_parens)]
#[allow(unused_variables)]
impl<$($F: Future),*> Future for $StructName<$($F),*> {
type Output = ($($F::Output,)*);
impl<$($F: Future),+> Future for $StructName<$($F),+> {
type Output = ($($F::Output,)+);

fn poll(
self: Pin<&mut Self>, cx: &mut Context<'_>
) -> Poll<Self::Output> {
let mut all_done = true;
let mut this = self.project();
assert!(!*this.done, "Futures must not be polled after completing");

$(all_done &= this.$F.as_mut().poll(cx).is_ready();)*

if all_done {
*this.done = true;
Poll::Ready(($(this.$F.take().unwrap(),)*))
} else {
Poll::Pending
let this = self.project();

let mut readiness = this.wakers.readiness().lock().unwrap();
readiness.set_waker(cx.waker());

const LEN: u8 = $mod_name::LEN as u8;
let r = this.rng.generate(LEN as u32) as u8;

let mut futures = this.futures.project();

for index in (0..LEN).map(|n| (r + n).wrapping_rem(LEN) as usize) {
if !readiness.any_ready() {
return Poll::Pending;
} else if !readiness.clear_ready(index) || this.state[index].is_consumed() {
continue;
}

drop(readiness);

let mut cx = Context::from_waker(this.wakers.get(index).unwrap());

$(
let fut_index = $mod_name::Indexes::$F as usize;
poll_future!(
fut_index,
index,
this,
outputs,
futures . $F,
cx
);
)+

if *this.completed == LEN {
let out = {
let mut output = $mod_name::Outputs { $($F: MaybeUninit::uninit(),)+ };
core::mem::swap(this.outputs, &mut output);
core::mem::forget(this.outputs);
unsafe { ( $(output.$F.assume_init(),)+ ) }
};
return Poll::Ready(out);
}
readiness = this.wakers.readiness().lock().unwrap();
}

Poll::Pending
}
}

#[allow(unused_parens)]
impl<$($F),*> JoinTrait for ($($F,)*)
impl<$($F),+> JoinTrait for ($($F,)+)
where $(
$F: IntoFuture,
)* {
)+ {
type Output = ($($F::Output,)*);
type Future = $StructName<$($F::IntoFuture),*>;

fn join(self) -> Self::Future {
let ($($F,)*): ($($F,)*) = self;
let ($($F,)+): ($($F,)+) = self;
$StructName {
done: false,
$($F: MaybeDone::new($F.into_future())),*
futures: $mod_name::Futures { $($F: $F.into_future(),)+ },
rng: RandomGenerator::new(),
wakers: WakerArray::new(),
state: PollArray::new(),
outputs: $mod_name::Outputs { $($F: MaybeUninit::uninit(),)+ },
completed: 0,
}
}
}
};
}

impl_merge_tuple! { Join0 }
impl_merge_tuple! { Join1 A }
impl_merge_tuple! { Join2 A B }
impl_merge_tuple! { Join3 A B C }
impl_merge_tuple! { Join4 A B C D }
impl_merge_tuple! { Join5 A B C D E }
impl_merge_tuple! { Join6 A B C D E F }
impl_merge_tuple! { Join7 A B C D E F G }
impl_merge_tuple! { Join8 A B C D E F G H }
impl_merge_tuple! { Join9 A B C D E F G H I }
impl_merge_tuple! { Join10 A B C D E F G H I J }
impl_merge_tuple! { Join11 A B C D E F G H I J K }
impl_merge_tuple! { Join12 A B C D E F G H I J K L }
impl_join_tuple! { join0 Join0 }
impl_join_tuple! { join1 Join1 A }
impl_join_tuple! { join2 Join2 A B }
impl_join_tuple! { join3 Join3 A B C }
impl_join_tuple! { join4 Join4 A B C D }
impl_join_tuple! { join5 Join5 A B C D E }
impl_join_tuple! { join6 Join6 A B C D E F }
impl_join_tuple! { join7 Join7 A B C D E F G }
impl_join_tuple! { join8 Join8 A B C D E F G H }
impl_join_tuple! { join9 Join9 A B C D E F G H I }
impl_join_tuple! { join10 Join10 A B C D E F G H I J }
impl_join_tuple! { join11 Join11 A B C D E F G H I J K }
impl_join_tuple! { join12 Join12 A B C D E F G H I J K L }

#[cfg(test)]
mod test {
Expand Down

0 comments on commit 072998c

Please sign in to comment.