Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove MaybeDone from impl Join for array #72

Merged
merged 2 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 109 additions & 36 deletions src/future/join/array.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use super::Join as JoinTrait;
use crate::utils::MaybeDone;
use crate::utils;
use crate::utils::PollState;

use core::array;
use core::fmt;
use core::future::{Future, IntoFuture};
use core::mem::{self, MaybeUninit};
use core::pin::Pin;
use core::task::{Context, Poll};

use pin_project::pin_project;
use pin_project::{pin_project, pinned_drop};

/// Waits for two similarly-typed futures to complete.
///
Expand All @@ -16,12 +19,46 @@ use pin_project::pin_project;
/// [`join`]: crate::future::Join::join
/// [`Join`]: crate::future::Join
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[pin_project]
#[pin_project(PinnedDrop)]
pub struct Join<Fut, const N: usize>
where
Fut: Future,
{
pub(crate) elems: [MaybeDone<Fut>; N],
consumed: bool,
pending: usize,
items: [MaybeUninit<<Fut as Future>::Output>; N],
state: [PollState; N],
#[pin]
futures: [Fut; N],
}

impl<Fut, const N: usize> Join<Fut, N>
where
Fut: Future,
{
#[inline]
pub(crate) fn new(futures: [Fut; N]) -> Self {
Join {
consumed: false,
pending: N,
items: array::from_fn(|_| MaybeUninit::uninit()),
state: [PollState::default(); N],
futures,
}
}
}

impl<Fut, const N: usize> JoinTrait for [Fut; N]
where
Fut: IntoFuture,
{
type Output = [Fut::Output; N];
type Future = Join<Fut::IntoFuture, N>;

#[inline]
fn join(self) -> Self::Future {
yoshuawuyts marked this conversation as resolved.
Show resolved Hide resolved
Join::new(self.map(IntoFuture::into_future))
}
}

impl<Fut, const N: usize> fmt::Debug for Join<Fut, N>
Expand All @@ -30,7 +67,7 @@ where
Fut::Output: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.elems.iter()).finish()
f.debug_list().entries(self.state.iter()).finish()
}
}

Expand All @@ -40,65 +77,101 @@ where
{
type Output = [Fut::Output; N];

#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut all_done = true;
let mut this = self.project();

let this = self.project();
assert!(
!*this.consumed,
"Futures must not be polled after completing"
);

for elem in this.elems.iter_mut() {
let elem = unsafe { Pin::new_unchecked(elem) };
if elem.poll(cx).is_pending() {
all_done = false;
// Poll all futures
for (i, fut) in utils::iter_pin_mut(this.futures.as_mut()).enumerate() {
if this.state[i].is_pending() {
if let Poll::Ready(value) = fut.poll(cx) {
this.items[i] = MaybeUninit::new(value);
this.state[i] = PollState::Done;
*this.pending -= 1;
}
}
}

if all_done {
use core::array;
use core::mem::MaybeUninit;
// Check whether we're all done now or need to keep going.
if *this.pending == 0 {
// Mark all data as "consumed" before we take it
*this.consumed = true;
for state in this.state.iter_mut() {
debug_assert!(state.is_done(), "Future should have reached a `Done` state");
*state = PollState::Consumed;
}

// Create the result array based on the indices
// TODO: replace with `MaybeUninit::uninit_array()` when it becomes stable
let mut out: [_; N] = array::from_fn(|_| MaybeUninit::uninit());
let mut items = array::from_fn(|_| MaybeUninit::uninit());
mem::swap(this.items, &mut items);

// NOTE: this clippy attribute can be removed once we can `collect` into `[usize; K]`.
#[allow(clippy::needless_range_loop)]
for (i, el) in this.elems.iter_mut().enumerate() {
let el = unsafe { Pin::new_unchecked(el) }.take().unwrap();
out[i] = MaybeUninit::new(el);
}
let result = unsafe { out.as_ptr().cast::<[Fut::Output; N]>().read() };
Poll::Ready(result)
// SAFETY: we've checked with the state that all of our outputs have been
// filled, which means we're ready to take the data and assume it's initialized.
let items = unsafe { utils::array_assume_init(items) };
Poll::Ready(items)
} else {
Poll::Pending
}
}
}

impl<Fut, const N: usize> JoinTrait for [Fut; N]
/// Drop the already initialized values on cancellation.
#[pinned_drop]
impl<Fut, const N: usize> PinnedDrop for Join<Fut, N>
where
Fut: IntoFuture,
Fut: Future,
{
type Output = [Fut::Output; N];
type Future = Join<Fut::IntoFuture, N>;
fn join(self) -> Self::Future {
Join {
elems: self.map(|fut| MaybeDone::new(fut.into_future())),
fn drop(self: Pin<&mut Self>) {
let this = self.project();

// Get the indexes of the initialized values.
let indexes = this
.state
.iter_mut()
.enumerate()
.filter(|(_, state)| state.is_done())
.map(|(i, _)| i);

// Drop each value at the index.
for i in indexes {
// SAFETY: we've just filtered down to *only* the initialized values.
// We can assume they're initialized, and this is where we drop them.
unsafe { this.items[i].assume_init_drop() };
}
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::utils::DummyWaker;

use std::future;
use std::future::Future;
use std::sync::Arc;
use std::task::Context;

#[test]
fn smoke() {
futures_lite::future::block_on(async {
let res = [future::ready("hello"), future::ready("world")]
.join()
.await;
assert_eq!(res, ["hello", "world"]);
let fut = [future::ready("hello"), future::ready("world")].join();
assert_eq!(fut.await, ["hello", "world"]);
});
}

#[test]
fn debug() {
let mut fut = [future::ready("hello"), future::ready("world")].join();
assert_eq!(format!("{:?}", fut), "[Pending, Pending]");
let mut fut = Pin::new(&mut fut);

let waker = Arc::new(DummyWaker()).into();
let mut cx = Context::from_waker(&waker);
let _ = fut.as_mut().poll(&mut cx);
assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]");
}
}
22 changes: 22 additions & 0 deletions src/utils/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use std::mem::{self, MaybeUninit};

/// Extracts the values from an array of `MaybeUninit` containers.
///
/// # Safety
///
/// It is up to the caller to guarantee that all elements of the array are
/// in an initialized state.
///
/// Inlined version of: https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.array_assume_init
pub(crate) unsafe fn array_assume_init<T, const N: usize>(array: [MaybeUninit<T>; N]) -> [T; N] {
// SAFETY:
// * The caller guarantees that all elements of the array are initialized
// * `MaybeUninit<T>` and T are guaranteed to have the same layout
// * `MaybeUninit` does not drop, so there are no double-frees
// And thus the conversion is safe
let ret = unsafe { (&array as *const _ as *const [T; N]).read() };

// FIXME: required to avoid `~const Destruct` bound
mem::forget(array);
ret
}
2 changes: 2 additions & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//! This implementation was taken from the original `macro_rules` `join/try_join`
//! macros in the `futures-preview` crate.

mod array;
mod fuse;
mod maybe_done;
mod pin;
Expand All @@ -13,6 +14,7 @@ mod rng;
mod tuple;
mod waker;

pub(crate) use array::array_assume_init;
pub(crate) use fuse::Fuse;
pub(crate) use maybe_done::MaybeDone;
pub(crate) use pin::{get_pin_mut, get_pin_mut_from_vec, iter_pin_mut, iter_pin_mut_vec};
Expand Down
Empty file removed src/utils/slice_ext.rs
Empty file.