Skip to content

Commit

Permalink
Merge pull request #72 from yoshuawuyts/no-maybedone-join
Browse files Browse the repository at this point in the history
Remove `MaybeDone` from `impl Join for array`
  • Loading branch information
yoshuawuyts authored Nov 16, 2022
2 parents 49189ff + 8e84401 commit 2d35a6b
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 36 deletions.
145 changes: 109 additions & 36 deletions src/future/join/array.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use super::Join as JoinTrait;
use crate::utils::MaybeDone;
use crate::utils;
use crate::utils::PollState;

use core::array;
use core::fmt;
use core::future::{Future, IntoFuture};
use core::mem::{self, MaybeUninit};
use core::pin::Pin;
use core::task::{Context, Poll};

use pin_project::pin_project;
use pin_project::{pin_project, pinned_drop};

/// Waits for two similarly-typed futures to complete.
///
Expand All @@ -16,12 +19,46 @@ use pin_project::pin_project;
/// [`join`]: crate::future::Join::join
/// [`Join`]: crate::future::Join
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[pin_project]
#[pin_project(PinnedDrop)]
pub struct Join<Fut, const N: usize>
where
Fut: Future,
{
pub(crate) elems: [MaybeDone<Fut>; N],
consumed: bool,
pending: usize,
items: [MaybeUninit<<Fut as Future>::Output>; N],
state: [PollState; N],
#[pin]
futures: [Fut; N],
}

impl<Fut, const N: usize> Join<Fut, N>
where
Fut: Future,
{
#[inline]
pub(crate) fn new(futures: [Fut; N]) -> Self {
Join {
consumed: false,
pending: N,
items: array::from_fn(|_| MaybeUninit::uninit()),
state: [PollState::default(); N],
futures,
}
}
}

impl<Fut, const N: usize> JoinTrait for [Fut; N]
where
Fut: IntoFuture,
{
type Output = [Fut::Output; N];
type Future = Join<Fut::IntoFuture, N>;

#[inline]
fn join(self) -> Self::Future {
Join::new(self.map(IntoFuture::into_future))
}
}

impl<Fut, const N: usize> fmt::Debug for Join<Fut, N>
Expand All @@ -30,7 +67,7 @@ where
Fut::Output: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.elems.iter()).finish()
f.debug_list().entries(self.state.iter()).finish()
}
}

