Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tuple::join - drop futures in-place #143

Merged
merged 2 commits into from
Jun 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 69 additions & 20 deletions src/future/join/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,94 @@ use core::future::{Future, IntoFuture};
use core::mem::MaybeUninit;
use core::pin::Pin;
use core::task::{Context, Poll};
use std::mem::ManuallyDrop;
use std::ops::DerefMut;

use pin_project::{pin_project, pinned_drop};

/// Generates the `poll` call for every `Future` inside `$futures`.
///
/// SAFETY: pretty please only call this after having made very sure that the future you're trying
/// to call is actually marked `ready!`. If Rust had unsafe macros, this would be one.
//
// This is implemented as a tt-muncher of the future name `$($F:ident)`
// and the future index `$($rest)`, taking advantage that we only support
// tuples up to 12 elements
//
// # References
// TT Muncher: https://veykril.github.io/tlborm/decl-macros/patterns/tt-muncher.html
macro_rules! poll {
macro_rules! unsafe_poll {
// recursively iterate
(@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => {
if $fut_idx == $iteration {
if let Poll::Ready(value) = $futures.$fut_name.as_mut().poll(&mut $cx) {

if let Poll::Ready(value) = unsafe {
$futures.$fut_name.as_mut()
.map_unchecked_mut(|t| t.deref_mut())
.poll(&mut $cx)
} {
$this.outputs.$fut_idx.write(value);
*$this.completed += 1;
$this.state[$fut_idx].set_ready();
// SAFETY: the future state has been changed to "ready" which
// means we'll no longer poll the future, so it's safe to drop
unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) };
}
}
poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*);
unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*);
};

// base condition, no more futures to poll
// base condition
(@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, | $($rest:tt)*) => {};

// macro start
($iteration:ident, $this:ident, $futures:ident, $cx:ident, $LEN:ident, $($F:ident,)+) => {
poll!(@inner $iteration, $this, $futures, $cx, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
};
}

macro_rules! drop_outputs {
(@drop $output:ident, $($rem_outs:ident,)* | $states:expr, $stix:tt, $($rem_idx:tt,)*) => {
if $states[$stix].is_ready() {
// SAFETY: we're filtering out only the outputs marked as `ready`,
// which means that this memory is initialized
/// Drop all initialized values
macro_rules! drop_initialized_values {
// recursively iterate
(@drop $output:ident, $($rem_outs:ident,)* | $states:expr, $state_idx:tt, $($rem_idx:tt,)*) => {
if $states[$state_idx].is_ready() {
// SAFETY: we've just filtered down to *only* the initialized values.
// We can assume they're initialized, and this is where we drop them.
unsafe { $output.assume_init_drop() };
$states[$stix].set_consumed();
$states[$state_idx].set_consumed();
}
drop_outputs!(@drop $($rem_outs,)* | $states, $($rem_idx,)*);
drop_initialized_values!(@drop $($rem_outs,)* | $states, $($rem_idx,)*);
};

// base condition, no more outputs to look
// base condition
(@drop | $states:expr, $($rem_idx:tt,)*) => {};

// macro start
($($outs:ident,)+ | $states:expr) => {
drop_outputs!(@drop $($outs,)+ | $states, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,);
drop_initialized_values!(@drop $($outs,)+ | $states, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,);
};
}

/// Drop all pending futures
macro_rules! drop_pending_futures {
// recursively iterate
(@inner $states:ident, $futures:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => {
if $states[$fut_idx].is_pending() {
// SAFETY: We're accessing the value behind the pinned reference to drop it exactly once.
let futures = unsafe { $futures.as_mut().get_unchecked_mut() };
// SAFETY: we've just filtered down to *only* the initialized values.
// We can assume they're initialized, and this is where we drop them.
unsafe { ManuallyDrop::drop(&mut futures.$fut_name) };
}
drop_pending_futures!(@inner $states, $futures, $($F)* | $($rest)*);
};

// base condition
(@inner $states:ident, $futures:ident, | $($rest:tt)*) => {};

// macro start
($states:ident, $futures:ident, $($F:ident,)+) => {
drop_pending_futures!(@inner $states, $futures, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
};
}

Expand Down Expand Up @@ -94,9 +136,13 @@ macro_rules! impl_join_tuple {
};
($mod_name:ident $StructName:ident $($F:ident)+) => {
mod $mod_name {
use std::mem::ManuallyDrop;

#[pin_project::pin_project]
pub(super) struct Futures<$($F,)+> { $(#[pin] pub(super) $F: $F,)+ }
pub(super) struct Futures<$($F,)+> {$(
#[pin]
pub(super) $F: ManuallyDrop<$F>,
)+}

#[repr(u8)]
pub(super) enum Indexes { $($F,)+ }
Expand All @@ -115,7 +161,8 @@ macro_rules! impl_join_tuple {
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[allow(non_snake_case)]
pub struct $StructName<$($F: Future),+> {
#[pin] futures: $mod_name::Futures<$($F,)+>,
#[pin]
futures: $mod_name::Futures<$($F,)+>,
outputs: ($(MaybeUninit<$F::Output>,)+),
// trace the state of outputs, marking them as ready or consumed
// then, drop the non-consumed values, if any
Expand Down Expand Up @@ -172,7 +219,8 @@ macro_rules! impl_join_tuple {
let mut cx = Context::from_waker(this.wakers.get(index).unwrap());

// generate the needed code to poll `futures.{index}`
poll!(index, this, futures, cx, LEN, $($F,)+);
// SAFETY: the future's state should be "pending", so it's safe to poll
unsafe_poll!(index, this, futures, cx, LEN, $($F,)+);

if *this.completed == LEN {
let out = {
Expand Down Expand Up @@ -201,7 +249,9 @@ macro_rules! impl_join_tuple {
let ($(ref mut $F,)+) = this.outputs;

let states = this.state;
drop_outputs!($($F,)+ | states);
let mut futures = this.futures;
drop_initialized_values!($($F,)+ | states);
drop_pending_futures!(states, futures, $($F,)+);
}
}

Expand All @@ -216,7 +266,7 @@ macro_rules! impl_join_tuple {
fn join(self) -> Self::Future {
let ($($F,)+): ($($F,)+) = self;
$StructName {
futures: $mod_name::Futures {$($F: $F.into_future(),)+},
futures: $mod_name::Futures {$($F: ManuallyDrop::new($F.into_future()),)+},
state: PollArray::new(),
outputs: ($(MaybeUninit::<$F::Output>::uninit(),)+),
wakers: WakerArray::new(),
Expand All @@ -225,7 +275,6 @@ macro_rules! impl_join_tuple {
}
}
};

}

impl_join_tuple! { join0 Join0 }
Expand Down