From 74a8095d56d7f90f7f9220be5b48d00179e91843 Mon Sep 17 00:00:00 2001 From: Yoshua Wuyts <2467194+yoshuawuyts@users.noreply.github.com> Date: Mon, 7 Nov 2022 18:21:19 +0100 Subject: [PATCH 1/2] Remove `MaybeDone` from `impl Join for Vec` --- src/future/into_future.rs | 8 +-- src/future/join/vec.rs | 138 ++++++++++++++++++++++++++++---------- src/utils/metadata.rs | 105 +++++++++++++++++++++++++++++ src/utils/mod.rs | 4 +- src/utils/pin.rs | 19 +++--- 5 files changed, 222 insertions(+), 52 deletions(-) create mode 100644 src/utils/metadata.rs diff --git a/src/future/into_future.rs b/src/future/into_future.rs index 252ef70..4909e2a 100644 --- a/src/future/into_future.rs +++ b/src/future/into_future.rs @@ -19,12 +19,8 @@ impl IntoFuture for Vec { type IntoFuture = crate::future::join::vec::Join; fn into_future(self) -> Self::IntoFuture { - let elems = self - .into_iter() - .map(|fut| MaybeDone::new(core::future::IntoFuture::into_future(fut))) - .collect::>() - .into(); - crate::future::join::vec::Join::new(elems) + use crate::future::join::vec::Join; + Join::new(self.into_iter().collect()) } } diff --git a/src/future/join/vec.rs b/src/future/join/vec.rs index b9eb6f1..7d807d8 100644 --- a/src/future/join/vec.rs +++ b/src/future/join/vec.rs @@ -1,50 +1,58 @@ use super::Join as JoinTrait; -use crate::utils::iter_pin_mut; -use crate::utils::MaybeDone; +use crate::utils::{iter_pin_mut_vec, Metadata}; use core::fmt; use core::future::{Future, IntoFuture}; -use core::mem; use core::pin::Pin; use core::task::{Context, Poll}; -use std::boxed::Box; +use std::mem::{self, MaybeUninit}; use std::vec::Vec; -impl JoinTrait for Vec -where - Fut: IntoFuture, -{ - type Output = Vec; - type Future = Join; - - fn join(self) -> Self::Future { - let elems = self - .into_iter() - .map(|fut| MaybeDone::new(fut.into_future())) - .collect::>() - .into(); - Join::new(elems) - } -} +use pin_project::{pin_project, pinned_drop}; /// Waits for two similarly-typed futures to complete. /// /// Awaits multiple futures simultaneously, returning the output of the /// futures once both complete. #[must_use = "futures do nothing unless you `.await` or poll them"] +#[pin_project(PinnedDrop)] pub struct Join where Fut: Future, { - elems: Pin]>>, + #[pin] + futures: Vec, + items: Vec::Output>>, + metadata: Vec, } impl Join where Fut: Future, { - pub(crate) fn new(elems: Pin]>>) -> Self { - Self { elems } + pub(crate) fn new(futures: Vec) -> Self { + Join { + items: std::iter::repeat_with(|| MaybeUninit::uninit()) + .take(futures.len()) + .collect(), + metadata: std::iter::successors(Some(0), |prev| Some(prev + 1)) + .take(futures.len()) + .map(Metadata::new) + .collect(), + futures, + } + } +} + +impl JoinTrait for Vec +where + Fut: IntoFuture, +{ + type Output = Vec; + type Future = Join; + + fn join(self) -> Self::Future { + Join::new(self.into_iter().map(IntoFuture::into_future).collect()) } } @@ -54,7 +62,8 @@ where Fut::Output: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Join").field("elems", &self.elems).finish() + // TODO: fix debug output + f.debug_struct("Join").finish() } } @@ -64,23 +73,82 @@ where { type Output = Vec; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut all_done = true; + // SAFETY: see https://github.com/rust-lang/rust/issues/104108, + // projecting through slices is fine now, but it's not yet guaranteed to + // work. We need to guarantee structural pinning works as expected for it to + // be provably sound. + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + // Poll all futures + let futures = this.futures.as_mut(); + for (i, fut) in iter_pin_mut_vec(futures).enumerate() { + if this.metadata[i].is_done() { + continue; + } - for elem in iter_pin_mut(self.elems.as_mut()) { - if elem.poll(cx).is_pending() { - all_done = false; + if let Poll::Ready(value) = fut.poll(cx) { + this.items[i] = MaybeUninit::new(value); + this.metadata[i].set_done(); } } - if all_done { - let mut elems = mem::replace(&mut self.elems, Box::pin([])); - let result = iter_pin_mut(elems.as_mut()) - .map(|e| e.take().unwrap()) - .collect(); - Poll::Ready(result) + // Check whether we're all done now or need to keep going. + if this.metadata.iter().all(|meta| meta.is_done()) { + // Mark all data as "taken" before we actually take it. + this.metadata.iter_mut().for_each(|meta| meta.set_taken()); + + // SAFETY: we've checked with the metadata that all of our outputs have been + // filled, which means we're ready to take the data and assume it's initialized. + let items = unsafe { + let items = mem::take(this.items); + mem::transmute::<_, Vec>(items) + }; + Poll::Ready(items) } else { Poll::Pending } } } + +/// Drop the already initialized values on cancellation. +#[pinned_drop] +impl PinnedDrop for Join +where + Fut: Future, +{ + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + + // Get the indexes of the initialized values. + let indexes = this + .metadata + .iter_mut() + .filter(|meta| meta.is_done()) + .map(|meta| meta.index()); + + // Drop each value at the index. + for i in indexes { + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { this.items[i].assume_init_drop() }; + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::future; + + // NOTE: we should probably poll in random order. + #[test] + fn no_fairness() { + futures_lite::future::block_on(async { + let res = vec![future::ready("hello"), future::ready("world")] + .join() + .await; + assert_eq!(res, vec!["hello", "world"]); + }); + } +} diff --git a/src/utils/metadata.rs b/src/utils/metadata.rs new file mode 100644 index 0000000..7e3d989 --- /dev/null +++ b/src/utils/metadata.rs @@ -0,0 +1,105 @@ +/// Enumerate the current poll state. +#[derive(Debug, Clone, Copy)] +pub(crate) enum PollState { + /// Actively polling the underlying future. + Active, + /// Data has been written to the output structure + /// and the future should no longer be polled. + Written, + /// Data has been taken from the output structure, + /// and we no longer need to reason about it. + Taken, +} + +impl PollState { + /// Returns `true` if the poll state is [`Active`]. + /// + /// [`Active`]: PollState::Active + #[must_use] + fn is_active(&self) -> bool { + matches!(self, Self::Active) + } + + /// Returns `true` if the poll state is [`Done`]. + /// + /// [`Done`]: PollState::Done + #[must_use] + fn is_done(&self) -> bool { + matches!(self, Self::Written) + } + + /// Returns `true` if the poll state is [`Taken`]. + /// + /// [`Taken`]: PollState::Taken + #[must_use] + pub(crate) fn is_taken(&self) -> bool { + matches!(self, Self::Taken) + } +} + +#[derive(Debug)] +pub(crate) struct Metadata { + index: usize, + poll_state: PollState, +} + +impl Metadata { + /// Create a new instance of `Metadata`, positioned at a certain index. + pub(crate) fn new(index: usize) -> Self { + Self { + index, + poll_state: PollState::Active, + } + } + + /// Get the index of the metadata. + pub(crate) fn index(&self) -> usize { + self.index + } + + /// Get the current poll state. + pub(crate) fn poll_state(&self) -> PollState { + self.poll_state + } + + /// Set the current poll state. + pub(crate) fn set_poll_state(&mut self, poll_state: PollState) { + self.poll_state = poll_state; + } + + /// Set the current poll state to `Active`. + pub(crate) fn set_active(&mut self) { + self.poll_state = PollState::Active; + } + + /// Set the current poll state to `Done`. + pub(crate) fn set_done(&mut self) { + self.poll_state = PollState::Written; + } + + /// Set the current poll state to `Taken`. + pub(crate) fn set_taken(&mut self) { + self.poll_state = PollState::Taken; + } + + /// Returns `true` if the poll state is [`Active`]. + /// + /// [`Active`]: PollState::Active + pub(crate) fn is_active(&self) -> bool { + self.poll_state.is_active() + } + + /// Returns `true` if the poll state is [`Done`]. + /// + /// [`Done`]: PollState::Done + pub(crate) fn is_done(&self) -> bool { + self.poll_state.is_done() + } + + /// Returns `true` if the poll state is [`Taken`]. + /// + /// [`Taken`]: PollState::Taken + pub(crate) fn is_taken(&self) -> bool { + self.poll_state.is_taken() + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 76a82d3..3fe2216 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -7,10 +7,12 @@ mod fuse; mod maybe_done; +mod metadata; mod pin; mod rng; pub(crate) use fuse::Fuse; pub(crate) use maybe_done::MaybeDone; -pub(crate) use pin::{get_pin_mut, get_pin_mut_from_vec, iter_pin_mut}; +pub(crate) use metadata::{Metadata, PollState}; +pub(crate) use pin::{get_pin_mut, get_pin_mut_from_vec, iter_pin_mut, iter_pin_mut_vec}; pub(crate) use rng::random; diff --git a/src/utils/pin.rs b/src/utils/pin.rs index c8b5863..d416869 100644 --- a/src/utils/pin.rs +++ b/src/utils/pin.rs @@ -11,16 +11,15 @@ pub(crate) fn iter_pin_mut(slice: Pin<&mut [T]>) -> impl Iterator(slice: Pin<&mut [T; N]>) -> [Pin<&mut T>; N] { -// // SAFETY: `std` _could_ make this unsound if it were to decide Pin's -// // invariants aren't required to transmit through arrays. Otherwise this has -// // the same safety as a normal field pin projection. -// unsafe { slice.get_unchecked_mut() } -// .each_mut() -// .map(|t| unsafe { Pin::new_unchecked(t) }) -// } +// From: `futures_rs::join_all!` -- https://github.com/rust-lang/futures-rs/blob/b48eb2e9a9485ef7388edc2f177094a27e08e28b/futures-util/src/future/join_all.rs#L18-L23 +pub(crate) fn iter_pin_mut_vec(slice: Pin<&mut Vec>) -> impl Iterator> { + // SAFETY: `std` _could_ make this unsound if it were to decide Pin's + // invariants aren't required to transmit through slices. Otherwise this has + // the same safety as a normal field pin projection. + unsafe { slice.get_unchecked_mut() } + .iter_mut() + .map(|t| unsafe { Pin::new_unchecked(t) }) +} /// Returns a pinned mutable reference to an element or subslice depending on the /// type of index (see `get`) or `None` if the index is out of bounds. From 1dd997c464cb1b518a5d9cd7bfaade57f36db682 Mon Sep 17 00:00:00 2001 From: Yoshua Wuyts <2467194+yoshuawuyts@users.noreply.github.com> Date: Mon, 7 Nov 2022 23:47:57 +0100 Subject: [PATCH 2/2] implement feedback from review --- src/future/join/vec.rs | 64 ++++++++++++------------ src/utils/metadata.rs | 105 ---------------------------------------- src/utils/mod.rs | 4 +- src/utils/poll_state.rs | 39 +++++++++++++++ 4 files changed, 75 insertions(+), 137 deletions(-) delete mode 100644 src/utils/metadata.rs create mode 100644 src/utils/poll_state.rs diff --git a/src/future/join/vec.rs b/src/future/join/vec.rs index 7d807d8..f7dab06 100644 --- a/src/future/join/vec.rs +++ b/src/future/join/vec.rs @@ -1,5 +1,5 @@ use super::Join as JoinTrait; -use crate::utils::{iter_pin_mut_vec, Metadata}; +use crate::utils::{iter_pin_mut_vec, PollState}; use core::fmt; use core::future::{Future, IntoFuture}; @@ -20,10 +20,12 @@ pub struct Join where Fut: Future, { + consumed: bool, + pending: usize, + items: Vec::Output>>, + state: Vec, #[pin] futures: Vec, - items: Vec::Output>>, - metadata: Vec, } impl Join @@ -32,13 +34,12 @@ where { pub(crate) fn new(futures: Vec) -> Self { Join { + consumed: false, + pending: futures.len(), items: std::iter::repeat_with(|| MaybeUninit::uninit()) .take(futures.len()) .collect(), - metadata: std::iter::successors(Some(0), |prev| Some(prev + 1)) - .take(futures.len()) - .map(Metadata::new) - .collect(), + state: vec![PollState::default(); futures.len()], futures, } } @@ -62,8 +63,7 @@ where Fut::Output: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // TODO: fix debug output - f.debug_struct("Join").finish() + f.debug_list().entries(self.state.iter()).finish() } } @@ -73,32 +73,36 @@ where { type Output = Vec; - // SAFETY: see https://github.com/rust-lang/rust/issues/104108, - // projecting through slices is fine now, but it's not yet guaranteed to - // work. We need to guarantee structural pinning works as expected for it to - // be provably sound. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); + assert!( + !*this.consumed, + "Futures must not be polled after completing" + ); + // Poll all futures let futures = this.futures.as_mut(); for (i, fut) in iter_pin_mut_vec(futures).enumerate() { - if this.metadata[i].is_done() { - continue; - } - - if let Poll::Ready(value) = fut.poll(cx) { - this.items[i] = MaybeUninit::new(value); - this.metadata[i].set_done(); + if this.state[i].is_pending() { + if let Poll::Ready(value) = fut.poll(cx) { + this.items[i] = MaybeUninit::new(value); + this.state[i] = PollState::Done; + *this.pending -= 1; + } } } // Check whether we're all done now or need to keep going. - if this.metadata.iter().all(|meta| meta.is_done()) { - // Mark all data as "taken" before we actually take it. - this.metadata.iter_mut().for_each(|meta| meta.set_taken()); - - // SAFETY: we've checked with the metadata that all of our outputs have been + if *this.pending == 0 { + // Mark all data as "consumed" before we take it + *this.consumed = true; + this.state.iter_mut().for_each(|state| { + debug_assert!(state.is_done(), "Future should have reached a `Done` state"); + *state = PollState::Consumed; + }); + + // SAFETY: we've checked with the state that all of our outputs have been // filled, which means we're ready to take the data and assume it's initialized. let items = unsafe { let items = mem::take(this.items); @@ -122,10 +126,11 @@ where // Get the indexes of the initialized values. let indexes = this - .metadata + .state .iter_mut() - .filter(|meta| meta.is_done()) - .map(|meta| meta.index()); + .enumerate() + .filter(|(_, state)| state.is_done()) + .map(|(i, _)| i); // Drop each value at the index. for i in indexes { @@ -141,9 +146,8 @@ mod test { use super::*; use std::future; - // NOTE: we should probably poll in random order. #[test] - fn no_fairness() { + fn smoke() { futures_lite::future::block_on(async { let res = vec![future::ready("hello"), future::ready("world")] .join() diff --git a/src/utils/metadata.rs b/src/utils/metadata.rs deleted file mode 100644 index 7e3d989..0000000 --- a/src/utils/metadata.rs +++ /dev/null @@ -1,105 +0,0 @@ -/// Enumerate the current poll state. -#[derive(Debug, Clone, Copy)] -pub(crate) enum PollState { - /// Actively polling the underlying future. - Active, - /// Data has been written to the output structure - /// and the future should no longer be polled. - Written, - /// Data has been taken from the output structure, - /// and we no longer need to reason about it. - Taken, -} - -impl PollState { - /// Returns `true` if the poll state is [`Active`]. - /// - /// [`Active`]: PollState::Active - #[must_use] - fn is_active(&self) -> bool { - matches!(self, Self::Active) - } - - /// Returns `true` if the poll state is [`Done`]. - /// - /// [`Done`]: PollState::Done - #[must_use] - fn is_done(&self) -> bool { - matches!(self, Self::Written) - } - - /// Returns `true` if the poll state is [`Taken`]. - /// - /// [`Taken`]: PollState::Taken - #[must_use] - pub(crate) fn is_taken(&self) -> bool { - matches!(self, Self::Taken) - } -} - -#[derive(Debug)] -pub(crate) struct Metadata { - index: usize, - poll_state: PollState, -} - -impl Metadata { - /// Create a new instance of `Metadata`, positioned at a certain index. - pub(crate) fn new(index: usize) -> Self { - Self { - index, - poll_state: PollState::Active, - } - } - - /// Get the index of the metadata. - pub(crate) fn index(&self) -> usize { - self.index - } - - /// Get the current poll state. - pub(crate) fn poll_state(&self) -> PollState { - self.poll_state - } - - /// Set the current poll state. - pub(crate) fn set_poll_state(&mut self, poll_state: PollState) { - self.poll_state = poll_state; - } - - /// Set the current poll state to `Active`. - pub(crate) fn set_active(&mut self) { - self.poll_state = PollState::Active; - } - - /// Set the current poll state to `Done`. - pub(crate) fn set_done(&mut self) { - self.poll_state = PollState::Written; - } - - /// Set the current poll state to `Taken`. - pub(crate) fn set_taken(&mut self) { - self.poll_state = PollState::Taken; - } - - /// Returns `true` if the poll state is [`Active`]. - /// - /// [`Active`]: PollState::Active - pub(crate) fn is_active(&self) -> bool { - self.poll_state.is_active() - } - - /// Returns `true` if the poll state is [`Done`]. - /// - /// [`Done`]: PollState::Done - pub(crate) fn is_done(&self) -> bool { - self.poll_state.is_done() - } - - /// Returns `true` if the poll state is [`Taken`]. - /// - /// [`Taken`]: PollState::Taken - pub(crate) fn is_taken(&self) -> bool { - self.poll_state.is_taken() - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 3fe2216..8cdedc8 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -7,12 +7,12 @@ mod fuse; mod maybe_done; -mod metadata; mod pin; +mod poll_state; mod rng; pub(crate) use fuse::Fuse; pub(crate) use maybe_done::MaybeDone; -pub(crate) use metadata::{Metadata, PollState}; pub(crate) use pin::{get_pin_mut, get_pin_mut_from_vec, iter_pin_mut, iter_pin_mut_vec}; +pub(crate) use poll_state::PollState; pub(crate) use rng::random; diff --git a/src/utils/poll_state.rs b/src/utils/poll_state.rs new file mode 100644 index 0000000..b167758 --- /dev/null +++ b/src/utils/poll_state.rs @@ -0,0 +1,39 @@ +/// Enumerate the current poll state. +#[derive(Debug, Clone, Copy, Default)] +pub(crate) enum PollState { + /// Polling the underlying future. + #[default] + Pending, + /// Data has been written to the output structure + /// and the future should no longer be polled. + Done, + /// Data has been consumed from the output structure, + /// and we should no longer reason about it. + Consumed, +} + +impl PollState { + /// Returns `true` if the metadata is [`Pending`]. + /// + /// [`Pending`]: Metadata::Pending + #[must_use] + pub(crate) fn is_pending(&self) -> bool { + matches!(self, Self::Pending) + } + + /// Returns `true` if the poll state is [`Done`]. + /// + /// [`Done`]: PollState::Done + #[must_use] + pub(crate) fn is_done(&self) -> bool { + matches!(self, Self::Done) + } + + /// Returns `true` if the poll state is [`Consumed`]. + /// + /// [`Consumed`]: PollState::Consumed + #[must_use] + pub(crate) fn is_consumed(&self) -> bool { + matches!(self, Self::Consumed) + } +}