Expand All @@ -40,65 +77,101 @@ where
{
type Output = [Fut::Output; N];

#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut all_done = true;
let mut this = self.project();

let this = self.project();
assert!(
!*this.consumed,
"Futures must not be polled after completing"
);

for elem in this.elems.iter_mut() {
let elem = unsafe { Pin::new_unchecked(elem) };
if elem.poll(cx).is_pending() {
all_done = false;
// Poll all futures
for (i, fut) in utils::iter_pin_mut(this.futures.as_mut()).enumerate() {
if this.state[i].is_pending() {
if let Poll::Ready(value) = fut.poll(cx) {
this.items[i] = MaybeUninit::new(value);
this.state[i] = PollState::Done;
*this.pending -= 1;
}
}
}

if all_done {
use core::array;
use core::mem::MaybeUninit;
// Check whether we're all done now or need to keep going.
if *this.pending == 0 {
// Mark all data as "consumed" before we take it
*this.consumed = true;
for state in this.state.iter_mut() {
debug_assert!(state.is_done(), "Future should have reached a `Done` state");
*state = PollState::Consumed;
}

// Create the result array based on the indices
// TODO: replace with `MaybeUninit::uninit_array()` when it becomes stable
let mut out: [_; N] = array::from_fn(|_| MaybeUninit::uninit());
let mut items = array::from_fn(|_| MaybeUninit::uninit());
mem::swap(this.items, &mut items);

// NOTE: this clippy attribute can be removed once we can `collect` into `[usize; K]`.
#[allow(clippy::needless_range_loop)]
for (i, el) in this.elems.iter_mut().enumerate() {
let el = unsafe { Pin::new_unchecked(el) }.take().unwrap();
out[i] = MaybeUninit::new(el);
}
let result = unsafe { out.as_ptr().cast::<[Fut::Output; N]>().read() };
Poll::Ready(result)
// SAFETY: we've checked with the state that all of our outputs have been
// filled, which means we're ready to take the data and assume it's initialized.
let items = unsafe { utils::array_assume_init(items) };
Poll::Ready(items)
} else {
Poll::Pending
}
}
}

impl<Fut, const N: usize> JoinTrait for [Fut; N]
/// Drop the already initialized values on cancellation.
#[pinned_drop]
impl<Fut, const N: usize> PinnedDrop for Join<Fut, N>
where
Fut: IntoFuture,
Fut: Future,
{
type Output = [Fut::Output; N];
type Future = Join<Fut::IntoFuture, N>;
fn join(self) -> Self::Future {
Join {
elems: self.map(|fut| MaybeDone::new(fut.into_future())),
fn drop(self: Pin<&mut Self>) {
let this = self.project();

// Get the indexes of the initialized values.
let indexes = this
.state
.iter_mut()
.enumerate()
.filter(|(_, state)| state.is_done())
.map(|(i, _)| i);

// Drop each value at the index.
for i in indexes {
// SAFETY: we've just filtered down to *only* the initialized values.
// We can assume they're initialized, and this is where we drop them.
unsafe { this.items[i].assume_init_drop() };
}
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::utils::DummyWaker;

use std::future;
use std::future::Future;
use std::sync::Arc;
use std::task::Context;

#[test]
fn smoke() {
futures_lite::future::block_on(async {
let res = [future::ready("hello"), future::ready("world")]
.join()
.await;
assert_eq!(res, ["hello", "world"]);
let fut = [future::ready("hello"), future::ready("world")].join();
assert_eq!(fut.await, ["hello", "world"]);
});
}

#[test]
fn debug() {
let mut fut = [future::ready("hello"), future::ready("world")].join();
assert_eq!(format!("{:?}", fut), "[Pending, Pending]");
let mut fut = Pin::new(&mut fut);

let waker = Arc::new(DummyWaker()).into();
let mut cx = Context::from_waker(&waker);
let _ = fut.as_mut().poll(&mut cx);
assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]");
}
}
22 changes: 22 additions & 0 deletions src/utils/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use std::mem::{self, MaybeUninit};

/// Extracts the values from an array of `MaybeUninit` containers.
///
/// # Safety
///
/// It is up to the caller to guarantee that all elements of the array are
/// in an initialized state.
///
/// Inlined version of: https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.array_assume_init
pub(crate) unsafe fn array_assume_init<T, const N: usize>(array: [MaybeUninit<T>; N]) -> [T; N] {
// SAFETY:
// * The caller guarantees that all elements of the array are initialized
// * `MaybeUninit<T>` and T are guaranteed to have the same layout
// * `MaybeUninit` does not drop, so there are no double-frees
// And thus the conversion is safe
let ret = unsafe { (&array as *const _ as *const [T; N]).read() };

// FIXME: required to avoid `~const Destruct` bound
mem::forget(array);
ret
}
2 changes: 2 additions & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//! This implementation was taken from the original `macro_rules` `join/try_join`
//! macros in the `futures-preview` crate.
mod array;
mod fuse;
mod maybe_done;
mod pin;
Expand All @@ -13,6 +14,7 @@ mod rng;
mod tuple;
mod wakers;

pub(crate) use array::array_assume_init;
pub(crate) use fuse::Fuse;
pub(crate) use maybe_done::MaybeDone;
pub(crate) use pin::{get_pin_mut, get_pin_mut_from_vec, iter_pin_mut, iter_pin_mut_vec};
Expand Down
Empty file removed src/utils/slice_ext.rs
Empty file.

0 comments on commit 2d35a6b

Please sign in to comment.