From b29307d9419ec1bbebe3bf3df64efcdf33cbef2b Mon Sep 17 00:00:00 2001 From: Wisha Wa Date: Thu, 29 Dec 2022 06:48:52 +0000 Subject: [PATCH 01/10] shuffle vec/array/tuple futures/streams for benchmark (with a fixed seed) --- Cargo.toml | 1 + benches/utils.rs | 232 ----------------------------- benches/utils/countdown_futures.rs | 116 +++++++++++++++ benches/utils/countdown_streams.rs | 116 +++++++++++++++ benches/utils/mod.rs | 46 ++++++ 5 files changed, 279 insertions(+), 232 deletions(-) delete mode 100644 benches/utils.rs create mode 100644 benches/utils/countdown_futures.rs create mode 100644 benches/utils/countdown_streams.rs create mode 100644 benches/utils/mod.rs diff --git a/Cargo.toml b/Cargo.toml index f75584c..25f411f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,3 +35,4 @@ futures-lite = "1.12.0" criterion = { version = "0.3", features = ["async", "async_futures", "html_reports"] } async-std = { version = "1.12.0", features = ["attributes"] } futures-time = "3.0.0" +rand = "0.8.5" diff --git a/benches/utils.rs b/benches/utils.rs deleted file mode 100644 index b04e4d1..0000000 --- a/benches/utils.rs +++ /dev/null @@ -1,232 +0,0 @@ -#![allow(unused)] - -use futures_core::Stream; -use futures_lite::prelude::*; -use pin_project::pin_project; - -use std::cell::RefCell; -use std::collections::VecDeque; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll, Waker}; - -pub fn futures_vec(len: usize) -> Vec { - let wakers = Rc::new(RefCell::new(VecDeque::new())); - let completed = Rc::new(RefCell::new(0)); - let futures: Vec<_> = (0..len) - .map(|n| CountdownFuture::new(n, len, wakers.clone(), completed.clone())) - .collect(); - futures -} - -pub fn futures_array() -> [CountdownFuture; N] { - let wakers = Rc::new(RefCell::new(VecDeque::new())); - let completed = Rc::new(RefCell::new(0)); - std::array::from_fn(|n| CountdownFuture::new(n, N, wakers.clone(), completed.clone())) -} - -pub fn futures_tuple() -> ( - CountdownFuture, - CountdownFuture, - CountdownFuture, - CountdownFuture, - CountdownFuture, - CountdownFuture, - CountdownFuture, - CountdownFuture, - CountdownFuture, - CountdownFuture, -) { - let len = 10; - let wakers = Rc::new(RefCell::new(VecDeque::new())); - let completed = Rc::new(RefCell::new(0)); - ( - CountdownFuture::new(0, len, wakers.clone(), completed.clone()), - CountdownFuture::new(1, len, wakers.clone(), completed.clone()), - CountdownFuture::new(2, len, wakers.clone(), completed.clone()), - CountdownFuture::new(3, len, wakers.clone(), completed.clone()), - CountdownFuture::new(4, len, wakers.clone(), completed.clone()), - CountdownFuture::new(5, len, wakers.clone(), completed.clone()), - CountdownFuture::new(6, len, wakers.clone(), completed.clone()), - CountdownFuture::new(7, len, wakers.clone(), completed.clone()), - CountdownFuture::new(8, len, wakers.clone(), completed.clone()), - CountdownFuture::new(9, len, wakers, completed), - ) -} - -pub fn streams_vec(len: usize) -> Vec { - let wakers = Rc::new(RefCell::new(VecDeque::new())); - let completed = Rc::new(RefCell::new(0)); - let streams: Vec<_> = (0..len) - .map(|n| CountdownStream::new(n, len, wakers.clone(), completed.clone())) - .collect(); - streams -} - -pub fn streams_array() -> [CountdownStream; N] { - let wakers = Rc::new(RefCell::new(VecDeque::new())); - let completed = Rc::new(RefCell::new(0)); - std::array::from_fn(|n| CountdownStream::new(n, N, wakers.clone(), completed.clone())) -} - -pub fn streams_tuple() -> ( - CountdownStream, - CountdownStream, - CountdownStream, - CountdownStream, - CountdownStream, - CountdownStream, - CountdownStream, - CountdownStream, - CountdownStream, - CountdownStream, -) { - let len = 10; - let wakers = Rc::new(RefCell::new(VecDeque::new())); - let completed = Rc::new(RefCell::new(0)); - ( - CountdownStream::new(0, len, wakers.clone(), completed.clone()), - CountdownStream::new(1, len, wakers.clone(), completed.clone()), - CountdownStream::new(2, len, wakers.clone(), completed.clone()), - CountdownStream::new(3, len, wakers.clone(), completed.clone()), - CountdownStream::new(4, len, wakers.clone(), completed.clone()), - CountdownStream::new(5, len, wakers.clone(), completed.clone()), - CountdownStream::new(6, len, wakers.clone(), completed.clone()), - CountdownStream::new(7, len, wakers.clone(), completed.clone()), - CountdownStream::new(8, len, wakers.clone(), completed.clone()), - CountdownStream::new(9, len, wakers, completed), - ) -} - -#[derive(Clone, Copy)] -enum State { - Init, - Polled, - Done, -} - -/// A stream which will _eventually_ be ready, but needs to be polled N times before it is. -#[pin_project] -pub struct CountdownStream { - state: State, - wakers: Rc>>, - index: usize, - max_count: usize, - completed_count: Rc>, -} - -impl CountdownStream { - pub fn new( - index: usize, - max_count: usize, - wakers: Rc>>, - completed_count: Rc>, - ) -> Self { - Self { - state: State::Init, - wakers, - max_count, - index, - completed_count, - } - } -} -impl Stream for CountdownStream { - type Item = (); - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - // If we are the last stream to be polled, skip strait to the Polled state. - if this.wakers.borrow().len() + 1 == *this.max_count { - *this.state = State::Polled; - } - - match this.state { - State::Init => { - // Push our waker onto the stack so we get woken again someday. - this.wakers.borrow_mut().push_back(cx.waker().clone()); - *this.state = State::Polled; - Poll::Pending - } - State::Polled => { - // Wake up the next one - let _ = this.wakers.borrow_mut().pop_front().map(Waker::wake); - - if *this.completed_count.borrow() == *this.index { - *this.state = State::Done; - *this.completed_count.borrow_mut() += 1; - Poll::Ready(Some(())) - } else { - // We're not done yet, so schedule another wakeup - this.wakers.borrow_mut().push_back(cx.waker().clone()); - Poll::Pending - } - } - State::Done => Poll::Ready(None), - } - } -} - -/// A future which will _eventually_ be ready, but needs to be polled N times before it is. -#[pin_project] -pub struct CountdownFuture { - state: State, - wakers: Rc>>, - index: usize, - max_count: usize, - completed_count: Rc>, -} - -impl CountdownFuture { - pub fn new( - index: usize, - max_count: usize, - wakers: Rc>>, - completed_count: Rc>, - ) -> Self { - Self { - state: State::Init, - wakers, - max_count, - index, - completed_count, - } - } -} -impl Future for CountdownFuture { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - // If we are the last stream to be polled, skip strait to the Polled state. - if this.wakers.borrow().len() + 1 == *this.max_count { - *this.state = State::Polled; - } - - match this.state { - State::Init => { - // Push our waker onto the stack so we get woken again someday. - this.wakers.borrow_mut().push_back(cx.waker().clone()); - *this.state = State::Polled; - Poll::Pending - } - State::Polled => { - // Wake up the next one - let _ = this.wakers.borrow_mut().pop_front().map(Waker::wake); - - if *this.completed_count.borrow() == *this.index { - *this.state = State::Done; - *this.completed_count.borrow_mut() += 1; - Poll::Ready(()) - } else { - // We're not done yet, so schedule another wakeup - this.wakers.borrow_mut().push_back(cx.waker().clone()); - Poll::Pending - } - } - State::Done => Poll::Ready(()), - } - } -} diff --git a/benches/utils/countdown_futures.rs b/benches/utils/countdown_futures.rs new file mode 100644 index 0000000..c324281 --- /dev/null +++ b/benches/utils/countdown_futures.rs @@ -0,0 +1,116 @@ +use futures_core::Future; +use pin_project::pin_project; + +use std::cell::{Cell, RefCell}; +use std::collections::BinaryHeap; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use super::{shuffle, PrioritizedWaker, State}; + +pub fn futures_vec(len: usize) -> Vec { + let wakers = Rc::new(RefCell::new(BinaryHeap::new())); + let completed = Rc::new(Cell::new(0)); + let mut futures: Vec<_> = (0..len) + .map(|n| CountdownFuture::new(n, len, wakers.clone(), completed.clone())) + .collect(); + shuffle(&mut futures); + futures +} + +pub fn futures_array() -> [CountdownFuture; N] { + let wakers = Rc::new(RefCell::new(BinaryHeap::new())); + let completed = Rc::new(Cell::new(0)); + let mut futures = + std::array::from_fn(|n| CountdownFuture::new(n, N, wakers.clone(), completed.clone())); + shuffle(&mut futures); + futures +} + +pub fn futures_tuple() -> ( + CountdownFuture, + CountdownFuture, + CountdownFuture, + CountdownFuture, + CountdownFuture, + CountdownFuture, + CountdownFuture, + CountdownFuture, + CountdownFuture, + CountdownFuture, +) { + let [f0, f1, f2, f3, f4, f5, f6, f7, f8, f9] = futures_array::<10>(); + (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) +} + +/// A future which will _eventually_ be ready, but needs to be polled N times before it is. +#[pin_project] +pub struct CountdownFuture { + state: State, + wakers: Rc>>, + index: usize, + max_count: usize, + completed_count: Rc>, +} + +impl CountdownFuture { + pub fn new( + index: usize, + max_count: usize, + wakers: Rc>>, + completed_count: Rc>, + ) -> Self { + Self { + state: State::Init, + wakers, + max_count, + index, + completed_count, + } + } +} +impl Future for CountdownFuture { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + // If we are the last stream to be polled, skip strait to the Polled state. + if this.wakers.borrow().len() + 1 == *this.max_count { + *this.state = State::Polled; + } + + match this.state { + State::Init => { + // Push our waker onto the stack so we get woken again someday. + this.wakers + .borrow_mut() + .push(PrioritizedWaker(*this.index, cx.waker().clone())); + *this.state = State::Polled; + Poll::Pending + } + State::Polled => { + // Wake up the next one + let _ = this + .wakers + .borrow_mut() + .pop() + .map(|PrioritizedWaker(_, waker)| waker.wake()); + + if this.completed_count.get() == *this.index { + *this.state = State::Done; + this.completed_count.set(this.completed_count.get() + 1); + Poll::Ready(()) + } else { + // We're not done yet, so schedule another wakeup + this.wakers + .borrow_mut() + .push(PrioritizedWaker(*this.index, cx.waker().clone())); + Poll::Pending + } + } + State::Done => Poll::Ready(()), + } + } +} diff --git a/benches/utils/countdown_streams.rs b/benches/utils/countdown_streams.rs new file mode 100644 index 0000000..0fa1035 --- /dev/null +++ b/benches/utils/countdown_streams.rs @@ -0,0 +1,116 @@ +use futures_core::Stream; +use pin_project::pin_project; + +use std::cell::{Cell, RefCell}; +use std::collections::BinaryHeap; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use super::{shuffle, PrioritizedWaker, State}; + +pub fn streams_vec(len: usize) -> Vec { + let wakers = Rc::new(RefCell::new(BinaryHeap::new())); + let completed = Rc::new(Cell::new(0)); + let mut streams: Vec<_> = (0..len) + .map(|n| CountdownStream::new(n, len, wakers.clone(), completed.clone())) + .collect(); + shuffle(&mut streams); + streams +} + +pub fn streams_array() -> [CountdownStream; N] { + let wakers = Rc::new(RefCell::new(BinaryHeap::new())); + let completed = Rc::new(Cell::new(0)); + let mut streams = + std::array::from_fn(|n| CountdownStream::new(n, N, wakers.clone(), completed.clone())); + shuffle(&mut streams); + streams +} + +pub fn streams_tuple() -> ( + CountdownStream, + CountdownStream, + CountdownStream, + CountdownStream, + CountdownStream, + CountdownStream, + CountdownStream, + CountdownStream, + CountdownStream, + CountdownStream, +) { + let [f0, f1, f2, f3, f4, f5, f6, f7, f8, f9] = streams_array::<10>(); + (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) +} + +/// A stream which will _eventually_ be ready, but needs to be polled N times before it is. +#[pin_project] +pub struct CountdownStream { + state: State, + wakers: Rc>>, + index: usize, + max_count: usize, + completed_count: Rc>, +} + +impl CountdownStream { + pub fn new( + index: usize, + max_count: usize, + wakers: Rc>>, + completed_count: Rc>, + ) -> Self { + Self { + state: State::Init, + wakers, + max_count, + index, + completed_count, + } + } +} +impl Stream for CountdownStream { + type Item = (); + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + // If we are the last stream to be polled, skip strait to the Polled state. + if this.wakers.borrow().len() + 1 == *this.max_count { + *this.state = State::Polled; + } + + match this.state { + State::Init => { + // Push our waker onto the stack so we get woken again someday. + this.wakers + .borrow_mut() + .push(PrioritizedWaker(*this.index, cx.waker().clone())); + *this.state = State::Polled; + Poll::Pending + } + State::Polled => { + // Wake up the next one + let _ = this + .wakers + .borrow_mut() + .pop() + .map(|PrioritizedWaker(_, waker)| waker.wake()); + + if this.completed_count.get() == *this.index { + *this.state = State::Done; + this.completed_count.set(this.completed_count.get() + 1); + Poll::Ready(Some(())) + } else { + // We're not done yet, so schedule another wakeup + this.wakers + .borrow_mut() + .push(PrioritizedWaker(*this.index, cx.waker().clone())); + Poll::Pending + } + } + State::Done => Poll::Ready(None), + } + } +} diff --git a/benches/utils/mod.rs b/benches/utils/mod.rs new file mode 100644 index 0000000..971a5de --- /dev/null +++ b/benches/utils/mod.rs @@ -0,0 +1,46 @@ +mod countdown_futures; +mod countdown_streams; + +mod prioritized_waker { + use std::{cmp::Ordering, task::Waker}; + + // PrioritizedWaker(index, waker). + // Lowest index gets popped off the BinaryHeap first. + pub struct PrioritizedWaker(pub usize, pub Waker); + impl PartialEq for PrioritizedWaker { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } + } + impl Eq for PrioritizedWaker { + fn assert_receiver_is_total_eq(&self) {} + } + impl PartialOrd for PrioritizedWaker { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + impl Ord for PrioritizedWaker { + fn cmp(&self, other: &Self) -> Ordering { + self.0.cmp(&other.0).reverse() + } + } +} +use prioritized_waker::PrioritizedWaker; + +#[derive(Clone, Copy)] +enum State { + Init, + Polled, + Done, +} + +fn shuffle(slice: &mut [T]) { + use rand::seq::SliceRandom; + use rand::SeedableRng; + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + slice.shuffle(&mut rng); +} + +pub use countdown_futures::*; +pub use countdown_streams::*; From ca246193abec454ab5c0d1120b5c6da709a2709f Mon Sep 17 00:00:00 2001 From: Wisha Wa Date: Mon, 9 Jan 2023 20:43:53 +0000 Subject: [PATCH 02/10] used a single shared Arc between wakers in WakerArray and WakerVec --- src/utils/wakers/array/mod.rs | 2 - src/utils/wakers/array/waker.rs | 31 --------- src/utils/wakers/array/waker_array.rs | 69 ++++++++++++++++---- src/utils/wakers/mod.rs | 4 +- src/utils/wakers/shared_arc.rs | 92 +++++++++++++++++++++++++++ src/utils/wakers/vec/mod.rs | 2 - src/utils/wakers/vec/waker.rs | 31 --------- src/utils/wakers/vec/waker_vec.rs | 67 +++++++++++++++---- 8 files changed, 206 insertions(+), 92 deletions(-) delete mode 100644 src/utils/wakers/array/waker.rs create mode 100644 src/utils/wakers/shared_arc.rs delete mode 100644 src/utils/wakers/vec/waker.rs diff --git a/src/utils/wakers/array/mod.rs b/src/utils/wakers/array/mod.rs index 7303d9c..20a9392 100644 --- a/src/utils/wakers/array/mod.rs +++ b/src/utils/wakers/array/mod.rs @@ -1,7 +1,5 @@ mod readiness; -mod waker; mod waker_array; pub(crate) use readiness::ReadinessArray; -pub(crate) use waker::InlineWakerArray; pub(crate) use waker_array::WakerArray; diff --git a/src/utils/wakers/array/waker.rs b/src/utils/wakers/array/waker.rs deleted file mode 100644 index 960dff0..0000000 --- a/src/utils/wakers/array/waker.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::sync::{Arc, Mutex}; -use std::task::Wake; - -use super::ReadinessArray; - -/// An efficient waker which delegates wake events. -#[derive(Debug, Clone)] -pub(crate) struct InlineWakerArray { - pub(crate) id: usize, - pub(crate) readiness: Arc>>, -} - -impl InlineWakerArray { - /// Create a new instance of `InlineWaker`. - pub(crate) fn new(id: usize, readiness: Arc>>) -> Self { - Self { id, readiness } - } -} - -impl Wake for InlineWakerArray { - fn wake(self: std::sync::Arc) { - let mut readiness = self.readiness.lock().unwrap(); - if !readiness.set_ready(self.id) { - readiness - .parent_waker() - .as_mut() - .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") - .wake_by_ref() - } - } -} diff --git a/src/utils/wakers/array/waker_array.rs b/src/utils/wakers/array/waker_array.rs index 293155b..345c46b 100644 --- a/src/utils/wakers/array/waker_array.rs +++ b/src/utils/wakers/array/waker_array.rs @@ -1,26 +1,52 @@ use core::array; -use std::sync::Arc; -use std::sync::Mutex; -use std::task::Waker; +use core::task::Waker; +use std::sync::{Arc, Mutex}; -use super::{InlineWakerArray, ReadinessArray}; +use super::{ + super::shared_arc::{waker_for_wake_data_slot, WakeDataContainer}, + ReadinessArray, +}; /// A collection of wakers which delegate to an in-line waker. pub(crate) struct WakerArray { wakers: [Waker; N], - readiness: Arc>>, + inner: Arc>, +} + +/// See [super::super::shared_arc] for how this works. +struct WakerArrayInner { + wake_data: [*const Self; N], + readiness: Mutex>, } impl WakerArray { /// Create a new instance of `WakerArray`. pub(crate) fn new() -> Self { - let readiness = Arc::new(Mutex::new(ReadinessArray::new())); - Self { - wakers: array::from_fn(|i| { - Arc::new(InlineWakerArray::new(i, readiness.clone())).into() - }), - readiness, - } + let mut inner = Arc::new(WakerArrayInner { + readiness: Mutex::new(ReadinessArray::new()), + wake_data: [std::ptr::null(); N], // We don't know the Arc's address yet so put null for now. + }); + let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address. + + // At this point the strong count is 2. Decrement it to 1. + // Each time we create/clone a Waker the count will be incremented by 1. + // So N Wakers -> count = N+1. + unsafe { Arc::decrement_strong_count(raw) } + + // Make wake_data all point to the Arc itself. + Arc::get_mut(&mut inner).unwrap().wake_data = [raw; N]; + + // Now the wake_data array is complete. Time to create the actual Wakers. + let wakers = array::from_fn(|i| { + let data = inner.wake_data.get(i).unwrap(); + unsafe { + waker_for_wake_data_slot::>( + data as *const *const WakerArrayInner, + ) + } + }); + + Self { inner, wakers } } pub(crate) fn get(&self, index: usize) -> Option<&Waker> { @@ -29,6 +55,23 @@ impl WakerArray { /// Access the `Readiness`. pub(crate) fn readiness(&self) -> &Mutex> { - self.readiness.as_ref() + &self.inner.readiness + } +} + +impl WakeDataContainer for WakerArrayInner { + fn get_wake_data_slice(&self) -> &[*const Self] { + &self.wake_data + } + + fn wake_index(&self, index: usize) { + let mut readiness = self.readiness.lock().unwrap(); + if !readiness.set_ready(index) { + readiness + .parent_waker() + .as_ref() + .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") + .wake_by_ref(); + } } } diff --git a/src/utils/wakers/mod.rs b/src/utils/wakers/mod.rs index d5c7f1d..3eba976 100644 --- a/src/utils/wakers/mod.rs +++ b/src/utils/wakers/mod.rs @@ -1,7 +1,9 @@ mod array; +mod shared_arc; +mod vec; + #[cfg(test)] mod dummy; -mod vec; #[cfg(test)] pub(crate) use dummy::DummyWaker; diff --git a/src/utils/wakers/shared_arc.rs b/src/utils/wakers/shared_arc.rs new file mode 100644 index 0000000..cad8431 --- /dev/null +++ b/src/utils/wakers/shared_arc.rs @@ -0,0 +1,92 @@ +use core::task::{RawWaker, RawWakerVTable, Waker}; +use std::sync::Arc; + +// In the diagram below, `A` is the upper block. +// It is a struct that implements WakeDataContainer (so either WakerVecInner or WakerArrayInner). +// The lower block is either WakerVec or WakerArray. Each waker there points to a slot of wake_data in `A`. +// Every one of these slots contain a pointer to the Arc wrapping `A` itself. +// Wakers figure out their indices by comparing the address they are pointing to to `wake_data`'s start address. +// +// ┌───────────────────────────┬──────────────┬──────────────┐ +// │ │ │ │ +// │ / ┌─────────────┬──────┼───────┬──────┼───────┬──────┼───────┬─────┐ \ +// ▼ / │ │ │ │ │ │ │ │ │ \ +// Arc < │ Readiness │ wake_data[0] │ wake_data[1] │ wake_data[2] │ ... │ > +// ▲ \ │ │ │ │ │ │ / +// │ \ └─────────────┴──────▲───────┴──────▲───────┴──────▲───────┴─────┘ / +// │ │ │ │ +// └─┐ ┌───────────────┘ │ │ +// │ │ │ │ +// │ │ ┌──────────────────┘ │ +// │ │ │ │ +// │ │ │ ┌─────────────────────┘ +// │ │ │ │ +// │ │ │ │ +// ┌────┼────┬────┼──────┬────┼──────┬────┼──────┬─────┐ +// │ │ │ │ │ │ │ │ │ │ +// │ Inner │ wakers[0] │ wakers[1] │ wakers[2] │ ... │ +// │ │ │ │ │ │ +// └─────────┴───────────┴───────────┴───────────┴─────┘ + +// TODO: Right now each waker gets its own wake_data slot. +// We can save space by making size_of::() wakers share the same slot. +// With such change, in 64-bit system, the wake_data array/vec would only need ⌈N/8⌉ slots instead of N. + +pub(super) trait WakeDataContainer { + /// Get the reference of the wake_data slice. This is used to compute the index. + fn get_wake_data_slice(&self) -> &[*const Self]; + /// Called when the subfuture at the specified index should be polled. + fn wake_index(&self, index: usize); +} +pub(super) unsafe fn waker_for_wake_data_slot( + pointer: *const *const A, +) -> Waker { + unsafe fn clone_waker(pointer: *const ()) -> RawWaker { + let pointer = pointer as *const *const A; + let raw = *pointer; // This is the raw pointer of Arc. + + // We're creating a new Waker, so we need to increment the count. + Arc::increment_strong_count(raw); + + RawWaker::new(pointer as *const (), create_vtable::()) + } + + // Convert a pointer to a wake_data slot to the Arc. + unsafe fn to_arc(pointer: *const *const A) -> Arc { + let raw = *pointer; + Arc::from_raw(raw) + } + unsafe fn wake(pointer: *const ()) { + let pointer = pointer as *const *const A; + let arc = to_arc::(pointer); + // Calculate the index + let index = ((pointer as usize) // This is the slot our pointer points to. + - (arc.get_wake_data_slice() as *const [*const A] as *const () as usize)) // This is the starting address of wake_data. + / std::mem::size_of::<*const A>(); + + arc.wake_index(index); + + // Dropping the Arc would decrement the strong count. + // We only want to do that when we're not waking by ref. + if BY_REF { + std::mem::forget(arc); + } else { + std::mem::drop(arc); + } + } + unsafe fn drop_waker(pointer: *const ()) { + let pointer = pointer as *const *const A; + let arc = to_arc::(pointer); + // Decrement the strong count by dropping the Arc. + std::mem::drop(arc); + } + fn create_vtable() -> &'static RawWakerVTable { + &RawWakerVTable::new( + clone_waker::, + wake::, + wake::, + drop_waker::, + ) + } + Waker::from_raw(clone_waker::(pointer as *const ())) +} diff --git a/src/utils/wakers/vec/mod.rs b/src/utils/wakers/vec/mod.rs index e002fd3..7f064fb 100644 --- a/src/utils/wakers/vec/mod.rs +++ b/src/utils/wakers/vec/mod.rs @@ -1,7 +1,5 @@ mod readiness; -mod waker; mod waker_vec; pub(crate) use readiness::ReadinessVec; -pub(crate) use waker::InlineWakerVec; pub(crate) use waker_vec::WakerVec; diff --git a/src/utils/wakers/vec/waker.rs b/src/utils/wakers/vec/waker.rs deleted file mode 100644 index cfd12ad..0000000 --- a/src/utils/wakers/vec/waker.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::sync::{Arc, Mutex}; -use std::task::Wake; - -use super::ReadinessVec; - -/// An efficient waker which delegates wake events. -#[derive(Debug, Clone)] -pub(crate) struct InlineWakerVec { - pub(crate) id: usize, - pub(crate) readiness: Arc>, -} - -impl InlineWakerVec { - /// Create a new instance of `InlineWaker`. - pub(crate) fn new(id: usize, readiness: Arc>) -> Self { - Self { id, readiness } - } -} - -impl Wake for InlineWakerVec { - fn wake(self: std::sync::Arc) { - let mut readiness = self.readiness.lock().unwrap(); - if !readiness.set_ready(self.id) { - readiness - .parent_waker() - .as_mut() - .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") - .wake_by_ref() - } - } -} diff --git a/src/utils/wakers/vec/waker_vec.rs b/src/utils/wakers/vec/waker_vec.rs index e317f0b..2e21f3f 100644 --- a/src/utils/wakers/vec/waker_vec.rs +++ b/src/utils/wakers/vec/waker_vec.rs @@ -1,31 +1,74 @@ -use std::sync::Arc; -use std::sync::Mutex; -use std::task::Waker; +use core::task::Waker; +use std::sync::{Arc, Mutex}; -use super::{InlineWakerVec, ReadinessVec}; +use super::{ + super::shared_arc::{waker_for_wake_data_slot, WakeDataContainer}, + ReadinessVec, +}; -/// A collection of wakers which delegate to an in-line waker. +/// A collection of wakers sharing the same allocation. pub(crate) struct WakerVec { wakers: Vec, - readiness: Arc>, + inner: Arc, +} + +/// See [super::super::shared_arc] for how this works. +struct WakerVecInner { + wake_data: Vec<*const Self>, + readiness: Mutex, } impl WakerVec { /// Create a new instance of `WakerVec`. pub(crate) fn new(len: usize) -> Self { - let readiness = Arc::new(Mutex::new(ReadinessVec::new(len))); - let wakers = (0..len) - .map(|i| Arc::new(InlineWakerVec::new(i, readiness.clone())).into()) + let mut inner = Arc::new(WakerVecInner { + readiness: Mutex::new(ReadinessVec::new(len)), + wake_data: Vec::new(), + }); + let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address. + + // At this point the strong count is 2. Decrement it to 1. + // Each time we create/clone a Waker the count will be incremented by 1. + // So N Wakers -> count = N+1. + unsafe { Arc::decrement_strong_count(raw) } + + // Make wake_data all point to the Arc itself. + Arc::get_mut(&mut inner).unwrap().wake_data = vec![raw; len]; + + // Now the wake_data vec is complete. Time to create the actual Wakers. + let wakers = inner + .wake_data + .iter() + .map(|data| unsafe { + waker_for_wake_data_slot::(data as *const *const WakerVecInner) + }) .collect(); - Self { wakers, readiness } + + Self { inner, wakers } } pub(crate) fn get(&self, index: usize) -> Option<&Waker> { self.wakers.get(index) } - /// Access the `Readiness`. pub(crate) fn readiness(&self) -> &Mutex { - self.readiness.as_ref() + &self.inner.readiness + } +} + +impl WakeDataContainer for WakerVecInner { + fn get_wake_data_slice(&self) -> &[*const Self] { + &self.wake_data + } + + fn wake_index(&self, index: usize) { + let mut readiness = self.readiness.lock().unwrap(); + if !readiness.set_ready(index) { + readiness + .parent_waker() + .as_ref() + .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") + .wake_by_ref(); + } } } From b396a85536c9e19d71531b4ac5ebf2b39ea73566 Mon Sep 17 00:00:00 2001 From: Wisha Wa Date: Mon, 9 Jan 2023 20:47:56 +0000 Subject: [PATCH 03/10] Use RawWaker to create allocation-free dummy waker --- src/future/join/array.rs | 5 ++--- src/future/join/vec.rs | 5 ++--- src/utils/mod.rs | 2 +- src/utils/wakers/dummy.rs | 18 ++++++++++++++---- src/utils/wakers/mod.rs | 2 +- 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/future/join/array.rs b/src/future/join/array.rs index 9a6ec0c..cdedcc4 100644 --- a/src/future/join/array.rs +++ b/src/future/join/array.rs @@ -168,11 +168,10 @@ where #[cfg(test)] mod test { use super::*; - use crate::utils::DummyWaker; + use crate::utils::dummy_waker; use std::future; use std::future::Future; - use std::sync::Arc; use std::task::Context; #[test] @@ -189,7 +188,7 @@ mod test { assert_eq!(format!("{:?}", fut), "[Pending, Pending]"); let mut fut = Pin::new(&mut fut); - let waker = Arc::new(DummyWaker()).into(); + let waker = dummy_waker(); let mut cx = Context::from_waker(&waker); let _ = fut.as_mut().poll(&mut cx); assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]"); diff --git a/src/future/join/vec.rs b/src/future/join/vec.rs index 1817de6..31d643b 100644 --- a/src/future/join/vec.rs +++ b/src/future/join/vec.rs @@ -170,11 +170,10 @@ where #[cfg(test)] mod test { use super::*; - use crate::utils::DummyWaker; + use crate::utils::dummy_waker; use std::future; use std::future::Future; - use std::sync::Arc; use std::task::Context; #[test] @@ -191,7 +190,7 @@ mod test { assert_eq!(format!("{:?}", fut), "[Pending, Pending]"); let mut fut = Pin::new(&mut fut); - let waker = Arc::new(DummyWaker()).into(); + let waker = dummy_waker(); let mut cx = Context::from_waker(&waker); let _ = fut.as_mut().poll(&mut cx); assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]"); diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ac5d38e..ed9932a 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -16,7 +16,7 @@ pub(crate) use tuple::{gen_conditions, tuple_len}; pub(crate) use wakers::{WakerArray, WakerVec}; #[cfg(test)] -pub(crate) use wakers::DummyWaker; +pub(crate) use wakers::dummy_waker; #[cfg(test)] pub(crate) mod channel; diff --git a/src/utils/wakers/dummy.rs b/src/utils/wakers/dummy.rs index 0f454b0..b60c996 100644 --- a/src/utils/wakers/dummy.rs +++ b/src/utils/wakers/dummy.rs @@ -1,6 +1,16 @@ -use std::{sync::Arc, task::Wake}; +use core::task::{RawWaker, RawWakerVTable, Waker}; -pub(crate) struct DummyWaker(); -impl Wake for DummyWaker { - fn wake(self: Arc) {} +/// A Waker that doesn't do anything. +pub(crate) fn dummy_waker() -> Waker { + fn new_raw_waker() -> RawWaker { + unsafe fn no_op(_data: *const ()) {} + unsafe fn clone(_data: *const ()) -> RawWaker { + new_raw_waker() + } + RawWaker::new( + core::ptr::null() as *const usize as *const (), + &RawWakerVTable::new(clone, no_op, no_op, no_op), + ) + } + unsafe { Waker::from_raw(new_raw_waker()) } } diff --git a/src/utils/wakers/mod.rs b/src/utils/wakers/mod.rs index 3eba976..e4fe1e7 100644 --- a/src/utils/wakers/mod.rs +++ b/src/utils/wakers/mod.rs @@ -6,7 +6,7 @@ mod vec; mod dummy; #[cfg(test)] -pub(crate) use dummy::DummyWaker; +pub(crate) use dummy::dummy_waker; pub(crate) use array::*; pub(crate) use vec::*; From 1d0be87e06f0ef758dc8ea2807519b5c50a583c2 Mon Sep 17 00:00:00 2001 From: Wisha Wa Date: Mon, 9 Jan 2023 20:48:46 +0000 Subject: [PATCH 04/10] add a test for the new Arc-sharing WakerArray --- src/utils/wakers/array/waker_array.rs | 40 +++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/utils/wakers/array/waker_array.rs b/src/utils/wakers/array/waker_array.rs index 345c46b..7797672 100644 --- a/src/utils/wakers/array/waker_array.rs +++ b/src/utils/wakers/array/waker_array.rs @@ -75,3 +75,43 @@ impl WakeDataContainer for WakerArrayInner { } } } + +#[cfg(test)] +mod tests { + use crate::utils::wakers::dummy_waker; + + use super::*; + #[test] + fn check_refcount() { + let mut wa = WakerArray::<5>::new(); + assert_eq!(Arc::strong_count(&wa.inner), 6); + wa.wakers[4] = dummy_waker(); + assert_eq!(Arc::strong_count(&wa.inner), 5); + let cloned = wa.wakers[3].clone(); + assert_eq!(Arc::strong_count(&wa.inner), 6); + wa.wakers[3] = wa.wakers[4].clone(); + assert_eq!(Arc::strong_count(&wa.inner), 5); + drop(cloned); + assert_eq!(Arc::strong_count(&wa.inner), 4); + + wa.wakers[0].wake_by_ref(); + wa.wakers[0].wake_by_ref(); + wa.wakers[0].wake_by_ref(); + assert_eq!(Arc::strong_count(&wa.inner), 4); + + wa.wakers[0] = wa.wakers[1].clone(); + assert_eq!(Arc::strong_count(&wa.inner), 4); + + let taken = std::mem::replace(&mut wa.wakers[2], dummy_waker()); + assert_eq!(Arc::strong_count(&wa.inner), 4); + taken.wake_by_ref(); + taken.wake_by_ref(); + taken.wake_by_ref(); + assert_eq!(Arc::strong_count(&wa.inner), 4); + taken.wake(); + assert_eq!(Arc::strong_count(&wa.inner), 3); + + wa.wakers = array::from_fn(|_| dummy_waker()); + assert_eq!(Arc::strong_count(&wa.inner), 1); + } +} From b20cabfc36e8ce1292deadf1d8b9b5647a6f508f Mon Sep 17 00:00:00 2001 From: Wisha Wa Date: Mon, 9 Jan 2023 21:42:43 +0000 Subject: [PATCH 05/10] Readiness: keep track of indices that have woken to enable O(woken) polling --- src/utils/wakers/array/readiness.rs | 95 +++++++++++++------------- src/utils/wakers/array/waker_array.rs | 9 +-- src/utils/wakers/mod.rs | 5 +- src/utils/wakers/vec/readiness.rs | 96 ++++++++++++--------------- src/utils/wakers/vec/waker_vec.rs | 9 +-- 5 files changed, 90 insertions(+), 124 deletions(-) diff --git a/src/utils/wakers/array/readiness.rs b/src/utils/wakers/array/readiness.rs index 047103c..440035d 100644 --- a/src/utils/wakers/array/readiness.rs +++ b/src/utils/wakers/array/readiness.rs @@ -1,67 +1,62 @@ -use std::task::Waker; +use super::super::dummy_waker; + +use core::task::Waker; -/// Tracks which wakers are "ready" and should be polled. -#[derive(Debug)] pub(crate) struct ReadinessArray { - count: usize, - ready: [bool; N], - parent_waker: Option, + /// Whether each subfuture has woken. + awake_set: [bool; N], + /// Indices of subfutures that have woken. + awake_list: [usize; N], + /// Number of subfutures that have woken. + /// `awake_list` and `awake_list_len` together makes up something like ArrayVec. + // TODO: Maybe just use the ArrayVec crate? + awake_list_len: usize, + parent_waker: Waker, } impl ReadinessArray { - /// Create a new instance of readiness. pub(crate) fn new() -> Self { Self { - count: N, - ready: [true; N], // TODO: use a bitarray instead - parent_waker: None, + awake_set: [true; N], + awake_list: core::array::from_fn(core::convert::identity), + awake_list_len: N, + parent_waker: dummy_waker(), // parent waker is dummy at first } } - - /// Returns the old ready state for this id - pub(crate) fn set_ready(&mut self, id: usize) -> bool { - if !self.ready[id] { - self.count += 1; - self.ready[id] = true; - - false - } else { - true + pub(crate) fn set_parent_waker(&mut self, waker: &Waker) { + // If self.parent_waker and the given waker are the same then don't do the replacement. + if !self.parent_waker.will_wake(waker) { + self.parent_waker = waker.to_owned(); } } - - /// Set all markers to ready. - pub(crate) fn set_all_ready(&mut self) { - self.ready.fill(true); - self.count = N; - } - - /// Returns whether the task id was previously ready - pub(crate) fn clear_ready(&mut self, id: usize) -> bool { - if self.ready[id] { - self.count -= 1; - self.ready[id] = false; - - true - } else { - false + fn set_woken(&mut self, index: usize) -> bool { + let was_awake = std::mem::replace(&mut self.awake_set[index], true); + if !was_awake { + self.awake_list[self.awake_list_len] = index; + self.awake_list_len += 1; } + was_awake } - - /// Returns `true` if any of the wakers are ready. - pub(crate) fn any_ready(&self) -> bool { - self.count > 0 + pub(crate) fn wake(&mut self, index: usize) { + if !self.set_woken(index) && self.awake_list_len == 1 { + self.parent_waker.wake_by_ref(); + } } - - /// Access the parent waker. - #[inline] - pub(crate) fn parent_waker(&self) -> Option<&Waker> { - self.parent_waker.as_ref() + pub(crate) fn awake_list(&self) -> &[usize] { + &self.awake_list[..self.awake_list_len] } - - /// Set the parent `Waker`. This needs to be called at the start of every - /// `poll` function. - pub(crate) fn set_waker(&mut self, parent_waker: &Waker) { - self.parent_waker = Some(parent_waker.clone()); + const TRESHOLD: usize = N / 64; + pub(crate) fn clear(&mut self) { + // Depending on how many items was in the list, + // either use `fill` (memset) or iterate and set each. + // TODO: I came up with the 64 factor at random. Maybe test different factors? + if self.awake_list_len < Self::TRESHOLD { + self.awake_set.fill(false); + } else { + self.awake_list.iter().for_each(|&idx| { + self.awake_set[idx] = false; + }); + } + self.awake_list_len = 0; } } diff --git a/src/utils/wakers/array/waker_array.rs b/src/utils/wakers/array/waker_array.rs index 7797672..a892aa7 100644 --- a/src/utils/wakers/array/waker_array.rs +++ b/src/utils/wakers/array/waker_array.rs @@ -65,14 +65,7 @@ impl WakeDataContainer for WakerArrayInner { } fn wake_index(&self, index: usize) { - let mut readiness = self.readiness.lock().unwrap(); - if !readiness.set_ready(index) { - readiness - .parent_waker() - .as_ref() - .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") - .wake_by_ref(); - } + self.readiness.lock().unwrap().wake(index); } } diff --git a/src/utils/wakers/mod.rs b/src/utils/wakers/mod.rs index e4fe1e7..70e9c1a 100644 --- a/src/utils/wakers/mod.rs +++ b/src/utils/wakers/mod.rs @@ -1,11 +1,8 @@ mod array; +mod dummy; mod shared_arc; mod vec; -#[cfg(test)] -mod dummy; - -#[cfg(test)] pub(crate) use dummy::dummy_waker; pub(crate) use array::*; diff --git a/src/utils/wakers/vec/readiness.rs b/src/utils/wakers/vec/readiness.rs index 08b6045..9e2daac 100644 --- a/src/utils/wakers/vec/readiness.rs +++ b/src/utils/wakers/vec/readiness.rs @@ -1,70 +1,58 @@ -use bitvec::{bitvec, vec::BitVec}; -use std::task::Waker; +use super::super::dummy_waker; + +use core::task::Waker; + +use bitvec::vec::BitVec; -/// Tracks which wakers are "ready" and should be polled. -#[derive(Debug)] pub(crate) struct ReadinessVec { - count: usize, - max_count: usize, - ready: BitVec, - parent_waker: Option, + /// Whether each subfuture has woken. + awake_set: BitVec, + /// Indices of subfutures that have woken. + awake_list: Vec, + parent_waker: Waker, } impl ReadinessVec { - /// Create a new instance of readiness. - pub(crate) fn new(count: usize) -> Self { + pub(crate) fn new(len: usize) -> Self { + let awake_set = BitVec::repeat(true, len); Self { - count, - max_count: count, - ready: bitvec![true as usize; count], - parent_waker: None, + awake_set, + awake_list: (0..len).collect(), + parent_waker: dummy_waker(), } } - - /// Returns the old ready state for this id - pub(crate) fn set_ready(&mut self, id: usize) -> bool { - if !self.ready[id] { - self.count += 1; - self.ready.set(id, true); - - false - } else { - true + pub(crate) fn set_parent_waker(&mut self, waker: &Waker) { + // If self.parent_waker and the given waker are the same then don't do the replacement. + if !self.parent_waker.will_wake(waker) { + self.parent_waker = waker.to_owned(); } } - - /// Set all markers to ready. - pub(crate) fn set_all_ready(&mut self) { - self.ready.fill(true); - self.count = self.max_count; - } - - /// Returns whether the task id was previously ready - pub(crate) fn clear_ready(&mut self, id: usize) -> bool { - if self.ready[id] { - self.count -= 1; - self.ready.set(id, false); - - true - } else { - false + fn set_woken(&mut self, index: usize) -> bool { + let was_awake = self.awake_set.replace(index, true); + if !was_awake { + self.awake_list.push(index); } + was_awake } - - /// Returns `true` if any of the wakers are ready. - pub(crate) fn any_ready(&self) -> bool { - self.count > 0 + pub(crate) fn wake(&mut self, index: usize) { + if !self.set_woken(index) && self.awake_list.len() == 1 { + self.parent_waker.wake_by_ref(); + } } - - /// Access the parent waker. - #[inline] - pub(crate) fn parent_waker(&self) -> Option<&Waker> { - self.parent_waker.as_ref() + pub(crate) fn awake_list(&self) -> &Vec { + &self.awake_list } - - /// Set the parent `Waker`. This needs to be called at the start of every - /// `poll` function. - pub(crate) fn set_waker(&mut self, parent_waker: &Waker) { - self.parent_waker = Some(parent_waker.clone()); + pub(crate) fn clear(&mut self) { + // Depending on how many items was in the list, + // either use `fill` (memset) or iterate and set each. + // TODO: I came up with the 64 factor at random. Maybe test different factors? + if self.awake_list.len() * 64 < self.awake_set.len() { + self.awake_list.drain(..).for_each(|idx| { + self.awake_set.set(idx, false); + }); + } else { + self.awake_list.clear(); + self.awake_set.fill(false); + } } } diff --git a/src/utils/wakers/vec/waker_vec.rs b/src/utils/wakers/vec/waker_vec.rs index 2e21f3f..5a84a00 100644 --- a/src/utils/wakers/vec/waker_vec.rs +++ b/src/utils/wakers/vec/waker_vec.rs @@ -62,13 +62,6 @@ impl WakeDataContainer for WakerVecInner { } fn wake_index(&self, index: usize) { - let mut readiness = self.readiness.lock().unwrap(); - if !readiness.set_ready(index) { - readiness - .parent_waker() - .as_ref() - .expect("`parent_waker` not available from `Readiness`. Did you forget to call `Readiness::set_waker`?") - .wake_by_ref(); - } + self.readiness.lock().unwrap().wake(index); } } From 2be16841054ae505b5511a51fe5130812d672cc1 Mon Sep 17 00:00:00 2001 From: Wisha Wa Date: Mon, 9 Jan 2023 21:48:06 +0000 Subject: [PATCH 06/10] make method of WakerArray/WakerVec lock the mutex for convenience --- src/utils/wakers/array/waker_array.rs | 6 +++--- src/utils/wakers/vec/waker_vec.rs | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/utils/wakers/array/waker_array.rs b/src/utils/wakers/array/waker_array.rs index a892aa7..975a57f 100644 --- a/src/utils/wakers/array/waker_array.rs +++ b/src/utils/wakers/array/waker_array.rs @@ -1,6 +1,6 @@ use core::array; use core::task::Waker; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard}; use super::{ super::shared_arc::{waker_for_wake_data_slot, WakeDataContainer}, @@ -54,8 +54,8 @@ impl WakerArray { } /// Access the `Readiness`. - pub(crate) fn readiness(&self) -> &Mutex> { - &self.inner.readiness + pub(crate) fn readiness(&self) -> MutexGuard<'_, ReadinessArray> { + self.inner.readiness.lock().unwrap() } } diff --git a/src/utils/wakers/vec/waker_vec.rs b/src/utils/wakers/vec/waker_vec.rs index 5a84a00..0609cc3 100644 --- a/src/utils/wakers/vec/waker_vec.rs +++ b/src/utils/wakers/vec/waker_vec.rs @@ -1,5 +1,5 @@ use core::task::Waker; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard}; use super::{ super::shared_arc::{waker_for_wake_data_slot, WakeDataContainer}, @@ -51,8 +51,8 @@ impl WakerVec { self.wakers.get(index) } - pub(crate) fn readiness(&self) -> &Mutex { - &self.inner.readiness + pub(crate) fn readiness(&self) -> MutexGuard<'_, ReadinessVec> { + self.inner.readiness.lock().unwrap() } } From 16d8b0fda3a14f7c373f0adde1a4458a625073ca Mon Sep 17 00:00:00 2001 From: Wisha Wa Date: Tue, 10 Jan 2023 00:50:16 +0000 Subject: [PATCH 07/10] O(woken) polling for all combinators except chain --- examples/happy_eyeballs.rs | 4 +- src/future/common/array.rs | 207 +++++++++++++++++++ src/future/common/mod.rs | 10 + src/future/common/tuple.rs | 239 ++++++++++++++++++++++ src/future/common/vec.rs | 161 +++++++++++++++ src/future/join/array.rs | 202 +++++------------- src/future/join/mod.rs | 3 + src/future/join/tuple.rs | 306 ++++------------------------ src/future/join/vec.rs | 171 ++-------------- src/future/mod.rs | 1 + src/future/race/array.rs | 64 ++---- src/future/race/mod.rs | 3 + src/future/race/tuple.rs | 138 ++++--------- src/future/race/vec.rs | 64 ++---- src/future/race_ok/array.rs | 113 ++++++++++ src/future/race_ok/array/error.rs | 54 ----- src/future/race_ok/array/mod.rs | 151 -------------- src/future/race_ok/error.rs | 21 ++ src/future/race_ok/mod.rs | 8 +- src/future/race_ok/tuple.rs | 124 +++++++++++ src/future/race_ok/tuple/error.rs | 54 ----- src/future/race_ok/tuple/mod.rs | 220 -------------------- src/future/race_ok/vec.rs | 113 ++++++++++ src/future/race_ok/vec/error.rs | 56 ----- src/future/race_ok/vec/mod.rs | 138 ------------- src/future/try_join/array.rs | 90 ++------ src/future/try_join/mod.rs | 8 +- src/future/try_join/tuple.rs | 89 ++++++++ src/future/try_join/vec.rs | 74 ++----- src/lib.rs | 9 +- src/stream/merge/array.rs | 119 +++++++---- src/stream/merge/tuple.rs | 225 ++++++++++---------- src/stream/merge/vec.rs | 99 +++++---- src/stream/zip/array.rs | 168 ++++++++------- src/stream/zip/tuple.rs | 224 ++++++++++++++++++++ src/stream/zip/vec.rs | 122 +++++------ src/utils/array_dequeue.rs | 50 +++++ src/utils/mod.rs | 11 +- src/utils/poll_state/array.rs | 40 ---- src/utils/poll_state/maybe_done.rs | 70 ------- src/utils/poll_state/mod.rs | 25 ++- src/utils/poll_state/poll_state.rs | 49 ----- src/utils/poll_state/vec.rs | 90 -------- src/utils/wakers/array/readiness.rs | 3 +- src/utils/wakers/vec/readiness.rs | 3 +- 45 files changed, 2031 insertions(+), 2162 deletions(-) create mode 100644 src/future/common/array.rs create mode 100644 src/future/common/mod.rs create mode 100644 src/future/common/tuple.rs create mode 100644 src/future/common/vec.rs create mode 100644 src/future/race_ok/array.rs delete mode 100644 src/future/race_ok/array/error.rs delete mode 100644 src/future/race_ok/array/mod.rs create mode 100644 src/future/race_ok/error.rs create mode 100644 src/future/race_ok/tuple.rs delete mode 100644 src/future/race_ok/tuple/error.rs delete mode 100644 src/future/race_ok/tuple/mod.rs create mode 100644 src/future/race_ok/vec.rs delete mode 100644 src/future/race_ok/vec/error.rs delete mode 100644 src/future/race_ok/vec/mod.rs create mode 100644 src/future/try_join/tuple.rs create mode 100644 src/utils/array_dequeue.rs delete mode 100644 src/utils/poll_state/array.rs delete mode 100644 src/utils/poll_state/maybe_done.rs delete mode 100644 src/utils/poll_state/poll_state.rs delete mode 100644 src/utils/poll_state/vec.rs diff --git a/examples/happy_eyeballs.rs b/examples/happy_eyeballs.rs index ed591cd..b0c36aa 100644 --- a/examples/happy_eyeballs.rs +++ b/examples/happy_eyeballs.rs @@ -1,12 +1,12 @@ use async_std::io::prelude::*; use futures::future::TryFutureExt; +use futures_concurrency::errors::AggregateError; use futures_concurrency::prelude::*; use futures_time::prelude::*; use async_std::io; use async_std::net::TcpStream; use futures::channel::oneshot; -use futures_concurrency::vec::AggregateError; use futures_time::time::Duration; use std::error; @@ -27,7 +27,7 @@ async fn open_tcp_socket( addr: &str, port: u16, attempts: u64, -) -> Result> { +) -> Result>> { let (mut sender, mut receiver) = oneshot::channel(); let mut futures = Vec::with_capacity(attempts as usize); diff --git a/src/future/common/array.rs b/src/future/common/array.rs new file mode 100644 index 0000000..99e3a07 --- /dev/null +++ b/src/future/common/array.rs @@ -0,0 +1,207 @@ +use crate::utils::{self, WakerArray}; + +use core::array; +use core::fmt; +use core::future::Future; +use core::marker::PhantomData; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use pin_project::{pin_project, pinned_drop}; + +/// A trait for making CombinatorArray behave as Join/TryJoin/Race/RaceOk. +pub trait CombinatorBehaviorArray +where + Fut: Future, +{ + /// The output type of the future. + /// + /// Example: + /// for Join, this is [F::Output; N]. + /// for RaceOk, this is Result. + type Output; + + /// The type of item stored. + /// + /// Example: + /// for Join this is F::Output. + /// for RaceOk this is F::Error. + type StoredItem; + + /// Takes the output of a subfuture and decide what to do with it. + /// If this function returns Err(output), the combinator would early return Poll::Ready(output). + /// For Ok(item), the combinator would keep the item in an array. + /// If by the end, all items are kept (no early return made), + /// then `when_completed` will be called on the items array. + /// + /// Example: + /// Join will always wrap the output in Ok because it want to wait until all outputs are ready. + /// Race will always wrap the output in Err because it want to early return with the first output. + fn maybe_return(idx: usize, res: Fut::Output) -> Result; + + /// Called when all subfutures are completed and none caused the combinator to return early. + /// The argument is an array of the kept item from each subfuture. + fn when_completed(arr: [Self::StoredItem; N]) -> Self::Output; +} + +#[must_use = "futures do nothing unless you `.await` or poll them"] +#[pin_project(PinnedDrop)] +pub struct CombinatorArray +where + Fut: Future, + B: CombinatorBehaviorArray, +{ + behavior: PhantomData, + /// Number of subfutures that have not yet completed. + pending: usize, + wakers: WakerArray, + /// The stored items from each subfuture. + items: [MaybeUninit; N], + /// Whether each item in self.items is initialized. + /// Invariant: self.filled.count_falses() == self.pending. + filled: [bool; N], + /// A temporary buffer for indices that have woken. + /// The data here don't have to persist between each `poll`. + awake_list_buffer: [usize; N], + #[pin] + futures: [Fut; N], +} + +impl CombinatorArray +where + Fut: Future, + B: CombinatorBehaviorArray, +{ + #[inline] + pub(crate) fn new(futures: [Fut; N]) -> Self { + CombinatorArray { + behavior: PhantomData, + pending: N, + wakers: WakerArray::new(), + items: array::from_fn(|_| MaybeUninit::uninit()), + filled: [false; N], + // TODO: this is a temporary buffer so it can be MaybeUninit. + awake_list_buffer: [0; N], + futures, + } + } +} + +impl fmt::Debug for CombinatorArray +where + Fut: Future + fmt::Debug, + B: CombinatorBehaviorArray, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.futures.iter()).finish() + } +} + +impl Future for CombinatorArray +where + Fut: Future, + B: CombinatorBehaviorArray, +{ + type Output = B::Output; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + // If this.pending == 0, the future is done. + assert!( + N == 0 || *this.pending > 0, + "Futures must not be polled after completing" + ); + + let num_awake = { + // Lock the readiness Mutex. + let mut readiness = this.wakers.readiness(); + readiness.set_parent_waker(cx.waker()); + + // Copy the list of indices that have woken.. + let awake_list = readiness.awake_list(); + let num_awake = awake_list.len(); + this.awake_list_buffer[..num_awake].copy_from_slice(awake_list); + + // Clear the list. + readiness.clear(); + num_awake + + // Awakeneess Mutex should be unlocked here. + }; + + // Iterate over the indices we've copied out of the Mutex. + for &idx in this.awake_list_buffer.iter().take(num_awake) { + let filled = &mut this.filled[idx]; + if *filled { + // Woken subfuture is already complete, don't poll it again. + // (Futures probably shouldn't wake after they are complete, but there's no guarantee.) + continue; + } + let fut = utils::get_pin_mut(this.futures.as_mut(), idx).unwrap(); + let mut cx = Context::from_waker(this.wakers.get(idx).unwrap()); + if let Poll::Ready(value) = fut.poll(&mut cx) { + match B::maybe_return(idx, value) { + // Keep the item for returning once every subfuture is done. + Ok(store) => { + this.items[idx].write(store); + *filled = true; + *this.pending -= 1; + } + // Early return. + Err(ret) => return Poll::Ready(ret), + } + } + } + + // Check whether we're all done now or need to keep going. + if *this.pending == 0 { + // Check an internal invariant. + // No matter how ill-behaved the subfutures are, this should be held. + debug_assert!( + this.filled.iter().all(|&filled| filled), + "Future should have filled items array" + ); + this.filled.fill(false); + + let mut items = array::from_fn(|_| MaybeUninit::uninit()); + core::mem::swap(this.items, &mut items); + + // SAFETY: this.pending is only decremented when an item slot is filled. + // pending reaching 0 means the entire items array is filled. + // + // For N > 0, we can only enter this if block once (because the assert at the top), + // so it is safe to take the data. + // For N == 0, we can enter this if block many times (in case of poll-after-done), + // but then the items array is empty anyway so we're fine. + let items = unsafe { utils::array_assume_init(items) }; + + // Let the Behavior do any final transformation. + // For example, TryJoin would wrap the whole thing in Ok. + Poll::Ready(B::when_completed(items)) + } else { + Poll::Pending + } + } +} + +/// Drop the already initialized values on cancellation. +#[pinned_drop] +impl PinnedDrop for CombinatorArray +where + Fut: Future, + B: CombinatorBehaviorArray, +{ + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + + for (&filled, output) in this.filled.iter().zip(this.items.iter_mut()) { + if filled { + // SAFETY: filled is only set to true for initialized items. + unsafe { output.assume_init_drop() }; + } + } + } +} diff --git a/src/future/common/mod.rs b/src/future/common/mod.rs new file mode 100644 index 0000000..cfb44e8 --- /dev/null +++ b/src/future/common/mod.rs @@ -0,0 +1,10 @@ +//! This module implements a combinator similar to TryJoin. +//! The actual TryJoin, along with Join, Race, and RaceOk, can delegate to this. + +mod array; +mod tuple; +mod vec; + +pub(crate) use array::{CombinatorArray, CombinatorBehaviorArray}; +pub(crate) use tuple::{CombineTuple, TupleMaybeReturn, TupleWhenCompleted}; +pub(crate) use vec::{CombinatorBehaviorVec, CombinatorVec}; diff --git a/src/future/common/tuple.rs b/src/future/common/tuple.rs new file mode 100644 index 0000000..d35f908 --- /dev/null +++ b/src/future/common/tuple.rs @@ -0,0 +1,239 @@ +use crate::utils::WakerArray; + +use core::fmt::{self, Debug}; +use core::future::Future; +use core::marker::PhantomData; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::{Context, Poll}; + +/// An internal trait. +/// Join/Race/RaceOk/TryJoin uses this to construct their CombinatorTupleN. +/// This trait is implemented on tuple of the format +/// ((fut1, fut2,...), Behavior ZST, PhantomData). +pub trait CombineTuple { + /// The resulting combinator future. + type Combined; + fn combine(self) -> Self::Combined; +} + +/// This and [TupleWhenCompleted] takes the role of [super::array::CombinatorBehaviorArray] but for tuples. +/// Type parameters: +/// R = the return type of a subfuture. +/// O = the return type of the combinator future. +pub trait TupleMaybeReturn { + /// The type of the item to store for this subfuture. + type StoredItem; + /// Take the return value of a subfuture and decide whether to store it or early return. + /// Ok(v) = store v. + /// Err(o) = early return o. + fn maybe_return(idx: usize, res: R) -> Result; +} +/// This and [TupleMaybeReturn] takes the role of [super::array::CombinatorBehaviorArray] but for tuples. +/// Type parameters: +/// S = the type of the stored tuples = (F1::StoredItem, F2::StoredItem, ...). +/// O = the return type of the combinator future. +pub trait TupleWhenCompleted { + /// Called when all subfutures are completed and none caused the combinator to return early. + /// The argument is an array of the kept item from each subfuture. + fn when_completed(stored_items: S) -> O; +} + +macro_rules! impl_common_tuple { + ($mod_name:ident $StructName:ident $($F:ident=$idx:tt)+) => { + mod $mod_name { + #[pin_project::pin_project] + pub(super) struct Futures<$($F,)+> { $(#[pin] pub(super) $F: $F,)+ } + pub(super) const LEN: usize = [$($idx,)+].len(); + } + #[pin_project::pin_project(PinnedDrop)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[allow(non_snake_case)] + pub struct $StructName + where + $( + $F: Future, + )+ + $( + B: TupleMaybeReturn<$F::Output, O>, + )+ + B: TupleWhenCompleted<($(>::StoredItem,)+), O> + { + pending: usize, + items: ($(MaybeUninit<>::StoredItem>,)+), + wakers: WakerArray<{$mod_name::LEN}>, + filled: [bool; $mod_name::LEN], + awake_list_buffer: [usize; $mod_name::LEN], + #[pin] + futures: $mod_name::Futures<$($F,)+>, + phantom: PhantomData + } + + impl Debug for $StructName + where + $( + $F: Future + Debug, + )+ + $( + B: TupleMaybeReturn<$F::Output, O>, + )+ + B: TupleWhenCompleted<($(>::StoredItem,)+), O> + { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Join") + $(.field(&self.futures.$F))+ + .finish() + } + } + + impl CombineTuple for (($($F,)+), B, PhantomData) + where + $( + $F: Future, + )+ + $( + B: TupleMaybeReturn<$F::Output, O>, + )+ + B: TupleWhenCompleted<($(>::StoredItem,)+), O> + { + type Combined = $StructName; + fn combine(self) -> Self::Combined { + $StructName { + filled: [false; $mod_name::LEN], + items: ($(MaybeUninit::<>::StoredItem>::uninit(),)+), + wakers: WakerArray::new(), + pending: $mod_name::LEN, + awake_list_buffer: [0; $mod_name::LEN], + futures: $mod_name::Futures {$($F: self.0.$idx,)+}, + phantom: PhantomData + } + } + } + + #[allow(unused_mut)] + #[allow(unused_parens)] + #[allow(unused_variables)] + impl Future for $StructName + where + $( + $F: Future, + )+ + $( + B: TupleMaybeReturn<$F::Output, O>, + )+ + B: TupleWhenCompleted<($(>::StoredItem,)+), O> + + { + type Output = O; + + fn poll( + self: Pin<&mut Self>, cx: &mut Context<'_> + ) -> Poll { + let mut this = self.project(); + assert!( + *this.pending > 0, + "Futures must not be polled after completing" + ); + + let mut futures = this.futures.project(); + + let num_awake = { + let mut readiness = this.wakers.readiness(); + readiness.set_parent_waker(cx.waker()); + let awake_list = readiness.awake_list(); + let num_awake = awake_list.len(); + this.awake_list_buffer[..num_awake].copy_from_slice(awake_list); + readiness.clear(); + num_awake + }; + + + for &idx in this.awake_list_buffer.iter().take(num_awake) { + let filled = &mut this.filled[idx]; + if *filled { + continue; + } + let mut cx = Context::from_waker(this.wakers.get(idx).unwrap()); + let ready = match idx { + $( + $idx => { + if let Poll::Ready(value) = futures.$F.as_mut().poll(&mut cx) { + match B::maybe_return($idx, value) { + Err(ret) => { + return Poll::Ready(ret); + }, + Ok(store) => { + this.items.$idx.write(store); + true + } + } + } + else { + false + } + }, + )* + _ => unreachable!() + }; + if ready { + *this.pending -= 1; + *filled = true; + } + } + + if *this.pending == 0 { + debug_assert!(this.filled.iter().all(|&filled| filled), "Future should have filled items array"); + this.filled.fill(false); + let out = { + let mut out = ($(MaybeUninit::<>::StoredItem>::uninit(),)+); + core::mem::swap(&mut out, this.items); + let ($($F,)+) = out; + // SAFETY: we've checked with the state that all of our outputs have been + // filled, which means we're ready to take the data and assume it's initialized. + unsafe { ($($F.assume_init(),)+) } + }; + Poll::Ready(B::when_completed(out)) + } + else { + Poll::Pending + } + } + } + + #[pin_project::pinned_drop] + impl PinnedDrop for $StructName + where + $( + $F: Future, + )+ + $( + B: TupleMaybeReturn<$F::Output, O>, + )+ + B: TupleWhenCompleted<($(>::StoredItem,)+), O> + { + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + $( + if this.filled[$idx] { + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { this.items.$idx.assume_init_drop() }; + } + )+ + } + } + }; +} + +impl_common_tuple! { common1 CombinatorTuple1 A0=0 } +impl_common_tuple! { common2 CombinatorTuple2 A0=0 A1=1 } +impl_common_tuple! { common3 CombinatorTuple3 A0=0 A1=1 A2=2 } +impl_common_tuple! { common4 CombinatorTuple4 A0=0 A1=1 A2=2 A3=3 } +impl_common_tuple! { common5 CombinatorTuple5 A0=0 A1=1 A2=2 A3=3 A4=4 } +impl_common_tuple! { common6 CombinatorTuple6 A0=0 A1=1 A2=2 A3=3 A4=4 A5=5 } +impl_common_tuple! { common7 CombinatorTuple7 A0=0 A1=1 A2=2 A3=3 A4=4 A5=5 A6=6 } +impl_common_tuple! { common8 CombinatorTuple8 A0=0 A1=1 A2=2 A3=3 A4=4 A5=5 A6=6 A7=7 } +impl_common_tuple! { common9 CombinatorTuple9 A0=0 A1=1 A2=2 A3=3 A4=4 A5=5 A6=6 A7=7 A8=8 } +impl_common_tuple! { common10 CombinatorTuple10 A0=0 A1=1 A2=2 A3=3 A4=4 A5=5 A6=6 A7=7 A8=8 A9=9 } +impl_common_tuple! { common11 CombinatorTuple11 A0=0 A1=1 A2=2 A3=3 A4=4 A5=5 A6=6 A7=7 A8=8 A9=9 A10=10 } +impl_common_tuple! { common12 CombinatorTuple12 A0=0 A1=1 A2=2 A3=3 A4=4 A5=5 A6=6 A7=7 A8=8 A9=9 A10=10 A11=11 } diff --git a/src/future/common/vec.rs b/src/future/common/vec.rs new file mode 100644 index 0000000..d91a482 --- /dev/null +++ b/src/future/common/vec.rs @@ -0,0 +1,161 @@ +use crate::utils::{self, WakerVec}; + +use core::fmt; +use core::future::Future; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::{Context, Poll}; +use std::vec::Vec; + +use bitvec::vec::BitVec; +use pin_project::{pin_project, pinned_drop}; + +// For code comments, see the array module. + +/// A trait for making CombinatorVec behave as Join/TryJoin/Race/RaceOk. +/// See [super::CombinatorBehaviorArray], which is very similar, for documentation. +pub trait CombinatorBehaviorVec +where + Fut: Future, +{ + type Output; + type StoredItem; + fn maybe_return(idx: usize, res: Fut::Output) -> Result; + fn when_completed(vec: Vec) -> Self::Output; +} + +/// See [super::CombinatorArray] for documentation. +#[must_use = "futures do nothing unless you `.await` or poll them"] +#[pin_project(PinnedDrop)] +pub struct CombinatorVec +where + Fut: Future, + B: CombinatorBehaviorVec, +{ + pending: usize, + items: Vec>, + wakers: WakerVec, + filled: BitVec, + awake_list_buffer: Vec, + #[pin] + futures: Vec, +} + +impl CombinatorVec +where + Fut: Future, + B: CombinatorBehaviorVec, +{ + pub(crate) fn new(futures: Vec) -> Self { + let len = futures.len(); + CombinatorVec { + pending: len, + items: std::iter::repeat_with(MaybeUninit::uninit) + .take(len) + .collect(), + wakers: WakerVec::new(len), + filled: BitVec::repeat(false, len), + awake_list_buffer: Vec::new(), + futures, + } + } +} + +impl fmt::Debug for CombinatorVec +where + Fut: Future + fmt::Debug, + B: CombinatorBehaviorVec, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.futures.iter()).finish() + } +} + +impl Future for CombinatorVec +where + Fut: Future, + B: CombinatorBehaviorVec, +{ + type Output = B::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + assert!( + *this.pending > 0 || this.items.is_empty(), + "Futures must not be polled after completing" + ); + + { + let mut readiness = this.wakers.readiness(); + readiness.set_parent_waker(cx.waker()); + + this.awake_list_buffer.clone_from(readiness.awake_list()); + + readiness.clear(); + } + + for idx in this.awake_list_buffer.drain(..) { + if this.filled[idx] { + continue; + } + let fut = utils::get_pin_mut_from_vec(this.futures.as_mut(), idx).unwrap(); + let mut cx = Context::from_waker(this.wakers.get(idx).unwrap()); + if let Poll::Ready(value) = fut.poll(&mut cx) { + match B::maybe_return(idx, value) { + Ok(store) => { + this.items[idx].write(store); + this.filled.set(idx, true); + *this.pending -= 1; + } + Err(ret) => { + return Poll::Ready(ret); + } + } + } + } + + if *this.pending == 0 { + debug_assert!( + this.filled.iter().all(|filled| *filled), + "Future should have reached a `Ready` state" + ); + this.filled.fill(false); + + // SAFETY: this.pending is only decremented when an item slot is filled. + // pending reaching 0 means the entire items array is filled. + // + // For len > 0, we can only enter this if block once (because the assert at the top), + // so it is safe to take the data. + // For len == 0, we can enter this if block many times (in case of poll-after-done), + // but then the items array is empty anyway so we're fine. + let items = unsafe { + let items = core::mem::take(this.items); + core::mem::transmute::>, Vec>(items) + }; + + Poll::Ready(B::when_completed(items)) + } else { + Poll::Pending + } + } +} + +/// Drop the already initialized values on cancellation. +#[pinned_drop] +impl PinnedDrop for CombinatorVec +where + Fut: Future, + B: CombinatorBehaviorVec, +{ + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + + for (filled, output) in this.filled.iter().zip(this.items.iter_mut()) { + if *filled { + // SAFETY: filled is only set to true for initialized items. + unsafe { output.assume_init_drop() } + } + } + } +} diff --git a/src/future/join/array.rs b/src/future/join/array.rs index cdedcc4..53ff3d6 100644 --- a/src/future/join/array.rs +++ b/src/future/join/array.rs @@ -1,14 +1,7 @@ -use super::Join as JoinTrait; -use crate::utils::{self, PollArray, WakerArray}; +use super::super::common::{CombinatorArray, CombinatorBehaviorArray}; +use super::{Join as JoinTrait, JoinBehavior}; -use core::array; -use core::fmt; use core::future::{Future, IntoFuture}; -use core::mem::{self, MaybeUninit}; -use core::pin::Pin; -use core::task::{Context, Poll}; - -use pin_project::{pin_project, pinned_drop}; /// Waits for two similarly-typed futures to complete. /// @@ -17,35 +10,25 @@ use pin_project::{pin_project, pinned_drop}; /// /// [`join`]: crate::future::Join::join /// [`Join`]: crate::future::Join -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[pin_project(PinnedDrop)] -pub struct Join -where - Fut: Future, -{ - consumed: bool, - pending: usize, - items: [MaybeUninit<::Output>; N], - wakers: WakerArray, - state: PollArray, - #[pin] - futures: [Fut; N], -} +pub type Join = CombinatorArray; -impl Join +impl CombinatorBehaviorArray for JoinBehavior where Fut: Future, { - #[inline] - pub(crate) fn new(futures: [Fut; N]) -> Self { - Join { - consumed: false, - pending: N, - items: array::from_fn(|_| MaybeUninit::uninit()), - wakers: WakerArray::new(), - state: PollArray::new(), - futures, - } + type Output = [Fut::Output; N]; + + type StoredItem = Fut::Output; + + fn maybe_return( + _idx: usize, + res: ::Output, + ) -> Result { + Ok(res) + } + + fn when_completed(arr: [Self::StoredItem; N]) -> Self::Output { + arr } } @@ -62,117 +45,14 @@ where } } -impl fmt::Debug for Join -where - Fut: Future + fmt::Debug, - Fut::Output: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.state.iter()).finish() - } -} - -impl Future for Join -where - Fut: Future, -{ - type Output = [Fut::Output; N]; - - #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - assert!( - !*this.consumed, - "Futures must not be polled after completing" - ); - - let mut readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_waker(cx.waker()); - if !readiness.any_ready() { - // Nothing is ready yet - return Poll::Pending; - } - - // Poll all ready futures - for (i, fut) in utils::iter_pin_mut(this.futures.as_mut()).enumerate() { - if this.state[i].is_pending() && readiness.clear_ready(i) { - // unlock readiness so we don't deadlock when polling - drop(readiness); - - // Obtain the intermediate waker. - let mut cx = Context::from_waker(this.wakers.get(i).unwrap()); - - if let Poll::Ready(value) = fut.poll(&mut cx) { - this.items[i] = MaybeUninit::new(value); - this.state[i].set_ready(); - *this.pending -= 1; - } - - // Lock readiness so we can use it again - readiness = this.wakers.readiness().lock().unwrap(); - } - } - - // Check whether we're all done now or need to keep going. - if *this.pending == 0 { - // Mark all data as "consumed" before we take it - *this.consumed = true; - for state in this.state.iter_mut() { - debug_assert!( - state.is_ready(), - "Future should have reached a `Ready` state" - ); - state.set_consumed(); - } - - let mut items = array::from_fn(|_| MaybeUninit::uninit()); - mem::swap(this.items, &mut items); - - // SAFETY: we've checked with the state that all of our outputs have been - // filled, which means we're ready to take the data and assume it's initialized. - let items = unsafe { utils::array_assume_init(items) }; - Poll::Ready(items) - } else { - Poll::Pending - } - } -} - -/// Drop the already initialized values on cancellation. -#[pinned_drop] -impl PinnedDrop for Join -where - Fut: Future, -{ - fn drop(self: Pin<&mut Self>) { - let this = self.project(); - - // Get the indexes of the initialized values. - let indexes = this - .state - .iter_mut() - .enumerate() - .filter(|(_, state)| state.is_ready()) - .map(|(i, _)| i); - - // Drop each value at the index. - for i in indexes { - // SAFETY: we've just filtered down to *only* the initialized values. - // We can assume they're initialized, and this is where we drop them. - unsafe { this.items[i].assume_init_drop() }; - } - } -} - #[cfg(test)] mod test { + use futures_lite::future::yield_now; + use super::*; - use crate::utils::dummy_waker; + use std::cell::RefCell; use std::future; - use std::future::Future; - use std::task::Context; #[test] fn smoke() { @@ -183,14 +63,38 @@ mod test { } #[test] - fn debug() { - let mut fut = [future::ready("hello"), future::ready("world")].join(); - assert_eq!(format!("{:?}", fut), "[Pending, Pending]"); - let mut fut = Pin::new(&mut fut); - - let waker = dummy_waker(); - let mut cx = Context::from_waker(&waker); - let _ = fut.as_mut().poll(&mut cx); - assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]"); + fn poll_order() { + let polled = RefCell::new(Vec::new()); + async fn record_poll(id: char, times: usize, target: &RefCell>) { + for _ in 0..times { + target.borrow_mut().push(id); + yield_now().await; + } + target.borrow_mut().push(id); + } + futures_lite::future::block_on( + [ + record_poll('a', 0, &polled), + record_poll('b', 1, &polled), + record_poll('c', 0, &polled), + ] + .join(), + ); + assert_eq!(&**polled.borrow(), ['a', 'b', 'c', 'b']); + + polled.borrow_mut().clear(); + futures_lite::future::block_on( + [ + record_poll('a', 2, &polled), + record_poll('b', 3, &polled), + record_poll('c', 1, &polled), + record_poll('d', 0, &polled), + ] + .join(), + ); + assert_eq!( + &**polled.borrow(), + ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'a', 'b', 'b'] + ); } } diff --git a/src/future/join/mod.rs b/src/future/join/mod.rs index 2cfbb13..8a5c7fd 100644 --- a/src/future/join/mod.rs +++ b/src/future/join/mod.rs @@ -23,3 +23,6 @@ pub trait Join { /// This function returns a new future which polls all futures concurrently. fn join(self) -> Self::Future; } + +#[derive(Debug)] +pub struct JoinBehavior; diff --git a/src/future/join/tuple.rs b/src/future/join/tuple.rs index cee4dd1..7893813 100644 --- a/src/future/join/tuple.rs +++ b/src/future/join/tuple.rs @@ -1,247 +1,64 @@ -use super::Join as JoinTrait; -use crate::utils::{PollArray, WakerArray}; +use super::super::common::{CombineTuple, TupleMaybeReturn, TupleWhenCompleted}; +use super::{Join as JoinTrait, JoinBehavior}; -use core::fmt::{self, Debug}; -use core::future::{Future, IntoFuture}; -use core::mem::MaybeUninit; -use core::pin::Pin; -use core::task::{Context, Poll}; +use core::future::IntoFuture; +use core::marker::PhantomData; -use pin_project::{pin_project, pinned_drop}; - -/// Generates the `poll` call for every `Future` inside `$futures`. -// This is implemented as a tt-muncher of the future name `$($F:ident)` -// and the future index `$($rest)`, taking advantage that we only support -// tuples up to 12 elements -// -// # References -// TT Muncher: https://veykril.github.io/tlborm/decl-macros/patterns/tt-muncher.html -macro_rules! poll { - (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => { - if $fut_idx == $iteration { - if let Poll::Ready(value) = $futures.$fut_name.as_mut().poll(&mut $cx) { - $this.outputs.$fut_idx.write(value); - *$this.completed += 1; - $this.state[$fut_idx].set_ready(); - } - } - poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*); - }; - - // base condition, no more futures to poll - (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, | $($rest:tt)*) => {}; - - ($iteration:ident, $this:ident, $futures:ident, $cx:ident, $LEN:ident, $($F:ident,)+) => { - poll!(@inner $iteration, $this, $futures, $cx, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11); - }; +impl TupleMaybeReturn for JoinBehavior { + type StoredItem = T; + fn maybe_return(_: usize, res: T) -> Result { + Ok(res) + } } - -macro_rules! drop_outputs { - (@drop $output:ident, $($rem_outs:ident,)* | $states:expr, $stix:tt, $($rem_idx:tt,)*) => { - if $states[$stix].is_ready() { - // SAFETY: we're filtering out only the outputs marked as `ready`, - // which means that this memory is initialized - unsafe { $output.assume_init_drop() }; - $states[$stix].set_consumed(); - } - drop_outputs!(@drop $($rem_outs,)* | $states, $($rem_idx,)*); - }; - - // base condition, no more outputs to look - (@drop | $states:expr, $($rem_idx:tt,)*) => {}; - - ($($outs:ident,)+ | $states:expr) => { - drop_outputs!(@drop $($outs,)+ | $states, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,); - }; +impl TupleWhenCompleted for JoinBehavior { + fn when_completed(stored_items: O) -> O { + stored_items + } } macro_rules! impl_join_tuple { - ($mod_name:ident $StructName:ident) => { - /// Waits for two similarly-typed futures to complete. - /// - /// This `struct` is created by the [`join`] method on the [`Join`] trait. See - /// its documentation for more. - /// - /// [`join`]: crate::future::Join::join - /// [`Join`]: crate::future::Join - #[must_use = "futures do nothing unless you `.await` or poll them"] - #[allow(non_snake_case)] - pub struct $StructName {} - - impl fmt::Debug for $StructName { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Join").finish() - } - } - - impl Future for $StructName { - type Output = (); - - fn poll( - self: Pin<&mut Self>, _cx: &mut Context<'_> - ) -> Poll { - Poll::Ready(()) - } - } - - impl JoinTrait for () { - type Output = (); - type Future = $StructName; - fn join(self) -> Self::Future { - $StructName {} - } - } - }; - ($mod_name:ident $StructName:ident $($F:ident)+) => { - mod $mod_name { - - #[pin_project::pin_project] - pub(super) struct Futures<$($F,)+> { $(#[pin] pub(super) $F: $F,)+ } - - #[repr(u8)] - pub(super) enum Indexes { $($F,)+ } - - pub(super) const LEN: usize = [$(Indexes::$F,)+].len(); - } - - /// Waits for many similarly-typed futures to complete. - /// - /// This `struct` is created by the [`join`] method on the [`Join`] trait. See - /// its documentation for more. - /// - /// [`join`]: crate::future::Join::join - /// [`Join`]: crate::future::Join - #[pin_project(PinnedDrop)] - #[must_use = "futures do nothing unless you `.await` or poll them"] - #[allow(non_snake_case)] - pub struct $StructName<$($F: Future),+> { - #[pin] futures: $mod_name::Futures<$($F,)+>, - outputs: ($(MaybeUninit<$F::Output>,)+), - // trace the state of outputs, marking them as ready or consumed - // then, drop the non-consumed values, if any - state: PollArray<{$mod_name::LEN}>, - wakers: WakerArray<{$mod_name::LEN}>, - completed: usize, - } - - impl<$($F),+> Debug for $StructName<$($F),+> - where $( - $F: Future + Debug, - $F::Output: Debug, - )+ { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Join") - $(.field(&self.futures.$F))+ - .finish() - } - } - - #[allow(unused_mut)] - #[allow(unused_parens)] - #[allow(unused_variables)] - impl<$($F: Future),+> Future for $StructName<$($F),+> { - type Output = ($($F::Output,)+); - - fn poll( - self: Pin<&mut Self>, cx: &mut Context<'_> - ) -> Poll { - const LEN: usize = $mod_name::LEN; - - let mut this = self.project(); - let all_completed = !(*this.completed == LEN); - assert!(all_completed, "Futures must not be polled after completing"); - - let mut futures = this.futures.project(); - - let mut readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_waker(cx.waker()); - - for index in 0..LEN { - if !readiness.any_ready() { - // nothing ready yet - return Poll::Pending; - } - if !readiness.clear_ready(index) || this.state[index].is_ready() { - // future not ready yet or already polled to completion, skip - continue; - } - - // unlock readiness so we don't deadlock when polling - drop(readiness); - - // obtain the intermediate waker - let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); - - // generate the needed code to poll `futures.{index}` - poll!(index, this, futures, cx, LEN, $($F,)+); - - if *this.completed == LEN { - let out = { - let mut out = ($(MaybeUninit::<$F::Output>::uninit(),)+); - core::mem::swap(&mut out, this.outputs); - let ($($F,)+) = out; - unsafe { ($($F.assume_init(),)+) } - }; - - this.state.set_all_completed(); - - return Poll::Ready(out); - } - readiness = this.wakers.readiness().lock().unwrap(); - } - - Poll::Pending - } - } - - #[pinned_drop] - impl<$($F: Future),+> PinnedDrop for $StructName<$($F),+> { - fn drop(self: Pin<&mut Self>) { - let this = self.project(); - - let ($(ref mut $F,)+) = this.outputs; - - let states = this.state; - drop_outputs!($($F,)+ | states); - } - } - - #[allow(unused_parens)] + ($($F:ident)+) => { impl<$($F),+> JoinTrait for ($($F,)+) where $( $F: IntoFuture, )+ { - type Output = ($($F::Output,)*); - type Future = $StructName<$($F::IntoFuture),*>; - + type Output = ($($F::Output,)+); + type Future = <(($($F::IntoFuture,)+), JoinBehavior, PhantomData<($($F::Output,)+)>) as CombineTuple>::Combined; fn join(self) -> Self::Future { - let ($($F,)+): ($($F,)+) = self; - $StructName { - futures: $mod_name::Futures {$($F: $F.into_future(),)+}, - state: PollArray::new(), - outputs: ($(MaybeUninit::<$F::Output>::uninit(),)+), - wakers: WakerArray::new(), - completed: 0, - } + let ($($F,)+) = self; + ( + ( + $($F.into_future(),)+ + ), + JoinBehavior, + PhantomData + ).combine() } } }; +} +impl JoinTrait for () { + type Output = (); + type Future = core::future::Ready<()>; + + fn join(self) -> Self::Future { + core::future::ready(()) + } } -impl_join_tuple! { join0 Join0 } -impl_join_tuple! { join1 Join1 A } -impl_join_tuple! { join2 Join2 A B } -impl_join_tuple! { join3 Join3 A B C } -impl_join_tuple! { join4 Join4 A B C D } -impl_join_tuple! { join5 Join5 A B C D E } -impl_join_tuple! { join6 Join6 A B C D E F } -impl_join_tuple! { join7 Join7 A B C D E F G } -impl_join_tuple! { join8 Join8 A B C D E F G H } -impl_join_tuple! { join9 Join9 A B C D E F G H I } -impl_join_tuple! { join10 Join10 A B C D E F G H I J } -impl_join_tuple! { join11 Join11 A B C D E F G H I J K } -impl_join_tuple! { join12 Join12 A B C D E F G H I J K L } +impl_join_tuple! { A0 } +impl_join_tuple! { A0 A1 } +impl_join_tuple! { A0 A1 A2 } +impl_join_tuple! { A0 A1 A2 A3 } +impl_join_tuple! { A0 A1 A2 A3 A4 } +impl_join_tuple! { A0 A1 A2 A3 A4 A5 } +impl_join_tuple! { A0 A1 A2 A3 A4 A5 A6 } +impl_join_tuple! { A0 A1 A2 A3 A4 A5 A6 A7 } +impl_join_tuple! { A0 A1 A2 A3 A4 A5 A6 A7 A8 } +impl_join_tuple! { A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 } +impl_join_tuple! { A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 } +impl_join_tuple! { A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 } #[cfg(test)] mod test { @@ -282,39 +99,4 @@ mod test { assert_eq!((a, b, c).join().await, ("hello", "world", 12)); }); } - - #[test] - fn does_not_leak_memory() { - use core::cell::RefCell; - use futures_lite::future::pending; - - thread_local! { - static NOT_LEAKING: RefCell = RefCell::new(false); - }; - - struct FlipFlagAtDrop; - impl Drop for FlipFlagAtDrop { - fn drop(&mut self) { - NOT_LEAKING.with(|v| { - *v.borrow_mut() = true; - }); - } - } - - futures_lite::future::block_on(async { - // this will trigger Miri if we don't drop the memory - let string = future::ready("memory leak".to_owned()); - - // this will not flip the thread_local flag if we don't drop the memory - let flip = future::ready(FlipFlagAtDrop); - - let leak = (string, flip, pending::()).join(); - - _ = futures_lite::future::poll_once(leak).await; - }); - - NOT_LEAKING.with(|flag| { - assert!(*flag.borrow()); - }) - } } diff --git a/src/future/join/vec.rs b/src/future/join/vec.rs index 31d643b..1310ebf 100644 --- a/src/future/join/vec.rs +++ b/src/future/join/vec.rs @@ -1,15 +1,9 @@ -use super::Join as JoinTrait; -use crate::utils::{iter_pin_mut_vec, PollVec, WakerVec}; +use super::super::common::{CombinatorBehaviorVec, CombinatorVec}; +use super::{Join as JoinTrait, JoinBehavior}; -use core::fmt; use core::future::{Future, IntoFuture}; -use core::pin::Pin; -use core::task::{Context, Poll}; -use std::mem::{self, MaybeUninit}; use std::vec::Vec; -use pin_project::{pin_project, pinned_drop}; - /// Waits for two similarly-typed futures to complete. /// /// This `struct` is created by the [`join`] method on the [`Join`] trait. See @@ -17,37 +11,25 @@ use pin_project::{pin_project, pinned_drop}; /// /// [`join`]: crate::future::Join::join /// [`Join`]: crate::future::Join -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[pin_project(PinnedDrop)] -pub struct Join -where - Fut: Future, -{ - consumed: bool, - pending: usize, - items: Vec::Output>>, - wakers: WakerVec, - state: PollVec, - #[pin] - futures: Vec, -} +pub type Join = CombinatorVec; -impl Join +impl CombinatorBehaviorVec for JoinBehavior where Fut: Future, { - pub(crate) fn new(futures: Vec) -> Self { - let len = futures.len(); - Join { - consumed: false, - pending: len, - items: std::iter::repeat_with(MaybeUninit::uninit) - .take(len) - .collect(), - wakers: WakerVec::new(len), - state: PollVec::new(len), - futures, - } + type Output = Vec; + + type StoredItem = Fut::Output; + + fn maybe_return( + _idx: usize, + res: ::Output, + ) -> Result { + Ok(res) + } + + fn when_completed(vec: Vec) -> Self::Output { + vec } } @@ -63,118 +45,11 @@ where } } -impl fmt::Debug for Join -where - Fut: Future + fmt::Debug, - Fut::Output: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.state.iter()).finish() - } -} - -impl Future for Join -where - Fut: Future, -{ - type Output = Vec; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - assert!( - !*this.consumed, - "Futures must not be polled after completing" - ); - - let mut readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_waker(cx.waker()); - if !readiness.any_ready() { - // Nothing is ready yet - return Poll::Pending; - } - - // Poll all ready futures - let futures = this.futures.as_mut(); - let states = &mut this.state[..]; - for (i, fut) in iter_pin_mut_vec(futures).enumerate() { - if states[i].is_pending() && readiness.clear_ready(i) { - // unlock readiness so we don't deadlock when polling - drop(readiness); - - // Obtain the intermediate waker. - let mut cx = Context::from_waker(this.wakers.get(i).unwrap()); - - if let Poll::Ready(value) = fut.poll(&mut cx) { - this.items[i] = MaybeUninit::new(value); - states[i].set_ready(); - *this.pending -= 1; - } - - // Lock readiness so we can use it again - readiness = this.wakers.readiness().lock().unwrap(); - } - } - - // Check whether we're all done now or need to keep going. - if *this.pending == 0 { - // Mark all data as "consumed" before we take it - *this.consumed = true; - this.state.iter_mut().for_each(|state| { - debug_assert!( - state.is_ready(), - "Future should have reached a `Ready` state" - ); - state.set_consumed(); - }); - - // SAFETY: we've checked with the state that all of our outputs have been - // filled, which means we're ready to take the data and assume it's initialized. - let items = unsafe { - let items = mem::take(this.items); - mem::transmute::<_, Vec>(items) - }; - Poll::Ready(items) - } else { - Poll::Pending - } - } -} - -/// Drop the already initialized values on cancellation. -#[pinned_drop] -impl PinnedDrop for Join -where - Fut: Future, -{ - fn drop(self: Pin<&mut Self>) { - let this = self.project(); - - // Get the indexes of the initialized values. - let indexes = this - .state - .iter_mut() - .enumerate() - .filter(|(_, state)| state.is_ready()) - .map(|(i, _)| i); - - // Drop each value at the index. - for i in indexes { - // SAFETY: we've just filtered down to *only* the initialized values. - // We can assume they're initialized, and this is where we drop them. - unsafe { this.items[i].assume_init_drop() }; - } - } -} - #[cfg(test)] mod test { use super::*; - use crate::utils::dummy_waker; use std::future; - use std::future::Future; - use std::task::Context; #[test] fn smoke() { @@ -183,16 +58,4 @@ mod test { assert_eq!(fut.await, vec!["hello", "world"]); }); } - - #[test] - fn debug() { - let mut fut = vec![future::ready("hello"), future::ready("world")].join(); - assert_eq!(format!("{:?}", fut), "[Pending, Pending]"); - let mut fut = Pin::new(&mut fut); - - let waker = dummy_waker(); - let mut cx = Context::from_waker(&waker); - let _ = fut.as_mut().poll(&mut cx); - assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]"); - } } diff --git a/src/future/mod.rs b/src/future/mod.rs index f590be8..8eed438 100644 --- a/src/future/mod.rs +++ b/src/future/mod.rs @@ -73,6 +73,7 @@ pub use race::Race; pub use race_ok::RaceOk; pub use try_join::TryJoin; +mod common; pub(crate) mod join; pub(crate) mod race; pub(crate) mod race_ok; diff --git a/src/future/race/array.rs b/src/future/race/array.rs index 0c66962..8cd225b 100644 --- a/src/future/race/array.rs +++ b/src/future/race/array.rs @@ -1,13 +1,7 @@ -use crate::utils::{self, Indexer}; +use super::super::common::{CombinatorArray, CombinatorBehaviorArray}; +use super::{Race as RaceTrait, RaceBehavior}; -use super::Race as RaceTrait; - -use core::fmt; use core::future::{Future, IntoFuture}; -use core::pin::Pin; -use core::task::{Context, Poll}; - -use pin_project::pin_project; /// Wait for the first future to complete. /// @@ -16,49 +10,25 @@ use pin_project::pin_project; /// /// [`race`]: crate::future::Race::race /// [`Race`]: crate::future::Race -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[pin_project] -pub struct Race -where - Fut: Future, -{ - #[pin] - futures: [Fut; N], - indexer: Indexer, - done: bool, -} +pub type Race = CombinatorArray; -impl fmt::Debug for Race -where - Fut: Future + fmt::Debug, - Fut::Output: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.futures.iter()).finish() - } -} - -impl Future for Race +impl CombinatorBehaviorArray for RaceBehavior where Fut: Future, { type Output = Fut::Output; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - assert!(!*this.done, "Futures must not be polled after completing"); + type StoredItem = core::convert::Infallible; + + fn maybe_return( + _idx: usize, + res: ::Output, + ) -> Result { + Err(res) + } - for index in this.indexer.iter() { - let fut = utils::get_pin_mut(this.futures.as_mut(), index).unwrap(); - match fut.poll(cx) { - Poll::Ready(item) => { - *this.done = true; - return Poll::Ready(item); - } - Poll::Pending => continue, - } - } - Poll::Pending + fn when_completed(_arr: [Self::StoredItem; N]) -> Self::Output { + panic!("race only works on non-empty arrays"); } } @@ -70,11 +40,7 @@ where type Future = Race; fn race(self) -> Self::Future { - Race { - futures: self.map(|fut| fut.into_future()), - indexer: Indexer::new(N), - done: false, - } + Race::new(self.map(IntoFuture::into_future)) } } diff --git a/src/future/race/mod.rs b/src/future/race/mod.rs index 19ca85c..8710869 100644 --- a/src/future/race/mod.rs +++ b/src/future/race/mod.rs @@ -23,3 +23,6 @@ pub trait Race { /// This function returns a new future which polls all futures concurrently. fn race(self) -> Self::Future; } + +#[derive(Debug)] +pub struct RaceBehavior; diff --git a/src/future/race/tuple.rs b/src/future/race/tuple.rs index 0b501ac..26e7602 100644 --- a/src/future/race/tuple.rs +++ b/src/future/race/tuple.rs @@ -1,108 +1,60 @@ -use super::Race as RaceTrait; -use crate::utils; +use super::super::common::{CombineTuple, TupleMaybeReturn, TupleWhenCompleted}; +use super::{Race as RaceTrait, RaceBehavior}; -use core::fmt::{self, Debug}; +use core::convert::Infallible; use core::future::{Future, IntoFuture}; -use core::pin::Pin; -use core::task::{Context, Poll}; - -use pin_project::pin_project; +use core::marker::PhantomData; + +impl TupleMaybeReturn for RaceBehavior { + // We early return as soon as any subfuture finishes. + // Results from subfutures are never stored. + type StoredItem = Infallible; + fn maybe_return(_: usize, res: T) -> Result { + // Err = early return. + Err(res) + } +} +impl TupleWhenCompleted for RaceBehavior { + // We always early return, so we should never get here. + fn when_completed(_: S) -> O { + unreachable!() // should have early returned + } +} macro_rules! impl_race_tuple { - ($StructName:ident $($F:ident)+) => { - /// Wait for the first future to complete. - /// - /// This `struct` is created by the [`race`] method on the [`Race`] trait. See - /// its documentation for more. - /// - /// [`race`]: crate::future::Race::race - /// [`Race`]: crate::future::Race - #[pin_project] - #[must_use = "futures do nothing unless you `.await` or poll them"] - #[allow(non_snake_case)] - pub struct $StructName - where $( - $F: Future, - )* { - done: bool, - indexer: utils::Indexer, - $(#[pin] $F: $F,)* - } - - impl Debug for $StructName - where $( - $F: Future + Debug, - T: Debug, - )* { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Race") - $(.field(&self.$F))* - .finish() - } - } - - impl RaceTrait for ($($F,)*) + ($($F:ident)+) => { + impl RaceTrait for ($($F,)+) where $( $F: IntoFuture, - )* { - type Output = T; - type Future = $StructName; - + )+ { + type Output = ::Output; + type Future = <(($($F::IntoFuture,)+), RaceBehavior, PhantomData) as CombineTuple>::Combined; fn race(self) -> Self::Future { - let ($($F,)*): ($($F,)*) = self; - $StructName { - done: false, - indexer: utils::Indexer::new(utils::tuple_len!($($F,)*)), - $($F: $F.into_future()),* - } - } - } - - impl Future for $StructName - where - $($F: Future),* - { - type Output = T; - - fn poll( - self: Pin<&mut Self>, cx: &mut Context<'_> - ) -> Poll { - let mut this = self.project(); - assert!(!*this.done, "Futures must not be polled after completing"); - - #[repr(usize)] - enum Indexes { - $($F),* - } - - for i in this.indexer.iter() { - utils::gen_conditions!(i, this, cx, poll, $((Indexes::$F as usize; $F, { - Poll::Ready(output) => { - *this.done = true; - return Poll::Ready(output); - }, - _ => continue, - }))*); - } - - Poll::Pending + let ($($F,)+) = self; + ( + ( + $($F.into_future(),)+ + ), + RaceBehavior, + PhantomData + ).combine() } } }; } -impl_race_tuple! { Race1 A } -impl_race_tuple! { Race2 A B } -impl_race_tuple! { Race3 A B C } -impl_race_tuple! { Race4 A B C D } -impl_race_tuple! { Race5 A B C D E } -impl_race_tuple! { Race6 A B C D E F } -impl_race_tuple! { Race7 A B C D E F G } -impl_race_tuple! { Race8 A B C D E F G H } -impl_race_tuple! { Race9 A B C D E F G H I } -impl_race_tuple! { Race10 A B C D E F G H I J } -impl_race_tuple! { Race11 A B C D E F G H I J K } -impl_race_tuple! { Race12 A B C D E F G H I J K L } +impl_race_tuple! { A0 } +impl_race_tuple! { A0 A1 } +impl_race_tuple! { A0 A1 A2 } +impl_race_tuple! { A0 A1 A2 A3 } +impl_race_tuple! { A0 A1 A2 A3 A4 } +impl_race_tuple! { A0 A1 A2 A3 A4 A5 } +impl_race_tuple! { A0 A1 A2 A3 A4 A5 A6 } +impl_race_tuple! { A0 A1 A2 A3 A4 A5 A6 A7 } +impl_race_tuple! { A0 A1 A2 A3 A4 A5 A6 A7 A8 } +impl_race_tuple! { A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 } +impl_race_tuple! { A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 } +impl_race_tuple! { A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 } #[cfg(test)] mod test { diff --git a/src/future/race/vec.rs b/src/future/race/vec.rs index de664d8..327fa3f 100644 --- a/src/future/race/vec.rs +++ b/src/future/race/vec.rs @@ -1,13 +1,7 @@ -use crate::utils::{self, Indexer}; +use super::super::common::{CombinatorBehaviorVec, CombinatorVec}; +use super::{Race as RaceTrait, RaceBehavior}; -use super::Race as RaceTrait; - -use core::fmt; use core::future::{Future, IntoFuture}; -use core::pin::Pin; -use core::task::{Context, Poll}; - -use pin_project::pin_project; /// Wait for the first future to complete. /// @@ -16,49 +10,25 @@ use pin_project::pin_project; /// /// [`race`]: crate::future::Race::race /// [`Race`]: crate::future::Race -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[pin_project] -pub struct Race -where - Fut: Future, -{ - #[pin] - futures: Vec, - indexer: Indexer, - done: bool, -} +pub type Race = CombinatorVec; -impl fmt::Debug for Race -where - Fut: Future + fmt::Debug, - Fut::Output: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.futures.iter()).finish() - } -} - -impl Future for Race +impl CombinatorBehaviorVec for RaceBehavior where Fut: Future, { type Output = Fut::Output; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - assert!(!*this.done, "Futures must not be polled after completing"); + type StoredItem = core::convert::Infallible; + + fn maybe_return( + _idx: usize, + res: ::Output, + ) -> Result { + Err(res) + } - for index in this.indexer.iter() { - let fut = utils::get_pin_mut_from_vec(this.futures.as_mut(), index).unwrap(); - match fut.poll(cx) { - Poll::Ready(item) => { - *this.done = true; - return Poll::Ready(item); - } - Poll::Pending => continue, - } - } - Poll::Pending + fn when_completed(_vec: Vec) -> Self::Output { + panic!("race only works on non-empty arrays"); } } @@ -70,11 +40,7 @@ where type Future = Race; fn race(self) -> Self::Future { - Race { - indexer: Indexer::new(self.len()), - futures: self.into_iter().map(|fut| fut.into_future()).collect(), - done: false, - } + Race::new(self.into_iter().map(IntoFuture::into_future).collect()) } } diff --git a/src/future/race_ok/array.rs b/src/future/race_ok/array.rs new file mode 100644 index 0000000..4f3b548 --- /dev/null +++ b/src/future/race_ok/array.rs @@ -0,0 +1,113 @@ +use super::super::common::{CombinatorArray, CombinatorBehaviorArray}; +use super::error::AggregateError; +use super::{RaceOk as RaceOkTrait, RaceOkBehavior}; + +use core::future::{Future, IntoFuture}; + +/// Wait for the first successful future to complete. +/// +/// This `struct` is created by the [`race_ok`] method on the [`RaceOk`] trait. See +/// its documentation for more. +/// +/// [`race_ok`]: crate::future::RaceOk::race_ok +/// [`RaceOk`]: crate::future::RaceOk +pub type RaceOk = CombinatorArray; + +impl CombinatorBehaviorArray for RaceOkBehavior +where + Fut: Future>, +{ + type Output = Result>; + + type StoredItem = E; + + fn maybe_return( + _idx: usize, + res: ::Output, + ) -> Result { + match res { + Ok(v) => Err(Ok(v)), + Err(e) => Ok(e), + } + } + + fn when_completed(errors: [Self::StoredItem; N]) -> Self::Output { + Err(AggregateError { errors }) + } +} + +impl RaceOkTrait for [Fut; N] +where + Fut: IntoFuture, + Fut::IntoFuture: Future>, +{ + type Ok = T; + type Error = AggregateError<[E; N]>; + type Future = RaceOk; + + fn race_ok(self) -> Self::Future { + RaceOk::new(self.map(IntoFuture::into_future)) + } +} + +mod err { + use std::{error::Error, fmt::Display}; + + use crate::future::race_ok::error::AggregateError; + impl Display for AggregateError<[E; N]> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "multiple errors occurred: [")?; + for e in self.errors.iter() { + write!(f, "\n{}", e)?; + } + write!(f, "]") + } + } + + impl Error for AggregateError<[E; N]> {} +} + +#[cfg(test)] +mod test { + use super::*; + use std::future; + use std::io::{Error, ErrorKind}; + + #[test] + fn all_ok() { + futures_lite::future::block_on(async { + let res = [ + future::ready(Ok::<_, ()>("hello")), + future::ready(Ok("world")), + ] + .race_ok() + .await; + assert!(res.is_ok()); + }) + } + + #[test] + fn one_err() { + futures_lite::future::block_on(async { + let err = Error::new(ErrorKind::Other, "oh no"); + let res = [future::ready(Ok("hello")), future::ready(Err(err))] + .race_ok() + .await; + assert_eq!(res.unwrap(), "hello"); + }); + } + + #[test] + fn all_err() { + futures_lite::future::block_on(async { + let err1 = Error::new(ErrorKind::Other, "oops"); + let err2 = Error::new(ErrorKind::Other, "oh no"); + let res = [future::ready(Err::<(), _>(err1)), future::ready(Err(err2))] + .race_ok() + .await; + let err = res.unwrap_err(); + assert_eq!(err.errors[0].to_string(), "oops"); + assert_eq!(err.errors[1].to_string(), "oh no"); + }); + } +} diff --git a/src/future/race_ok/array/error.rs b/src/future/race_ok/array/error.rs deleted file mode 100644 index 23e3d6e..0000000 --- a/src/future/race_ok/array/error.rs +++ /dev/null @@ -1,54 +0,0 @@ -use core::fmt; -use core::ops::{Deref, DerefMut}; -use std::error::Error; - -/// A collection of errors. -#[repr(transparent)] -pub struct AggregateError { - inner: [E; N], -} - -impl AggregateError { - pub(super) fn new(inner: [E; N]) -> Self { - Self { inner } - } -} - -impl fmt::Debug for AggregateError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "{self}:")?; - - for (i, err) in self.inner.iter().enumerate() { - writeln!(f, "- Error {}: {err}", i + 1)?; - let mut source = err.source(); - while let Some(err) = source { - writeln!(f, " ↳ Caused by: {err}")?; - source = err.source(); - } - } - - Ok(()) - } -} - -impl fmt::Display for AggregateError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} errors occured", self.inner.len()) - } -} - -impl Deref for AggregateError { - type Target = [E; N]; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for AggregateError { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -impl std::error::Error for AggregateError {} diff --git a/src/future/race_ok/array/mod.rs b/src/future/race_ok/array/mod.rs deleted file mode 100644 index 08f4f15..0000000 --- a/src/future/race_ok/array/mod.rs +++ /dev/null @@ -1,151 +0,0 @@ -use super::RaceOk as RaceOkTrait; -use crate::utils::array_assume_init; -use crate::utils::iter_pin_mut; - -use core::array; -use core::fmt; -use core::future::{Future, IntoFuture}; -use core::mem::{self, MaybeUninit}; -use core::pin::Pin; -use core::task::{Context, Poll}; - -use pin_project::pin_project; - -mod error; - -pub use error::AggregateError; - -/// Wait for the first successful future to complete. -/// -/// This `struct` is created by the [`race_ok`] method on the [`RaceOk`] trait. See -/// its documentation for more. -/// -/// [`race_ok`]: crate::future::RaceOk::race_ok -/// [`RaceOk`]: crate::future::RaceOk -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[pin_project] -pub struct RaceOk -where - T: fmt::Debug, - Fut: Future>, -{ - #[pin] - futures: [Fut; N], - errors: [MaybeUninit; N], - completed: usize, -} - -impl fmt::Debug for RaceOk -where - Fut: Future> + fmt::Debug, - Fut::Output: fmt::Debug, - T: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.futures.iter()).finish() - } -} - -impl Future for RaceOk -where - T: fmt::Debug, - Fut: Future>, - E: fmt::Debug, -{ - type Output = Result>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - let futures = iter_pin_mut(this.futures); - - for (fut, out) in futures.zip(this.errors.iter_mut()) { - if let Poll::Ready(output) = fut.poll(cx) { - match output { - Ok(ok) => return Poll::Ready(Ok(ok)), - Err(err) => { - *out = MaybeUninit::new(err); - *this.completed += 1; - } - } - } - } - - let all_completed = *this.completed == N; - if all_completed { - let mut errors = array::from_fn(|_| MaybeUninit::uninit()); - mem::swap(&mut errors, this.errors); - - // SAFETY: we know that all futures are properly initialized because they're all completed - let result = unsafe { array_assume_init(errors) }; - - Poll::Ready(Err(AggregateError::new(result))) - } else { - Poll::Pending - } - } -} - -impl RaceOkTrait for [Fut; N] -where - T: fmt::Debug, - Fut: IntoFuture>, - E: fmt::Debug, -{ - type Output = T; - type Error = AggregateError; - type Future = RaceOk; - - fn race_ok(self) -> Self::Future { - RaceOk { - futures: self.map(|fut| fut.into_future()), - errors: array::from_fn(|_| MaybeUninit::uninit()), - completed: 0, - } - } -} - -#[cfg(test)] -mod test { - use super::*; - use std::future; - use std::io::{Error, ErrorKind}; - - #[test] - fn all_ok() { - futures_lite::future::block_on(async { - let res: Result<&str, AggregateError> = - [future::ready(Ok("hello")), future::ready(Ok("world"))] - .race_ok() - .await; - assert!(res.is_ok()); - }) - } - - #[test] - fn one_err() { - futures_lite::future::block_on(async { - let err = Error::new(ErrorKind::Other, "oh no"); - let res: Result<&str, AggregateError> = - [future::ready(Ok("hello")), future::ready(Err(err))] - .race_ok() - .await; - assert_eq!(res.unwrap(), "hello"); - }); - } - - #[test] - fn all_err() { - futures_lite::future::block_on(async { - let err1 = Error::new(ErrorKind::Other, "oops"); - let err2 = Error::new(ErrorKind::Other, "oh no"); - let res: Result<&str, AggregateError> = - [future::ready(Err(err1)), future::ready(Err(err2))] - .race_ok() - .await; - let errs = res.unwrap_err(); - assert_eq!(errs[0].to_string(), "oops"); - assert_eq!(errs[1].to_string(), "oh no"); - }); - } -} diff --git a/src/future/race_ok/error.rs b/src/future/race_ok/error.rs new file mode 100644 index 0000000..62973ee --- /dev/null +++ b/src/future/race_ok/error.rs @@ -0,0 +1,21 @@ +/// A collection of errors returned when [super::RaceOk] fails. +/// +/// Example: +/// ``` +/// # use std::future; +/// # use std::io::{Error, ErrorKind}; +/// # use futures_concurrency::errors::AggregateError; +/// # use futures_concurrency::future::RaceOk; +/// # futures_lite::future::block_on(async { +/// let err = Error::new(ErrorKind::Other, "oh no"); +/// let res: Result<&str, AggregateError<[Error; 2]>> = [future::ready(Ok("hello")), future::ready(Err(err))] +/// .race_ok() +/// .await; +/// assert_eq!(res.unwrap(), "hello"); +/// # }); +/// ``` +#[derive(Debug)] +pub struct AggregateError { + /// The errors. Can be a Vec, and Array, or a Tuple. + pub errors: E, +} diff --git a/src/future/race_ok/mod.rs b/src/future/race_ok/mod.rs index 5bc3e99..9ce8c1d 100644 --- a/src/future/race_ok/mod.rs +++ b/src/future/race_ok/mod.rs @@ -1,6 +1,7 @@ use core::future::Future; pub(crate) mod array; +pub(crate) mod error; pub(crate) mod tuple; pub(crate) mod vec; @@ -11,14 +12,17 @@ pub(crate) mod vec; /// aggregate error of all failed futures. pub trait RaceOk { /// The resulting output type. - type Output; + type Ok; /// The resulting error type. type Error; /// Which kind of future are we turning this into? - type Future: Future>; + type Future: Future>; /// Waits for the first successful future to complete. fn race_ok(self) -> Self::Future; } + +#[derive(Debug)] +pub struct RaceOkBehavior; diff --git a/src/future/race_ok/tuple.rs b/src/future/race_ok/tuple.rs new file mode 100644 index 0000000..8ab855e --- /dev/null +++ b/src/future/race_ok/tuple.rs @@ -0,0 +1,124 @@ +use super::super::common::{CombineTuple, TupleMaybeReturn, TupleWhenCompleted}; +use super::error::AggregateError; +use super::{RaceOk as RaceOkTrait, RaceOkBehavior}; + +use core::future::IntoFuture; +use core::marker::PhantomData; +use std::{error::Error, fmt::Display}; + +impl TupleMaybeReturn, Result> for RaceOkBehavior { + type StoredItem = E; + fn maybe_return(_: usize, res: Result) -> Result> { + match res { + // If subfuture returns Ok we want to early return from the combinator. + // We do this by returning Err to the combinator. + Ok(t) => Err(Ok(t)), + // If subfuture returns Err, we keep the error for potential use in AggregateError. + Err(e) => Ok(e), + } + } +} +impl TupleWhenCompleted>> for RaceOkBehavior { + // If we get here, it must have been that none of the subfutures early returned. + // This means all of them failed. In this case we returned an AggregateError with the errors we kept. + fn when_completed(errors: AggE) -> Result> { + Err(AggregateError { errors }) + } +} + +macro_rules! impl_race_ok_tuple { + ($(($F:ident $E:ident))+) => { + impl RaceOkTrait for ($($F,)+) + where $( + $F: IntoFuture>, + )+ { + type Ok = T; + type Error = AggregateError<($($E, )+)>; + type Future = <(($($F::IntoFuture,)+), RaceOkBehavior, PhantomData>>) as CombineTuple>::Combined; + fn race_ok(self) -> Self::Future { + let ($($F,)+) = self; + ( + ( + $($F.into_future(),)+ + ), + RaceOkBehavior, + PhantomData + ).combine() + } + } + impl<$($E: Display),+> Display for AggregateError<($($E,)+)> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "multiple errors occurred: [")?; + let ($($E,)+) = &self.errors; + $( + write!(f, "{}", $E)?; + )+ + write!(f, "]") + } + } + + impl<$($E: Error),+> Error for AggregateError<($($E,)+)> {} + }; +} + +impl_race_ok_tuple! { (A0 E0) } +impl_race_ok_tuple! { (A0 E0) (A1 E1) } +impl_race_ok_tuple! { (A0 E0) (A1 E1) (A2 E2) } +impl_race_ok_tuple! { (A0 E0) (A1 E1) (A2 E2) (A3 E3) } +impl_race_ok_tuple! { (A0 E0) (A1 E1) (A2 E2) (A3 E3) (A4 E4) } +impl_race_ok_tuple! { (A0 E0) (A1 E1) (A2 E2) (A3 E3) (A4 E4) (A5 E5) } +impl_race_ok_tuple! { (A0 E0) (A1 E1) (A2 E2) (A3 E3) (A4 E4) (A5 E5) (A6 E6) } +impl_race_ok_tuple! { (A0 E0) (A1 E1) (A2 E2) (A3 E3) (A4 E4) (A5 E5) (A6 E6) (A7 E7) } +impl_race_ok_tuple! { (A0 E0) (A1 E1) (A2 E2) (A3 E3) (A4 E4) (A5 E5) (A6 E6) (A7 E7) (A8 E8) } +impl_race_ok_tuple! { (A0 E0) (A1 E1) (A2 E2) (A3 E3) (A4 E4) (A5 E5) (A6 E6) (A7 E7) (A8 E8) (A9 E9) } +impl_race_ok_tuple! { (A0 E0) (A1 E1) (A2 E2) (A3 E3) (A4 E4) (A5 E5) (A6 E6) (A7 E7) (A8 E8) (A9 E9) (A10 E10) } + +#[cfg(test)] +mod test { + use super::*; + use core::future; + use std::error::Error; + + type DynError = Box; + + #[test] + fn race_ok_1() { + futures_lite::future::block_on(async { + let a = async { Ok::<_, DynError>("world") }; + let res = (a,).race_ok().await; + assert!(matches!(res.unwrap(), "world")); + }); + } + + #[test] + fn race_ok_2() { + futures_lite::future::block_on(async { + let a = future::pending::>(); + let b = async { Ok::<_, DynError>("world") }; + let res = (a, b).race_ok().await; + assert!(matches!(res.unwrap(), "world")); + }); + } + + #[test] + fn race_ok_3() { + futures_lite::future::block_on(async { + let a = future::pending::>(); + let b = async { Ok::<_, DynError>("hello") }; + let c = async { Ok::<_, DynError>("world") }; + let result = (a, b, c).race_ok().await; + assert!(matches!(result.unwrap(), "hello" | "world")); + }); + } + + #[test] + fn race_ok_err() { + futures_lite::future::block_on(async { + let a = async { Err::<(), _>("hello") }; + let b = async { Err::<(), _>("world") }; + let AggregateError { errors } = (a, b).race_ok().await.unwrap_err(); + assert_eq!(errors.0, "hello"); + assert_eq!(errors.1, "world"); + }); + } +} diff --git a/src/future/race_ok/tuple/error.rs b/src/future/race_ok/tuple/error.rs deleted file mode 100644 index 23e3d6e..0000000 --- a/src/future/race_ok/tuple/error.rs +++ /dev/null @@ -1,54 +0,0 @@ -use core::fmt; -use core::ops::{Deref, DerefMut}; -use std::error::Error; - -/// A collection of errors. -#[repr(transparent)] -pub struct AggregateError { - inner: [E; N], -} - -impl AggregateError { - pub(super) fn new(inner: [E; N]) -> Self { - Self { inner } - } -} - -impl fmt::Debug for AggregateError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "{self}:")?; - - for (i, err) in self.inner.iter().enumerate() { - writeln!(f, "- Error {}: {err}", i + 1)?; - let mut source = err.source(); - while let Some(err) = source { - writeln!(f, " ↳ Caused by: {err}")?; - source = err.source(); - } - } - - Ok(()) - } -} - -impl fmt::Display for AggregateError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} errors occured", self.inner.len()) - } -} - -impl Deref for AggregateError { - type Target = [E; N]; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for AggregateError { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -impl std::error::Error for AggregateError {} diff --git a/src/future/race_ok/tuple/mod.rs b/src/future/race_ok/tuple/mod.rs deleted file mode 100644 index 4e64142..0000000 --- a/src/future/race_ok/tuple/mod.rs +++ /dev/null @@ -1,220 +0,0 @@ -use super::RaceOk; -use crate::utils::{self, PollArray}; - -use core::array; -use core::fmt; -use core::future::{Future, IntoFuture}; -use core::mem::{self, MaybeUninit}; -use core::pin::Pin; -use core::task::{Context, Poll}; - -use pin_project::{pin_project, pinned_drop}; - -mod error; -pub(crate) use error::AggregateError; - -macro_rules! impl_race_ok_tuple { - ($StructName:ident $($F:ident)+) => { - /// Wait for the first successful future to complete. - /// - /// This `struct` is created by the [`race_ok`] method on the [`RaceOk`] trait. See - /// its documentation for more. - /// - /// [`race_ok`]: crate::future::RaceOk::race_ok - /// [`RaceOk`]: crate::future::RaceOk - #[must_use = "futures do nothing unless you `.await` or poll them"] - #[allow(non_snake_case)] - #[pin_project(PinnedDrop)] - pub struct $StructName - where - $( $F: Future>, )* - ERR: fmt::Debug, - { - completed: usize, - done: bool, - indexer: utils::Indexer, - errors: [MaybeUninit; { utils::tuple_len!($($F,)*) }], - errors_states: PollArray<{ utils::tuple_len!($($F,)*) }>, - $( #[pin] $F: $F, )* - } - - impl fmt::Debug for $StructName - where - $( $F: Future> + fmt::Debug, )* - T: fmt::Debug, - ERR: fmt::Debug, - { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Race") - $(.field(&self.$F))* - .finish() - } - } - - impl RaceOk for ($($F,)*) - where - $( $F: IntoFuture>, )* - ERR: fmt::Debug, - { - type Output = T; - type Error = AggregateError; - type Future = $StructName; - - fn race_ok(self) -> Self::Future { - let ($($F,)*): ($($F,)*) = self; - $StructName { - completed: 0, - done: false, - indexer: utils::Indexer::new(utils::tuple_len!($($F,)*)), - errors: array::from_fn(|_| MaybeUninit::uninit()), - errors_states: PollArray::new(), - $($F: $F.into_future()),* - } - } - } - - impl Future for $StructName - where - $( $F: Future>, )* - ERR: fmt::Debug, - { - type Output = Result>; - - fn poll( - self: Pin<&mut Self>, cx: &mut Context<'_> - ) -> Poll { - const LEN: usize = utils::tuple_len!($($F,)*); - - let mut this = self.project(); - - let can_poll = !*this.done; - assert!(can_poll, "Futures must not be polled after completing"); - - #[repr(usize)] - enum Indexes { - $($F),* - } - - for i in this.indexer.iter() { - utils::gen_conditions!(i, this, cx, poll, $((Indexes::$F as usize; $F, { - Poll::Ready(output) => match output { - Ok(output) => { - *this.done = true; - *this.completed += 1; - return Poll::Ready(Ok(output)); - }, - Err(err) => { - this.errors[i] = MaybeUninit::new(err); - this.errors_states[i].set_ready(); - *this.completed += 1; - continue; - }, - }, - _ => continue, - }))*); - } - - let all_completed = *this.completed == LEN; - if all_completed { - // mark all error states as consumed before we return it - this.errors_states.set_all_completed(); - - let mut errors = array::from_fn(|_| MaybeUninit::uninit()); - mem::swap(&mut errors, this.errors); - - let result = unsafe { utils::array_assume_init(errors) }; - - *this.done = true; - return Poll::Ready(Err(AggregateError::new(result))); - } - - Poll::Pending - } - } - - #[pinned_drop] - impl PinnedDrop for $StructName - where - $( $F: Future>, )* - ERR: fmt::Debug, - { - fn drop(self: Pin<&mut Self>) { - let this = self.project(); - - this - .errors_states - .iter_mut() - .zip(this.errors.iter_mut()) - .filter(|(st, _err)| st.is_ready()) - .for_each(|(st, err)| { - // SAFETY: we've filtered down to only the `ready`/initialized data - unsafe { err.assume_init_drop() }; - st.set_consumed(); - }); - } - } - }; -} - -impl_race_ok_tuple! { RaceOk1 A } -impl_race_ok_tuple! { RaceOk2 A B } -impl_race_ok_tuple! { RaceOk3 A B C } -impl_race_ok_tuple! { RaceOk4 A B C D } -impl_race_ok_tuple! { RaceOk5 A B C D E } -impl_race_ok_tuple! { RaceOk6 A B C D E F } -impl_race_ok_tuple! { RaceOk7 A B C D E F G } -impl_race_ok_tuple! { RaceOk8 A B C D E F G H } -impl_race_ok_tuple! { RaceOk9 A B C D E F G H I } -impl_race_ok_tuple! { RaceOk10 A B C D E F G H I J } -impl_race_ok_tuple! { RaceOk11 A B C D E F G H I J K } -impl_race_ok_tuple! { RaceOk12 A B C D E F G H I J K L } - -#[cfg(test)] -mod test { - use super::*; - use core::future; - use std::error::Error; - - type DynError = Box; - - #[test] - fn race_ok_1() { - futures_lite::future::block_on(async { - let a = async { Ok::<_, DynError>("world") }; - let res = (a,).race_ok().await; - assert!(matches!(res, Ok("world"))); - }); - } - - #[test] - fn race_ok_2() { - futures_lite::future::block_on(async { - let a = future::pending(); - let b = async { Ok::<_, DynError>("world") }; - let res = (a, b).race_ok().await; - assert!(matches!(res, Ok("world"))); - }); - } - - #[test] - fn race_ok_3() { - futures_lite::future::block_on(async { - let a = future::pending(); - let b = async { Ok::<_, DynError>("hello") }; - let c = async { Ok::<_, DynError>("world") }; - let result = (a, b, c).race_ok().await; - assert!(matches!(result, Ok("hello") | Ok("world"))); - }); - } - - #[test] - fn race_ok_err() { - futures_lite::future::block_on(async { - let a = async { Err::<(), _>("hello") }; - let b = async { Err::<(), _>("world") }; - let errors = (a, b).race_ok().await.unwrap_err(); - assert_eq!(errors[0], "hello"); - assert_eq!(errors[1], "world"); - }); - } -} diff --git a/src/future/race_ok/vec.rs b/src/future/race_ok/vec.rs new file mode 100644 index 0000000..80adc63 --- /dev/null +++ b/src/future/race_ok/vec.rs @@ -0,0 +1,113 @@ +use super::super::common::{CombinatorBehaviorVec, CombinatorVec}; +use super::error::AggregateError; +use super::{RaceOk as RaceOkTrait, RaceOkBehavior}; + +use core::future::{Future, IntoFuture}; +use std::vec::Vec; + +/// Wait for the first successful future to complete. +/// +/// This `struct` is created by the [`race_ok`] method on the [`RaceOk`] trait. See +/// its documentation for more. +/// +/// [`race_ok`]: crate::future::RaceOk::race_ok +/// [`RaceOk`]: crate::future::RaceOk +pub type RaceOk = CombinatorVec; + +impl CombinatorBehaviorVec for RaceOkBehavior +where + Fut: Future>, +{ + type Output = Result>>; + + type StoredItem = E; + + fn maybe_return( + _idx: usize, + res: ::Output, + ) -> Result { + match res { + Ok(v) => Err(Ok(v)), + Err(e) => Ok(e), + } + } + + fn when_completed(errors: Vec) -> Self::Output { + Err(AggregateError { errors }) + } +} + +impl RaceOkTrait for Vec +where + Fut: IntoFuture>, +{ + type Ok = T; + type Error = AggregateError>; + type Future = RaceOk; + + fn race_ok(self) -> Self::Future { + RaceOk::new(self.into_iter().map(IntoFuture::into_future).collect()) + } +} + +mod err { + use std::{error::Error, fmt::Display}; + + use crate::future::race_ok::error::AggregateError; + impl Display for AggregateError> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "multiple errors occurred: [")?; + for e in self.errors.iter() { + write!(f, "\n{}", e)?; + } + write!(f, "]") + } + } + + impl Error for AggregateError> {} +} + +#[cfg(test)] +mod test { + use super::*; + use std::future; + use std::io::{Error, ErrorKind}; + + #[test] + fn all_ok() { + futures_lite::future::block_on(async { + let res = vec![ + future::ready(Ok::<_, ()>("hello")), + future::ready(Ok("world")), + ] + .race_ok() + .await; + assert!(res.is_ok()); + }) + } + + #[test] + fn one_err() { + futures_lite::future::block_on(async { + let err = Error::new(ErrorKind::Other, "oh no"); + let res = vec![future::ready(Ok("hello")), future::ready(Err(err))] + .race_ok() + .await; + assert_eq!(res.unwrap(), "hello"); + }); + } + + #[test] + fn all_err() { + futures_lite::future::block_on(async { + let err1 = Error::new(ErrorKind::Other, "oops"); + let err2 = Error::new(ErrorKind::Other, "oh no"); + let res = vec![future::ready(Err::<(), _>(err1)), future::ready(Err(err2))] + .race_ok() + .await; + let err = res.unwrap_err(); + assert_eq!(err.errors[0].to_string(), "oops"); + assert_eq!(err.errors[1].to_string(), "oh no"); + }); + } +} diff --git a/src/future/race_ok/vec/error.rs b/src/future/race_ok/vec/error.rs deleted file mode 100644 index e5977d1..0000000 --- a/src/future/race_ok/vec/error.rs +++ /dev/null @@ -1,56 +0,0 @@ -use core::fmt; -use std::error::Error; -use std::ops::Deref; -use std::ops::DerefMut; -use std::vec::Vec; - -/// A collection of errors. -#[repr(transparent)] -pub struct AggregateError { - pub(crate) inner: Vec, -} - -impl AggregateError { - pub(crate) fn new(inner: Vec) -> Self { - Self { inner } - } -} - -impl fmt::Debug for AggregateError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "{self}:")?; - - for (i, err) in self.inner.iter().enumerate() { - writeln!(f, "- Error {}: {err}", i + 1)?; - let mut source = err.source(); - while let Some(err) = source { - writeln!(f, " ↳ Caused by: {err}")?; - source = err.source(); - } - } - - Ok(()) - } -} - -impl fmt::Display for AggregateError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} errors occurred", self.inner.len()) - } -} - -impl Deref for AggregateError { - type Target = Vec; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for AggregateError { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -impl Error for AggregateError {} diff --git a/src/future/race_ok/vec/mod.rs b/src/future/race_ok/vec/mod.rs deleted file mode 100644 index 3267959..0000000 --- a/src/future/race_ok/vec/mod.rs +++ /dev/null @@ -1,138 +0,0 @@ -use super::RaceOk as RaceOkTrait; -use crate::utils::iter_pin_mut; -use crate::utils::MaybeDone; - -use core::fmt; -use core::future::{Future, IntoFuture}; -use core::mem; -use core::pin::Pin; -use core::task::{Context, Poll}; -use std::boxed::Box; -use std::vec::Vec; - -pub use error::AggregateError; - -mod error; - -/// Wait for the first successful future to complete. -/// -/// This `struct` is created by the [`race_ok`] method on the [`RaceOk`] trait. See -/// its documentation for more. -/// -/// [`race_ok`]: crate::future::RaceOk::race_ok -/// [`RaceOk`]: crate::future::RaceOk -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct RaceOk -where - Fut: Future>, -{ - elems: Pin]>>, -} - -impl fmt::Debug for RaceOk -where - Fut: Future> + fmt::Debug, - Fut::Output: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.elems.iter()).finish() - } -} - -impl Future for RaceOk -where - T: std::fmt::Debug, - E: fmt::Debug, - Fut: Future>, -{ - type Output = Result>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut all_done = true; - - for mut elem in iter_pin_mut(self.elems.as_mut()) { - if elem.as_mut().poll(cx).is_pending() { - all_done = false - } else if let Some(Ok(_)) = elem.as_ref().output() { - return Poll::Ready(Ok(elem.take().unwrap().unwrap())); - } - } - - if all_done { - let mut elems = mem::replace(&mut self.elems, Box::pin([])); - let result: Vec = iter_pin_mut(elems.as_mut()) - .map(|e| e.take().unwrap().unwrap_err()) - .collect(); - Poll::Ready(Err(AggregateError::new(result))) - } else { - Poll::Pending - } - } -} - -impl RaceOkTrait for Vec -where - T: fmt::Debug, - E: fmt::Debug, - Fut: IntoFuture>, -{ - type Output = T; - type Error = AggregateError; - type Future = RaceOk; - - fn race_ok(self) -> Self::Future { - let elems: Box<[_]> = self - .into_iter() - .map(|fut| MaybeDone::new(fut.into_future())) - .collect(); - RaceOk { - elems: elems.into(), - } - } -} - -#[cfg(test)] -mod test { - use super::error::AggregateError; - use super::*; - use std::future; - use std::io::{Error, ErrorKind}; - - #[test] - fn all_ok() { - futures_lite::future::block_on(async { - let res: Result<&str, AggregateError> = - vec![future::ready(Ok("hello")), future::ready(Ok("world"))] - .race_ok() - .await; - assert!(res.is_ok()); - }) - } - - #[test] - fn one_err() { - futures_lite::future::block_on(async { - let err = Error::new(ErrorKind::Other, "oh no"); - let res: Result<&str, AggregateError> = - vec![future::ready(Ok("hello")), future::ready(Err(err))] - .race_ok() - .await; - assert_eq!(res.unwrap(), "hello"); - }); - } - - #[test] - fn all_err() { - futures_lite::future::block_on(async { - let err1 = Error::new(ErrorKind::Other, "oops"); - let err2 = Error::new(ErrorKind::Other, "oh no"); - let res: Result<&str, AggregateError> = - vec![future::ready(Err(err1)), future::ready(Err(err2))] - .race_ok() - .await; - let errs = res.unwrap_err(); - assert_eq!(errs[0].to_string(), "oops"); - assert_eq!(errs[1].to_string(), "oh no"); - }); - } -} diff --git a/src/future/try_join/array.rs b/src/future/try_join/array.rs index adc919e..10fea2f 100644 --- a/src/future/try_join/array.rs +++ b/src/future/try_join/array.rs @@ -1,12 +1,6 @@ -use super::TryJoin as TryJoinTrait; -use crate::utils::MaybeDone; - -use core::fmt; +use super::super::common::{CombinatorArray, CombinatorBehaviorArray}; +use super::{TryJoin as TryJoinTrait, TryJoinBehavior}; use core::future::{Future, IntoFuture}; -use core::pin::Pin; -use core::task::{Context, Poll}; - -use pin_project::pin_project; /// Wait for all futures to complete successfully, or abort early on error. /// @@ -15,87 +9,41 @@ use pin_project::pin_project; /// /// [`try_join`]: crate::future::TryJoin::try_join /// [`TryJoin`]: crate::future::TryJoin -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[pin_project] -pub struct TryJoin -where - T: fmt::Debug, - Fut: Future>, -{ - elems: [MaybeDone; N], -} +pub type TryJoin = CombinatorArray; -impl fmt::Debug for TryJoin +impl CombinatorBehaviorArray for TryJoinBehavior where - Fut: Future> + fmt::Debug, - Fut::Output: fmt::Debug, - T: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.elems.iter()).finish() - } -} - -impl Future for TryJoin -where - T: fmt::Debug, Fut: Future>, - E: fmt::Debug, { type Output = Result<[T; N], E>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut all_done = true; + type StoredItem = T; - let this = self.project(); - - for elem in this.elems.iter_mut() { - // SAFETY: we don't ever move the pinned container here; we only pin project - let mut elem = unsafe { Pin::new_unchecked(elem) }; - if elem.as_mut().poll(cx).is_pending() { - all_done = false - } else if let Some(Err(_)) = elem.as_ref().output() { - return Poll::Ready(Err(elem.take().unwrap().unwrap_err())); - } + fn maybe_return( + _idx: usize, + res: ::Output, + ) -> Result { + match res { + Ok(v) => Ok(v), + Err(e) => Err(Err(e)), } + } - if all_done { - use core::array; - use core::mem::MaybeUninit; - - // Create the result array based on the indices - // TODO: replace with `MaybeUninit::uninit_array()` when it becomes stable - let mut out: [_; N] = array::from_fn(|_| MaybeUninit::uninit()); - - // NOTE: this clippy attribute can be removed once we can `collect` into `[usize; K]`. - #[allow(clippy::needless_range_loop)] - for (i, el) in this.elems.iter_mut().enumerate() { - // SAFETY: we don't ever move the pinned container here; we only pin project - let el = unsafe { Pin::new_unchecked(el) }.take().unwrap().unwrap(); - out[i] = MaybeUninit::new(el); - } - let result = unsafe { out.as_ptr().cast::<[T; N]>().read() }; - Poll::Ready(Ok(result)) - } else { - Poll::Pending - } + fn when_completed(arr: [Self::StoredItem; N]) -> Self::Output { + Ok(arr) } } -impl TryJoinTrait for [Fut; N] +impl TryJoinTrait for [Fut; N] where - T: std::fmt::Debug, Fut: IntoFuture>, - E: fmt::Debug, { - type Output = [T; N]; + type Ok = [T; N]; type Error = E; - type Future = TryJoin; + type Future = TryJoin; fn try_join(self) -> Self::Future { - TryJoin { - elems: self.map(|fut| MaybeDone::new(fut.into_future())), - } + TryJoin::new(self.map(IntoFuture::into_future)) } } diff --git a/src/future/try_join/mod.rs b/src/future/try_join/mod.rs index 085fb84..373caf3 100644 --- a/src/future/try_join/mod.rs +++ b/src/future/try_join/mod.rs @@ -1,6 +1,7 @@ use core::future::Future; pub(crate) mod array; +pub(crate) mod tuple; pub(crate) mod vec; /// Wait for all futures to complete successfully, or abort early on error. @@ -12,16 +13,19 @@ pub(crate) mod vec; /// operation. pub trait TryJoin { /// The resulting output type. - type Output; + type Ok; /// The resulting error type. type Error; /// Which kind of future are we turning this into? - type Future: Future>; + type Future: Future>; /// Waits for multiple futures to complete, either returning when all /// futures complete successfully, or return early when any future completes /// with an error. fn try_join(self) -> Self::Future; } + +#[derive(Debug)] +pub struct TryJoinBehavior; diff --git a/src/future/try_join/tuple.rs b/src/future/try_join/tuple.rs new file mode 100644 index 0000000..a018127 --- /dev/null +++ b/src/future/try_join/tuple.rs @@ -0,0 +1,89 @@ +use super::super::common::{CombineTuple, TupleMaybeReturn, TupleWhenCompleted}; +use super::{TryJoin as TryJoinTrait, TryJoinBehavior}; + +use core::future::IntoFuture; +use core::marker::PhantomData; + +use futures_core::TryFuture; + +impl TupleMaybeReturn, Result> for TryJoinBehavior { + type StoredItem = T; + fn maybe_return(_: usize, res: Result) -> Result> { + match res { + Ok(t) => Ok(t), + Err(e) => Err(Err(e)), + } + } +} +impl TupleWhenCompleted> for TryJoinBehavior { + fn when_completed(stored_items: AggT) -> Result { + Ok(stored_items) + } +} + +macro_rules! impl_try_join_tuple { + ($(($F:ident $T:ident))+) => { + impl TryJoinTrait for ($($F,)+) + where $( + $F: IntoFuture>, + )+ { + type Ok = ($(<$F::IntoFuture as TryFuture>::Ok,)+); + type Error = E; + type Future = <(($($F::IntoFuture,)+), TryJoinBehavior, PhantomData>) as CombineTuple>::Combined; + fn try_join(self) -> Self::Future { + let ($($F,)+) = self; + ( + ($($F.into_future(),)+), + TryJoinBehavior, + PhantomData + ).combine() + } + } + }; +} + +impl_try_join_tuple! { (A0 T0) } +impl_try_join_tuple! { (A0 T0) (A1 T1) } +impl_try_join_tuple! { (A0 T0) (A1 T1) (A2 T2) } +impl_try_join_tuple! { (A0 T0) (A1 T1) (A2 T2) (A3 T3) } +impl_try_join_tuple! { (A0 T0) (A1 T1) (A2 T2) (A3 T3) (A4 T4) } +impl_try_join_tuple! { (A0 T0) (A1 T1) (A2 T2) (A3 T3) (A4 T4) (A5 T5) } +impl_try_join_tuple! { (A0 T0) (A1 T1) (A2 T2) (A3 T3) (A4 T4) (A5 T5) (A6 T6) } +impl_try_join_tuple! { (A0 T0) (A1 T1) (A2 T2) (A3 T3) (A4 T4) (A5 T5) (A6 T6) (A7 T7) } +impl_try_join_tuple! { (A0 T0) (A1 T1) (A2 T2) (A3 T3) (A4 T4) (A5 T5) (A6 T6) (A7 T7) (A8 T8) } +impl_try_join_tuple! { (A0 T0) (A1 T1) (A2 T2) (A3 T3) (A4 T4) (A5 T5) (A6 T6) (A7 T7) (A8 T8) (A9 T9) } +impl_try_join_tuple! { (A0 T0) (A1 T1) (A2 T2) (A3 T3) (A4 T4) (A5 T5) (A6 T6) (A7 T7) (A8 T8) (A9 T9) (A10 T10) } + +#[cfg(test)] +mod test { + use super::*; + use std::future; + use std::io::{self, Error, ErrorKind}; + + #[test] + fn ok() { + futures_lite::future::block_on(async { + let res = ( + future::ready(Result::<_, io::Error>::Ok(42)), + future::ready(Result::<_, io::Error>::Ok("world")), + ) + .try_join() + .await; + assert_eq!(res.unwrap(), (42, "world")); + }) + } + + #[test] + fn err() { + futures_lite::future::block_on(async { + let err = Error::new(ErrorKind::Other, "oh no"); + let res = ( + future::ready(io::Result::Ok("hello")), + future::ready(Result::::Err(err)), + ) + .try_join() + .await; + assert_eq!(res.unwrap_err().kind(), ErrorKind::Other); + }); + } +} diff --git a/src/future/try_join/vec.rs b/src/future/try_join/vec.rs index 72cc1fb..e5a6a51 100644 --- a/src/future/try_join/vec.rs +++ b/src/future/try_join/vec.rs @@ -1,13 +1,7 @@ -use super::TryJoin as TryJoinTrait; -use crate::utils::iter_pin_mut; -use crate::utils::MaybeDone; +use super::super::common::{CombinatorBehaviorVec, CombinatorVec}; +use super::{TryJoin as TryJoinTrait, TryJoinBehavior}; -use core::fmt; use core::future::{Future, IntoFuture}; -use core::mem; -use core::pin::Pin; -use core::task::{Context, Poll}; -use std::boxed::Box; use std::vec::Vec; /// Wait for all futures to complete successfully, or abort early on error. @@ -17,71 +11,41 @@ use std::vec::Vec; /// /// [`try_join`]: crate::future::TryJoin::try_join /// [`TryJoin`]: crate::future::TryJoin -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct TryJoin -where - Fut: Future>, -{ - elems: Pin]>>, -} - -impl fmt::Debug for TryJoin -where - Fut: Future> + fmt::Debug, - Fut::Output: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.elems.iter()).finish() - } -} +pub type TryJoin = CombinatorVec; -impl Future for TryJoin +impl CombinatorBehaviorVec for TryJoinBehavior where - T: std::fmt::Debug, Fut: Future>, { type Output = Result, E>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut all_done = true; + type StoredItem = T; - for mut elem in iter_pin_mut(self.elems.as_mut()) { - if elem.as_mut().poll(cx).is_pending() { - all_done = false - } else if let Some(Err(_)) = elem.as_ref().output() { - return Poll::Ready(Err(elem.take().unwrap().unwrap_err())); - } + fn maybe_return( + _idx: usize, + res: ::Output, + ) -> Result { + match res { + Ok(v) => Ok(v), + Err(e) => Err(Err(e)), } + } - if all_done { - let mut elems = mem::replace(&mut self.elems, Box::pin([])); - let result = iter_pin_mut(elems.as_mut()) - .map(|e| e.take().unwrap()) - .collect(); - Poll::Ready(result) - } else { - Poll::Pending - } + fn when_completed(vec: Vec) -> Self::Output { + Ok(vec) } } -impl TryJoinTrait for Vec +impl TryJoinTrait for Vec where - T: std::fmt::Debug, Fut: IntoFuture>, { - type Output = Vec; + type Ok = Vec; type Error = E; - type Future = TryJoin; + type Future = TryJoin; fn try_join(self) -> Self::Future { - let elems: Box<[_]> = self - .into_iter() - .map(|fut| MaybeDone::new(fut.into_future())) - .collect(); - TryJoin { - elems: elems.into(), - } + TryJoin::new(self.into_iter().map(IntoFuture::into_future).collect()) } } diff --git a/src/lib.rs b/src/lib.rs index d9778bc..18977d8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,11 +67,16 @@ pub mod prelude { pub mod future; pub mod stream; +/// Error types. +pub mod errors { + pub use crate::future::race_ok::error::AggregateError; +} + /// Helper functions and types for fixed-length arrays. pub mod array { pub use crate::future::join::array::Join; pub use crate::future::race::array::Race; - pub use crate::future::race_ok::array::{AggregateError, RaceOk}; + pub use crate::future::race_ok::array::RaceOk; pub use crate::future::try_join::array::TryJoin; pub use crate::stream::chain::array::Chain; pub use crate::stream::merge::array::Merge; @@ -82,7 +87,7 @@ pub mod array { pub mod vec { pub use crate::future::join::vec::Join; pub use crate::future::race::vec::Race; - pub use crate::future::race_ok::vec::{AggregateError, RaceOk}; + pub use crate::future::race_ok::vec::RaceOk; pub use crate::future::try_join::vec::TryJoin; pub use crate::stream::chain::vec::Chain; pub use crate::stream::merge::vec::Merge; diff --git a/src/stream/merge/array.rs b/src/stream/merge/array.rs index 8a5fe39..f7cc482 100644 --- a/src/stream/merge/array.rs +++ b/src/stream/merge/array.rs @@ -1,11 +1,12 @@ use super::Merge as MergeTrait; use crate::stream::IntoStream; -use crate::utils::{self, Indexer, PollArray, WakerArray}; +use crate::utils::{self, ArrayDequeue, PollState, WakerArray}; use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; + use futures_core::Stream; -use std::pin::Pin; -use std::task::{Context, Poll}; /// A stream that merges multiple streams into a single stream. /// @@ -21,10 +22,20 @@ where { #[pin] streams: [S; N], - indexer: Indexer, wakers: WakerArray, - state: PollArray, - complete: usize, + /// Number of substreams that haven't completed. + pending: usize, + /// The states of the N streams. + /// Pending = stream is sleeping + /// Ready = stream is awake + /// Consumed = stream is complete + state: [PollState; N], + /// List of awoken streams. + awake_list: ArrayDequeue, + /// Streams should not be polled after complete. + /// In debug, we panic to the user. + /// In release, we might sleep or poll substreams after completion. + #[cfg(debug_assertions)] done: bool, } @@ -35,10 +46,13 @@ where pub(crate) fn new(streams: [S; N]) -> Self { Self { streams, - indexer: Indexer::new(N), wakers: WakerArray::new(), - state: PollArray::new(), - complete: 0, + pending: N, + // Start with every substream awake. + state: [PollState::Ready; N], + // Start with indices of every substream since they're all awake. + awake_list: ArrayDequeue::new(core::array::from_fn(core::convert::identity), N), + #[cfg(debug_assertions)] done: false, } } @@ -62,48 +76,75 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); - let mut readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_waker(cx.waker()); - - // Iterate over our streams one-by-one. If a stream yields a value, - // we exit early. By default we'll return `Poll::Ready(None)`, but - // this changes if we encounter a `Poll::Pending`. - for index in this.indexer.iter() { - if !readiness.any_ready() { - // Nothing is ready yet - return Poll::Pending; - } else if !readiness.clear_ready(index) || this.state[index].is_consumed() { - continue; - } + #[cfg(debug_assertions)] + assert!(!*this.done, "Stream should not be polled after completing"); + + { + // Lock the readiness Mutex. + let mut readiness = this.wakers.readiness(); + readiness.set_parent_waker(cx.waker()); + + // Copy over the indices of awake substreams. + let awake_list = readiness.awake_list(); + let states = &mut *this.state; + this.awake_list.extend(awake_list.iter().filter_map(|&idx| { + // Only add to our list if the substream is actually pending. + // Our awake list will never contain duplicate indices. + let state = &mut states[idx]; + match state { + PollState::Pending => { + // Set the state to awake. + *state = PollState::Ready; + Some(idx) + } + _ => None, + } + })); + + // Clear the list in the Mutex. + readiness.clear(); - // unlock readiness so we don't deadlock when polling - drop(readiness); + // The Mutex should be unlocked here. + } - // Obtain the intermediate waker. - let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); + for idx in this.awake_list.drain() { + let state = &mut this.state[idx]; + // At this point state must be PollState::Ready (substream is awake). - let stream = utils::get_pin_mut(this.streams.as_mut(), index).unwrap(); + let waker = this.wakers.get(idx).unwrap(); + let mut cx = Context::from_waker(waker); + let stream = utils::get_pin_mut(this.streams.as_mut(), idx).unwrap(); match stream.poll_next(&mut cx) { Poll::Ready(Some(item)) => { - // Mark ourselves as ready again because we need to poll for the next item. - this.wakers.readiness().lock().unwrap().set_ready(index); + // Queue the substream to be polled again next time. + // Todo: figure out how to do this without locking the Mutex. + // We can do `this.awake_list.push_back(idx)` + // but that will cause this substream to be scheduled before the others that have woken + // between the indices-copying and now, leading to unfairness. + waker.wake_by_ref(); + *state = PollState::Pending; + return Poll::Ready(Some(item)); } Poll::Ready(None) => { - *this.complete += 1; - this.state[index].set_consumed(); - if *this.complete == this.streams.len() { - return Poll::Ready(None); - } + *this.pending -= 1; + *state = PollState::Consumed; + } + Poll::Pending => { + *state = PollState::Pending; } - Poll::Pending => {} } - - // Lock readiness so we can use it again - readiness = this.wakers.readiness().lock().unwrap(); } - Poll::Pending + if *this.pending == 0 { + #[cfg(debug_assertions)] + { + *this.done = true; + } + Poll::Ready(None) + } else { + Poll::Pending + } } } diff --git a/src/stream/merge/tuple.rs b/src/stream/merge/tuple.rs index 093bc5b..0cf1d2f 100644 --- a/src/stream/merge/tuple.rs +++ b/src/stream/merge/tuple.rs @@ -1,82 +1,20 @@ use super::Merge as MergeTrait; use crate::stream::IntoStream; -use crate::utils::{self, PollArray, WakerArray}; +use crate::utils::{ArrayDequeue, PollState, WakerArray}; use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; + use futures_core::Stream; -use std::pin::Pin; -use std::task::{Context, Poll}; - -macro_rules! poll_stream { - ($stream_idx:tt, $iteration:ident, $this:ident, $streams:ident . $stream_member:ident, $cx:ident, $len_streams:ident) => { - if $stream_idx == $iteration { - match unsafe { Pin::new_unchecked(&mut $streams.$stream_member) }.poll_next(&mut $cx) { - Poll::Ready(Some(item)) => { - // Mark ourselves as ready again because we need to poll for the next item. - $this - .wakers - .readiness() - .lock() - .unwrap() - .set_ready($stream_idx); - return Poll::Ready(Some(item)); - } - Poll::Ready(None) => { - *$this.completed += 1; - $this.state[$stream_idx].set_consumed(); - if *$this.completed == $len_streams { - return Poll::Ready(None); - } - } - Poll::Pending => {} - } - } - }; -} macro_rules! impl_merge_tuple { - ($ignore:ident $StructName:ident) => { - /// A stream that merges multiple streams into a single stream. - /// - /// This `struct` is created by the [`merge`] method on the [`Merge`] trait. See its - /// documentation for more. - /// - /// [`merge`]: trait.Merge.html#method.merge - /// [`Merge`]: trait.Merge.html - pub struct $StructName {} - - impl fmt::Debug for $StructName { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Merge").finish() - } - } - - impl Stream for $StructName { - type Item = core::convert::Infallible; // TODO: convert to `never` type in the stdlib - - fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(None) - } - } - - impl MergeTrait for () { - type Item = core::convert::Infallible; // TODO: convert to `never` type in the stdlib - type Stream = $StructName; - - fn merge(self) -> Self::Stream { - $StructName { } - } - } - }; - ($mod_name:ident $StructName:ident $($F:ident)+) => { + ($mod_name:ident $StructName:ident $($F:ident=$fut_idx:tt)+) => { mod $mod_name { #[pin_project::pin_project] pub(super) struct Streams<$($F,)+> { $(#[pin] pub(super) $F: $F),+ } - #[repr(usize)] - pub(super) enum Indexes { $($F),+ } - - pub(super) const LEN: usize = [$(Indexes::$F),+].len(); + pub(super) const LEN: usize = [$($fut_idx),+].len(); } /// A stream that merges multiple streams into a single stream. @@ -92,10 +30,12 @@ macro_rules! impl_merge_tuple { $F: Stream, )* { #[pin] streams: $mod_name::Streams<$($F,)+>, - indexer: utils::Indexer, wakers: WakerArray<{$mod_name::LEN}>, - state: PollArray<{$mod_name::LEN}>, - completed: u8, + pending: usize, + state: [PollState; $mod_name::LEN], + awake_list: ArrayDequeue, + #[cfg(debug_assertions)] + done: bool } impl fmt::Debug for $StructName @@ -119,47 +59,70 @@ macro_rules! impl_merge_tuple { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - let mut readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_waker(cx.waker()); - - const LEN: u8 = $mod_name::LEN as u8; + #[cfg(debug_assertions)] +assert!(!*this.done, "Stream should not be polled after completing"); + + { + let mut readiness = this.wakers.readiness(); + readiness.set_parent_waker(cx.waker()); + let awake_list = readiness.awake_list(); + let states = &mut *this.state; + this.awake_list.extend(awake_list.iter().filter_map(|&idx| { + let state = &mut states[idx]; + match state { + PollState::Pending => { + *state = PollState::Ready; + Some(idx) + }, + _ => None + } + })); + readiness.clear(); + } let mut streams = this.streams.project(); - // Iterate over our streams one-by-one. If a stream yields a value, - // we exit early. By default we'll return `Poll::Ready(None)`, but - // this changes if we encounter a `Poll::Pending`. - for index in this.indexer.iter() { - if !readiness.any_ready() { - // Nothing is ready yet - return Poll::Pending; - } else if !readiness.clear_ready(index) || this.state[index].is_consumed() { + for idx in this.awake_list.drain() { + let state = &mut this.state[idx]; + if let PollState::Consumed = *state { continue; } - - // unlock readiness so we don't deadlock when polling - drop(readiness); - - // Obtain the intermediate waker. - let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); - - $( - let stream_index = $mod_name::Indexes::$F as usize; - poll_stream!( - stream_index, - index, - this, - streams . $F, - cx, - LEN - ); - )+ - - // Lock readiness so we can use it again - readiness = this.wakers.readiness().lock().unwrap(); + let waker = this.wakers.get(idx).unwrap(); + let mut cx = Context::from_waker(waker); + + let poll_res = match idx { + $( + $fut_idx => { + streams.$F.as_mut().poll_next(&mut cx) + } + ),+ + _ => unreachable!() + }; + match poll_res { + Poll::Ready(Some(item)) => { + waker.wake_by_ref(); + *state = PollState::Pending; + return Poll::Ready(Some(item)); + } + Poll::Ready(None) => { + *this.pending -= 1; + *state = PollState::Consumed; + } + Poll::Pending => { + *state = PollState::Pending; + } + } + } + if *this.pending == 0 { + #[cfg(debug_assertions)] + { + *this.done = true; + } + Poll::Ready(None) + } + else { + Poll::Pending } - - Poll::Pending } } @@ -174,29 +137,45 @@ macro_rules! impl_merge_tuple { let ($($F,)*): ($($F,)*) = self; $StructName { streams: $mod_name::Streams { $($F: $F.into_stream()),+ }, - indexer: utils::Indexer::new(utils::tuple_len!($($F,)*)), wakers: WakerArray::new(), - state: PollArray::new(), - completed: 0, + pending: $mod_name::LEN, + state: [PollState::Ready; $mod_name::LEN], + awake_list: ArrayDequeue::new(core::array::from_fn(core::convert::identity), $mod_name::LEN), + #[cfg(debug_assertions)] + done: false } } } }; } - -impl_merge_tuple! { merge0 Merge0 } -impl_merge_tuple! { merge1 Merge1 A } -impl_merge_tuple! { merge2 Merge2 A B } -impl_merge_tuple! { merge3 Merge3 A B C } -impl_merge_tuple! { merge4 Merge4 A B C D } -impl_merge_tuple! { merge5 Merge5 A B C D E } -impl_merge_tuple! { merge6 Merge6 A B C D E F } -impl_merge_tuple! { merge7 Merge7 A B C D E F G } -impl_merge_tuple! { merge8 Merge8 A B C D E F G H } -impl_merge_tuple! { merge9 Merge9 A B C D E F G H I } -impl_merge_tuple! { merge10 Merge10 A B C D E F G H I J } -impl_merge_tuple! { merge11 Merge11 A B C D E F G H I J K } -impl_merge_tuple! { merge12 Merge12 A B C D E F G H I J K L } +impl_merge_tuple! { merge1 Merge1 A=0 } +impl_merge_tuple! { merge2 Merge2 A=0 B=1 } +impl_merge_tuple! { merge3 Merge3 A=0 B=1 C=2 } +impl_merge_tuple! { merge4 Merge4 A=0 B=1 C=2 D=3 } +impl_merge_tuple! { merge5 Merge5 A=0 B=1 C=2 D=3 E=4 } +impl_merge_tuple! { merge6 Merge6 A=0 B=1 C=2 D=3 E=4 F=5 } +impl_merge_tuple! { merge7 Merge7 A=0 B=1 C=2 D=3 E=4 F=5 G=6 } +impl_merge_tuple! { merge8 Merge8 A=0 B=1 C=2 D=3 E=4 F=5 G=6 H=7 } +impl_merge_tuple! { merge9 Merge9 A=0 B=1 C=2 D=3 E=4 F=5 G=6 H=7 I=8 } +impl_merge_tuple! { merge10 Merge10 A=0 B=1 C=2 D=3 E=4 F=5 G=6 H=7 I=8 J=9 } +impl_merge_tuple! { merge11 Merge11 A=0 B=1 C=2 D=3 E=4 F=5 G=6 H=7 I=8 J=9 K=10 } +impl_merge_tuple! { merge12 Merge12 A=0 B=1 C=2 D=3 E=4 F=5 G=6 H=7 I=8 J=9 K=10 L=11 } + +impl MergeTrait for () { + type Item = core::convert::Infallible; + type Stream = Merge0; + fn merge(self) -> Self::Stream { + Merge0 + } +} +#[derive(Debug)] +pub struct Merge0; +impl Stream for Merge0 { + type Item = core::convert::Infallible; + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(None) + } +} #[cfg(test)] mod tests { diff --git a/src/stream/merge/vec.rs b/src/stream/merge/vec.rs index 80869b4..4fd7a8d 100644 --- a/src/stream/merge/vec.rs +++ b/src/stream/merge/vec.rs @@ -1,11 +1,16 @@ use super::Merge as MergeTrait; use crate::stream::IntoStream; -use crate::utils::{self, Indexer, PollVec, WakerVec}; +use crate::utils::{self, WakerVec}; use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; +use std::collections::VecDeque; + +use bitvec::vec::BitVec; use futures_core::Stream; -use std::pin::Pin; -use std::task::{Context, Poll}; + +// For code comments, see the array merge code, which is very similar. /// A stream that merges multiple streams into a single stream. /// @@ -21,10 +26,12 @@ where { #[pin] streams: Vec, - indexer: Indexer, - complete: usize, wakers: WakerVec, - state: PollVec, + pending: usize, + consumed: BitVec, + awake_set: BitVec, + awake_list: VecDeque, + #[cfg(debug_assertions)] done: bool, } @@ -35,11 +42,18 @@ where pub(crate) fn new(streams: Vec) -> Self { let len = streams.len(); Self { - wakers: WakerVec::new(len), - state: PollVec::new(len), - indexer: Indexer::new(len), streams, - complete: 0, + wakers: WakerVec::new(len), + pending: len, + // Instead of using Vec, we use two bitvecs. + // !consumed && !awake = PollState::Pending + // !consumed && awake = PollState::Ready + // consumed = PollState::Consumed + // TODO: is this space-saving (from 8N to 2N bits) really worth it? + consumed: BitVec::repeat(false, len), + awake_set: BitVec::repeat(false, len), + awake_list: VecDeque::with_capacity(len), + #[cfg(debug_assertions)] done: false, } } @@ -63,48 +77,49 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); - let mut readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_waker(cx.waker()); - - // Iterate over our streams one-by-one. If a stream yields a value, - // we exit early. By default we'll return `Poll::Ready(None)`, but - // this changes if we encounter a `Poll::Pending`. - for index in this.indexer.iter() { - if !readiness.any_ready() { - // Nothing is ready yet - return Poll::Pending; - } else if !readiness.clear_ready(index) || this.state[index].is_consumed() { - continue; - } - - // unlock readiness so we don't deadlock when polling - drop(readiness); - - // Obtain the intermediate waker. - let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); + #[cfg(debug_assertions)] + assert!(!*this.done, "Stream should not be polled after completing"); + + { + let mut readiness = this.wakers.readiness(); + readiness.set_parent_waker(cx.waker()); + let awake_list = readiness.awake_list(); + let awake_set = &mut *this.awake_set; + let consumed = &mut *this.consumed; + this.awake_list.extend(awake_list.iter().filter_map(|&idx| { + // Only add substream that is in !awake && !consumed state. + // Set the state to awake in the process. + (!awake_set.replace(idx, true) && !consumed[idx]).then_some(idx) + })); + readiness.clear(); + } - let stream = utils::get_pin_mut_from_vec(this.streams.as_mut(), index).unwrap(); + while let Some(idx) = this.awake_list.pop_front() { + this.awake_set.set(idx, false); + let waker = this.wakers.get(idx).unwrap(); + let mut cx = Context::from_waker(waker); + let stream = utils::get_pin_mut_from_vec(this.streams.as_mut(), idx).unwrap(); match stream.poll_next(&mut cx) { Poll::Ready(Some(item)) => { - // Mark ourselves as ready again because we need to poll for the next item. - this.wakers.readiness().lock().unwrap().set_ready(index); + waker.wake_by_ref(); return Poll::Ready(Some(item)); } Poll::Ready(None) => { - *this.complete += 1; - this.state[index].set_consumed(); - if *this.complete == this.streams.len() { - return Poll::Ready(None); - } + *this.pending -= 1; + this.consumed.set(idx, true); } Poll::Pending => {} } - - // Lock readiness so we can use it again - readiness = this.wakers.readiness().lock().unwrap(); } - - Poll::Pending + if *this.pending == 0 { + #[cfg(debug_assertions)] + { + *this.done = true; + } + Poll::Ready(None) + } else { + Poll::Pending + } } } diff --git a/src/stream/zip/array.rs b/src/stream/zip/array.rs index f9d6acb..6c94156 100644 --- a/src/stream/zip/array.rs +++ b/src/stream/zip/array.rs @@ -1,13 +1,12 @@ use super::Zip as ZipTrait; use crate::stream::IntoStream; -use crate::utils::{self, PollArray, PollState, WakerArray}; +use crate::utils::{self, WakerArray}; use core::array; use core::fmt; use core::mem::MaybeUninit; use core::pin::Pin; use core::task::{Context, Poll}; -use std::mem; use futures_core::Stream; use pin_project::{pin_project, pinned_drop}; @@ -24,11 +23,32 @@ pub struct Zip where S: Stream, { + // Number of substreams that we're waiting for. + // MAGIC VALUE: pending == usize::MAX is used to signal that + // we just yielded a zipped value and every stream should be woken up again. + // + // pending value goes like + // N,N-1,...,1,0, + // ,usize::MAX, + // N,N-1,...,1,0, + // ,usize::MAX,... + pending: usize, + wakers: WakerArray, + /// The stored output from each substream. + items: [MaybeUninit<::Item>; N], + /// Whether each item in self.items is initialized. + /// Invariant: self.filled.count_falses() == self.pending + /// EXCEPT when self.pending==usize::MAX, self.filled must be all false. + filled: [bool; N], + /// A temporary buffer for indices that have woken. + /// The data here don't have to persist between each `poll_next`. + awake_list_buffer: [usize; N], #[pin] streams: [S; N], - output: [MaybeUninit<::Item>; N], - wakers: WakerArray, - state: PollArray, + /// Streams should not be polled after complete. + /// In debug, we panic to the user. + /// In release, we might sleep or poll substreams after completion. + #[cfg(debug_assertions)] done: bool, } @@ -39,9 +59,13 @@ where pub(crate) fn new(streams: [S; N]) -> Self { Self { streams, - output: array::from_fn(|_| MaybeUninit::uninit()), - state: PollArray::new(), + items: array::from_fn(|_| MaybeUninit::uninit()), + filled: [false; N], wakers: WakerArray::new(), + // TODO: this is a temporary buffer so it can be MaybeUninit. + awake_list_buffer: [0; N], + pending: N, + #[cfg(debug_assertions)] done: false, } } @@ -65,62 +89,82 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); - assert!(!*this.done, "Stream should not be polled after completion"); - - let mut readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_waker(cx.waker()); - for index in 0..N { - if !readiness.any_ready() { - // Nothing is ready yet - return Poll::Pending; - } else if this.state[index].is_ready() || !readiness.clear_ready(index) { - // We already have data stored for this stream, - // Or this waker isn't ready yet + #[cfg(debug_assertions)] + assert!(!*this.done, "Stream should not be polled after completing"); + + let num_awake = { + // Lock the readiness Mutex. + let mut readiness = this.wakers.readiness(); + readiness.set_parent_waker(cx.waker()); + + let num_awake = if *this.pending == usize::MAX { + // pending = usize::MAX is a special value used to communicate that + // a zipped value has been yielded and everything should be restarted. + *this.pending = N; + // Fill the awake_list_buffer with 0..N. + *this.awake_list_buffer = array::from_fn(core::convert::identity); + N + } else { + // Copy the awake list out of the Mutex. + let awake_list = readiness.awake_list(); + let num_awake = awake_list.len(); + this.awake_list_buffer[..num_awake].copy_from_slice(awake_list); + num_awake + }; + // Clear the list in the Mutex. + readiness.clear(); + num_awake + }; + + // Iterate over the awake list. + for &idx in this.awake_list_buffer.iter().take(num_awake) { + let filled = &mut this.filled[idx]; + if *filled { + // Woken substream has already yielded. continue; } - - // unlock readiness so we don't deadlock when polling - drop(readiness); - - // Obtain the intermediate waker. - let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); - - let stream = utils::get_pin_mut(this.streams.as_mut(), index).unwrap(); + let stream = utils::get_pin_mut(this.streams.as_mut(), idx).unwrap(); + let mut cx = Context::from_waker(this.wakers.get(idx).unwrap()); match stream.poll_next(&mut cx) { - Poll::Ready(Some(item)) => { - this.output[index] = MaybeUninit::new(item); - this.state[index].set_ready(); - - let all_ready = this.state.iter().all(|state| state.is_ready()); - if all_ready { - // Reset the future's state. - readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_all_ready(); - this.state.fill_with(PollState::default); - - // Take the output - // - // SAFETY: we just validated all our data is populated, meaning - // we can assume this is initialized. - let mut output = array::from_fn(|_| MaybeUninit::uninit()); - mem::swap(this.output, &mut output); - let output = unsafe { array_assume_init(output) }; - return Poll::Ready(Some(output)); - } + Poll::Ready(Some(value)) => { + this.items[idx].write(value); + *filled = true; + *this.pending -= 1; } Poll::Ready(None) => { - // If one stream returns `None`, we can no longer return - // pairs - meaning the stream is over. - *this.done = true; + // If one substream ends, the entire Zip ends. + + #[cfg(debug_assertions)] + { + *this.done = true; + } + return Poll::Ready(None); } Poll::Pending => {} } + } - // Lock readiness so we can use it again - readiness = this.wakers.readiness().lock().unwrap(); + if *this.pending == 0 { + debug_assert!( + this.filled.iter().all(|&filled| filled), + "The items array should have been filled" + ); + this.filled.fill(false); + + // Set this to the magic value so that the wakers get restarted next time. + *this.pending = usize::MAX; + + let mut items = array::from_fn(|_| MaybeUninit::uninit()); + core::mem::swap(this.items, &mut items); + + // SAFETY: this.pending is only decremented when an item slot is filled. + // pending reaching 0 means the entire items array is filled. + let items = unsafe { utils::array_assume_init(items) }; + Poll::Ready(Some(items)) + } else { + Poll::Pending } - Poll::Pending } } @@ -133,10 +177,9 @@ where fn drop(self: Pin<&mut Self>) { let this = self.project(); - for (state, output) in this.state.iter_mut().zip(this.output.iter_mut()) { - if state.is_ready() { - // SAFETY: we've just filtered down to *only* the initialized values. - // We can assume they're initialized, and this is where we drop them. + for (&filled, output) in this.filled.iter().zip(this.items.iter_mut()) { + if filled { + // SAFETY: when filled is true the item must be initialized. unsafe { output.assume_init_drop() }; } } @@ -176,16 +219,3 @@ mod tests { }) } } - -// Inlined version of the unstable `MaybeUninit::array_assume_init` feature. -// FIXME: replace with `utils::array_assume_init` -unsafe fn array_assume_init(array: [MaybeUninit; N]) -> [T; N] { - // SAFETY: - // * The caller guarantees that all elements of the array are initialized - // * `MaybeUninit` and T are guaranteed to have the same layout - // * `MaybeUninit` does not drop, so there are no double-frees - // And thus the conversion is safe - let ret = unsafe { (&array as *const _ as *const [T; N]).read() }; - mem::forget(array); - ret -} diff --git a/src/stream/zip/tuple.rs b/src/stream/zip/tuple.rs index 8b13789..60eeedd 100644 --- a/src/stream/zip/tuple.rs +++ b/src/stream/zip/tuple.rs @@ -1 +1,225 @@ +use super::Zip as ZipTrait; +use crate::stream::IntoStream; +use crate::utils::WakerArray; +use core::fmt; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use futures_core::Stream; +use pin_project::{pin_project, pinned_drop}; + +macro_rules! impl_zip_tuple { + ($mod_name:ident $StructName:ident $($F:ident=$fut_idx:tt)+) => { + mod $mod_name { + #[pin_project::pin_project] + pub(super) struct Streams<$($F,)+> { $(#[pin] pub(super) $F: $F),+ } + + pub(super) const LEN: usize = [$($fut_idx),+].len(); + } + + /// A stream that zips multiple streams into a single stream. + /// + /// This `struct` is created by the [`zip`] method on the [`Zip`] trait. See its + /// documentation for more. + /// + /// [`zip`]: trait.Zip.html#method.zip + /// [`Zip`]: trait.Zip.html + #[pin_project(PinnedDrop)] + pub struct $StructName<$($F),*> + where $( + $F: Stream, + )* { + #[pin] streams: $mod_name::Streams<$($F,)+>, + items: ($(MaybeUninit<$F::Item>,)+), + wakers: WakerArray<{$mod_name::LEN}>, + filled: [bool; $mod_name::LEN], + awake_list_buffer: [usize; $mod_name::LEN], + pending: usize, + #[cfg(debug_assertions)] + done: bool + } + + impl<$($F),*> fmt::Debug for $StructName<$($F),*> + where $( + $F: Stream + fmt::Debug, + )* { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Merge") + $( .field(&self.streams.$F) )* // Hides implementation detail of Streams struct + .finish() + } + } + + impl<$($F),*> Stream for $StructName<$($F),*> + where $( + $F: Stream, + )* { + type Item = ($($F::Item,)+); + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + #[cfg(debug_assertions)] +assert!(!*this.done, "Stream should not be polled after completing"); + + let num_awake = { + let mut readiness = this.wakers.readiness(); + readiness.set_parent_waker(cx.waker()); + let num_awake = if *this.pending == usize::MAX { + *this.pending = $mod_name::LEN; + *this.awake_list_buffer = core::array::from_fn(core::convert::identity); + $mod_name::LEN + } + else { + let awake_list = readiness.awake_list(); + let num_awake = awake_list.len(); + this.awake_list_buffer[..num_awake].copy_from_slice(awake_list); + num_awake + }; + readiness.clear(); + num_awake + }; + + let mut streams = this.streams.project(); + + for &idx in this.awake_list_buffer.iter().take(num_awake) { + let filled = &mut this.filled[idx]; + if *filled { + continue; + } + let mut cx = Context::from_waker(this.wakers.get(idx).unwrap()); + + match idx { + $( + $fut_idx => { + match streams.$F.as_mut().poll_next(&mut cx) { + Poll::Ready(Some(value)) => { + this.items.$fut_idx.write(value); + *filled = true; + *this.pending -= 1; + } + Poll::Ready(None) => { + #[cfg(debug_assertions)] + { + *this.done = true; + } + return Poll::Ready(None); + } + Poll::Pending => {} + } + } + ),+ + _ => unreachable!() + }; + } + if *this.pending == 0 { + debug_assert!( + this.filled.iter().all(|filled| *filled), + "The items array should have been filled" + ); + this.filled.fill(false); + + *this.pending = usize::MAX; + let mut out = ($(MaybeUninit::<$F::Item>::uninit(),)+); + core::mem::swap(&mut out, this.items); + // SAFETY: we've checked with the state that all of our outputs have been + // filled, which means we're ready to take the data and assume it's initialized. + Poll::Ready(Some(unsafe { ($(out.$fut_idx.assume_init(),)+) })) + } + else { + Poll::Pending + } + } + } + + impl<$($F),*> ZipTrait for ($($F,)*) + where $( + $F: IntoStream, + )* { + type Item = ($($F::Item,)+); + type Stream = $StructName<$($F::IntoStream),*>; + + fn zip(self) -> Self::Stream { + let ($($F,)*): ($($F,)*) = self; + $StructName { + streams: $mod_name::Streams { $($F: $F.into_stream()),+ }, + items: ($(MaybeUninit::<$F::Item>::uninit(),)+), + wakers: WakerArray::new(), + filled: [false; $mod_name::LEN], + awake_list_buffer: [0; $mod_name::LEN], + pending: $mod_name::LEN, + #[cfg(debug_assertions)] + done: false + } + } + } + #[pinned_drop] + impl<$($F),*> PinnedDrop for $StructName<$($F),*> + where $( + $F: Stream, + )* { + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + $( + if this.filled[$fut_idx] { + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { this.items.$fut_idx.assume_init_drop() }; + } + )+ + } + } + }; +} + +impl_zip_tuple! { zip1 Zip1 A=0 } +impl_zip_tuple! { zip2 Zip2 A=0 B=1 } +impl_zip_tuple! { zip3 Zip3 A=0 B=1 C=2 } +impl_zip_tuple! { zip4 Zip4 A=0 B=1 C=2 D=3 } +impl_zip_tuple! { zip5 Zip5 A=0 B=1 C=2 D=3 E=4 } +impl_zip_tuple! { zip6 Zip6 A=0 B=1 C=2 D=3 E=4 F=5 } +impl_zip_tuple! { zip7 Zip7 A=0 B=1 C=2 D=3 E=4 F=5 G=6 } +impl_zip_tuple! { zip8 Zip8 A=0 B=1 C=2 D=3 E=4 F=5 G=6 H=7 } +impl_zip_tuple! { zip9 Zip9 A=0 B=1 C=2 D=3 E=4 F=5 G=6 H=7 I=8 } +impl_zip_tuple! { zip10 Zip10 A=0 B=1 C=2 D=3 E=4 F=5 G=6 H=7 I=8 J=9 } +impl_zip_tuple! { zip11 Zip11 A=0 B=1 C=2 D=3 E=4 F=5 G=6 H=7 I=8 J=9 K=10 } +impl_zip_tuple! { zip12 Zip12 A=0 B=1 C=2 D=3 E=4 F=5 G=6 H=7 I=8 J=9 K=10 L=11 } + +impl ZipTrait for () { + type Item = (); + type Stream = Zip0; + + fn zip(self) -> Self::Stream { + Zip0 + } +} +#[derive(Debug)] +pub struct Zip0; +impl Stream for Zip0 { + type Item = (); + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Some(())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures_lite::{future::block_on, stream, StreamExt}; + #[test] + fn zip_tuple_3() { + block_on(async { + let mut s = ( + stream::repeat(3), + stream::repeat("hello"), + stream::once(1).chain(stream::once(5)), + ) + .zip(); + assert_eq!(s.next().await.unwrap(), (3, "hello", 1)); + assert_eq!(s.next().await.unwrap(), (3, "hello", 5)); + assert_eq!(s.next().await, None); + }) + } +} diff --git a/src/stream/zip/vec.rs b/src/stream/zip/vec.rs index 7ebf5f5..5a6a1ce 100644 --- a/src/stream/zip/vec.rs +++ b/src/stream/zip/vec.rs @@ -1,16 +1,18 @@ use super::Zip as ZipTrait; use crate::stream::IntoStream; -use crate::utils::{self, PollState, WakerVec}; +use crate::utils::{self, WakerVec}; use core::fmt; use core::mem::MaybeUninit; use core::pin::Pin; use core::task::{Context, Poll}; -use std::mem; +use bitvec::vec::BitVec; use futures_core::Stream; use pin_project::{pin_project, pinned_drop}; +// For code comments, see the array zip code, which is very similar. + /// A stream that ‘zips up’ multiple streams into a single stream of pairs. /// /// This `struct` is created by the [`zip`] method on the [`Zip`] trait. See its @@ -25,11 +27,13 @@ where { #[pin] streams: Vec, - output: Vec::Item>>, + items: Vec::Item>>, wakers: WakerVec, - state: Vec, + filled: BitVec, + awake_list_buffer: Vec, + pending: usize, + #[cfg(debug_assertions)] done: bool, - len: usize, } impl Zip @@ -39,11 +43,13 @@ where pub(crate) fn new(streams: Vec) -> Self { let len = streams.len(); Self { - len, streams, wakers: WakerVec::new(len), - output: (0..len).map(|_| MaybeUninit::uninit()).collect(), - state: (0..len).map(|_| PollState::default()).collect(), + items: (0..len).map(|_| MaybeUninit::uninit()).collect(), + filled: BitVec::repeat(false, len), + awake_list_buffer: Vec::new(), + pending: len, + #[cfg(debug_assertions)] done: false, } } @@ -67,62 +73,64 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); - assert!(!*this.done, "Stream should not be polled after completion"); - - let mut readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_waker(cx.waker()); - for index in 0..*this.len { - if !readiness.any_ready() { - // Nothing is ready yet - return Poll::Pending; - } else if this.state[index].is_ready() || !readiness.clear_ready(index) { - // We already have data stored for this stream, - // Or this waker isn't ready yet - continue; + #[cfg(debug_assertions)] + assert!(!*this.done, "Stream should not be polled after completing"); + + let len = this.streams.len(); + { + let mut readiness = this.wakers.readiness(); + readiness.set_parent_waker(cx.waker()); + if *this.pending == usize::MAX { + *this.pending = len; + this.awake_list_buffer.clear(); + this.awake_list_buffer.extend(0..len); + } else { + this.awake_list_buffer.clone_from(readiness.awake_list()); } + readiness.clear(); + } - // unlock readiness so we don't deadlock when polling - drop(readiness); - - // Obtain the intermediate waker. - let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); - - let stream = utils::get_pin_mut_from_vec(this.streams.as_mut(), index).unwrap(); + for idx in this.awake_list_buffer.drain(..) { + let mut filled = this.filled.get_mut(idx).unwrap(); + if *filled { + continue; + } + let stream = utils::get_pin_mut_from_vec(this.streams.as_mut(), idx).unwrap(); + let mut cx = Context::from_waker(this.wakers.get(idx).unwrap()); match stream.poll_next(&mut cx) { - Poll::Ready(Some(item)) => { - this.output[index] = MaybeUninit::new(item); - this.state[index].set_ready(); - - let all_ready = this.state.iter().all(|state| state.is_ready()); - if all_ready { - // Reset the future's state. - readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_all_ready(); - this.state.fill_with(PollState::default); - - // Take the output - // - // SAFETY: we just validated all our data is populated, meaning - // we can assume this is initialized. - let mut output = (0..*this.len).map(|_| MaybeUninit::uninit()).collect(); - mem::swap(this.output, &mut output); - let output = unsafe { vec_assume_init(output) }; - return Poll::Ready(Some(output)); - } + Poll::Ready(Some(value)) => { + this.items[idx].write(value); + filled.set(true); + *this.pending -= 1; } Poll::Ready(None) => { - // If one stream returns `None`, we can no longer return - // pairs - meaning the stream is over. - *this.done = true; + #[cfg(debug_assertions)] + { + *this.done = true; + } return Poll::Ready(None); } Poll::Pending => {} } + } + + if *this.pending == 0 { + debug_assert!( + this.filled.iter().all(|filled| *filled), + "Future should have reached a `Ready` state" + ); + this.filled.fill(false); + + *this.pending = usize::MAX; + + let mut output = (0..len).map(|_| MaybeUninit::uninit()).collect(); + core::mem::swap(this.items, &mut output); - // Lock readiness so we can use it again - readiness = this.wakers.readiness().lock().unwrap(); + let output = unsafe { vec_assume_init(output) }; + Poll::Ready(Some(output)) + } else { + Poll::Pending } - Poll::Pending } } @@ -135,8 +143,8 @@ where fn drop(self: Pin<&mut Self>) { let this = self.project(); - for (state, output) in this.state.iter_mut().zip(this.output.iter_mut()) { - if state.is_ready() { + for (filled, output) in this.filled.iter().zip(this.items.iter_mut()) { + if *filled { // SAFETY: we've just filtered down to *only* the initialized values. // We can assume they're initialized, and this is where we drop them. unsafe { output.assume_init_drop() }; @@ -165,7 +173,7 @@ mod tests { use futures_lite::stream; #[test] - fn zip_array_3() { + fn zip_vec_3() { block_on(async { let a = stream::repeat(1).take(2); let b = stream::repeat(2).take(2); @@ -188,6 +196,6 @@ unsafe fn vec_assume_init(vec: Vec>) -> Vec { // * `MaybeUninit` does not drop, so there are no double-frees // And thus the conversion is safe let ret = unsafe { (&vec as *const _ as *const Vec).read() }; - mem::forget(vec); + core::mem::forget(vec); ret } diff --git a/src/utils/array_dequeue.rs b/src/utils/array_dequeue.rs new file mode 100644 index 0000000..eb52370 --- /dev/null +++ b/src/utils/array_dequeue.rs @@ -0,0 +1,50 @@ +pub(crate) struct ArrayDequeue { + data: [T; N], + start: usize, + len: usize, +} + +impl ArrayDequeue { + pub(crate) fn new(data: [T; N], len: usize) -> Self { + Self { + data, + start: 0, + len, + } + } + pub(crate) fn push_back(&mut self, elem: T) { + assert!(self.len < N, "array is full"); + self.data[(self.start + self.len) % N] = elem; + self.len += 1; + } +} + +struct ArrayDequeueDrain<'a, T: Copy, const N: usize> { + arr: &'a mut ArrayDequeue, +} +impl<'a, T: Copy, const N: usize> Iterator for ArrayDequeueDrain<'a, T, N> { + type Item = T; + fn next(&mut self) -> Option { + if self.arr.len > 0 { + let elem = self.arr.data[self.arr.start]; + self.arr.start = (self.arr.start + 1) % N; + self.arr.len -= 1; + Some(elem) + } else { + None + } + } +} + +impl ArrayDequeue { + pub(crate) fn drain(&mut self) -> impl Iterator + '_ { + ArrayDequeueDrain { arr: self } + } +} +impl Extend for ArrayDequeue { + fn extend>(&mut self, iter: I) { + iter.into_iter().for_each(|elem| { + self.push_back(elem); + }); + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ed9932a..370934d 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,6 +1,7 @@ //! Utilities to implement the different futures of this crate. mod array; +mod array_dequeue; mod indexer; mod pin; mod poll_state; @@ -8,15 +9,11 @@ mod tuple; mod wakers; pub(crate) use array::array_assume_init; -pub(crate) use indexer::Indexer; +pub(crate) use array_dequeue::ArrayDequeue; pub(crate) use pin::{get_pin_mut, get_pin_mut_from_vec, iter_pin_mut, iter_pin_mut_vec}; -pub(crate) use poll_state::MaybeDone; -pub(crate) use poll_state::{PollArray, PollState, PollVec}; -pub(crate) use tuple::{gen_conditions, tuple_len}; +pub(crate) use poll_state::PollState; pub(crate) use wakers::{WakerArray, WakerVec}; -#[cfg(test)] -pub(crate) use wakers::dummy_waker; - #[cfg(test)] pub(crate) mod channel; +pub(crate) use wakers::dummy_waker; diff --git a/src/utils/poll_state/array.rs b/src/utils/poll_state/array.rs deleted file mode 100644 index 04baf0a..0000000 --- a/src/utils/poll_state/array.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::ops::{Deref, DerefMut}; - -use super::PollState; - -pub(crate) struct PollArray { - state: [PollState; N], -} - -impl PollArray { - pub(crate) fn new() -> Self { - Self { - state: [PollState::default(); N], - } - } - - #[inline] - pub(crate) fn set_all_completed(&mut self) { - self.iter_mut().for_each(|state| { - debug_assert!( - state.is_ready(), - "Future should have reached a `Ready` state" - ); - state.set_consumed(); - }) - } -} - -impl Deref for PollArray { - type Target = [PollState]; - - fn deref(&self) -> &Self::Target { - &self.state - } -} - -impl DerefMut for PollArray { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.state - } -} diff --git a/src/utils/poll_state/maybe_done.rs b/src/utils/poll_state/maybe_done.rs deleted file mode 100644 index 42aed43..0000000 --- a/src/utils/poll_state/maybe_done.rs +++ /dev/null @@ -1,70 +0,0 @@ -use core::future::Future; -use core::mem; -use core::pin::Pin; -use core::task::{ready, Context, Poll}; - -/// A future that may have completed. -#[derive(Debug)] -pub(crate) enum MaybeDone { - /// A not-yet-completed future - Future(Fut), - - /// The output of the completed future - Done(Fut::Output), - - /// The empty variant after the result of a [`MaybeDone`] has been - /// taken using the [`take`](MaybeDone::take) method. - Gone, -} - -impl MaybeDone { - /// Create a new instance of `MaybeDone`. - pub(crate) fn new(future: Fut) -> MaybeDone { - Self::Future(future) - } - - /// Returns an [`Option`] containing a reference to the output of the future. - /// The output of this method will be [`Some`] if and only if the inner - /// future has been completed and [`take`](MaybeDone::take) - /// has not yet been called. - #[inline] - pub(crate) fn output(self: Pin<&Self>) -> Option<&Fut::Output> { - let this = self.get_ref(); - match this { - MaybeDone::Done(res) => Some(res), - _ => None, - } - } - - /// Attempt to take the output of a `MaybeDone` without driving it - /// towards completion. - #[inline] - pub(crate) fn take(self: Pin<&mut Self>) -> Option { - let this = unsafe { self.get_unchecked_mut() }; - match this { - MaybeDone::Done(_) => {} - MaybeDone::Future(_) | MaybeDone::Gone => return None, - }; - if let MaybeDone::Done(output) = mem::replace(this, MaybeDone::Gone) { - Some(output) - } else { - unreachable!() - } - } -} - -impl Future for MaybeDone { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let res = unsafe { - match Pin::as_mut(&mut self).get_unchecked_mut() { - MaybeDone::Future(a) => ready!(Pin::new_unchecked(a).poll(cx)), - MaybeDone::Done(_) => return Poll::Ready(()), - MaybeDone::Gone => panic!("MaybeDone polled after value taken"), - } - }; - self.set(MaybeDone::Done(res)); - Poll::Ready(()) - } -} diff --git a/src/utils/poll_state/mod.rs b/src/utils/poll_state/mod.rs index 90a4d2c..62b2b04 100644 --- a/src/utils/poll_state/mod.rs +++ b/src/utils/poll_state/mod.rs @@ -1,11 +1,14 @@ -#![allow(clippy::module_inception)] - -mod array; -mod maybe_done; -mod poll_state; -mod vec; - -pub(crate) use array::PollArray; -pub(crate) use maybe_done::MaybeDone; -pub(crate) use poll_state::PollState; -pub(crate) use vec::PollVec; +/// Enumerate the current poll state. +#[derive(Debug, Clone, Copy, Default)] +#[repr(u8)] +pub(crate) enum PollState { + /// Polling the underlying future or stream. + #[default] + Pending, + /// Data has been written to the output structure, and is now ready to be + /// read. + Ready, + /// The underlying future or stream has finished yielding data and all data + /// has been read. We can now stop reasoning about it. + Consumed, +} diff --git a/src/utils/poll_state/poll_state.rs b/src/utils/poll_state/poll_state.rs deleted file mode 100644 index 35be7fe..0000000 --- a/src/utils/poll_state/poll_state.rs +++ /dev/null @@ -1,49 +0,0 @@ -/// Enumerate the current poll state. -#[derive(Debug, Clone, Copy, Default)] -#[repr(u8)] -pub(crate) enum PollState { - /// Polling the underlying future or stream. - #[default] - Pending, - /// Data has been written to the output structure, and is now ready to be - /// read. - Ready, - /// The underlying future or stream has finished yielding data and all data - /// has been read. We can now stop reasoning about it. - Consumed, -} - -impl PollState { - /// Returns `true` if the metadata is [`Pending`][Self::Pending]. - #[must_use] - #[inline] - pub(crate) fn is_pending(&self) -> bool { - matches!(self, Self::Pending) - } - - /// Returns `true` if the poll state is [`Ready`][Self::Ready]. - #[must_use] - #[inline] - pub(crate) fn is_ready(&self) -> bool { - matches!(self, Self::Ready) - } - - /// Sets the poll state to [`Ready`][Self::Ready]. - #[inline] - pub(crate) fn set_ready(&mut self) { - *self = PollState::Ready; - } - - /// Returns `true` if the poll state is [`Consumed`][Self::Consumed]. - #[must_use] - #[inline] - pub(crate) fn is_consumed(&self) -> bool { - matches!(self, Self::Consumed) - } - - /// Sets the poll state to [`Consumed`][Self::Consumed]. - #[inline] - pub(crate) fn set_consumed(&mut self) { - *self = PollState::Consumed; - } -} diff --git a/src/utils/poll_state/vec.rs b/src/utils/poll_state/vec.rs deleted file mode 100644 index 3f9807e..0000000 --- a/src/utils/poll_state/vec.rs +++ /dev/null @@ -1,90 +0,0 @@ -use std::ops::{Deref, DerefMut}; - -use super::PollState; - -/// The maximum number of entries that `PollStates` can store without -/// dynamic memory allocation. -/// -/// The `Boxed` variant is the minimum size the data structure can have. -/// It consists of a boxed slice (=2 usizes) and space for the enum -/// tag (another usize because of padding), so 3 usizes. -/// The inline variant then consists of `3 * size_of(usize) - 2` entries. -/// Each entry is a byte and we subtract one byte for a length field, -/// and another byte for the enum tag. -/// -/// ```txt -/// Boxed -/// vvvvv -/// tag -/// | <-------padding----> <--- Box<[T]>::len ---> <--- Box<[T]>::ptr ---> -/// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 -/// tag | -/// len ^^^^^ -/// Inline -/// ``` -const MAX_INLINE_ENTRIES: usize = std::mem::size_of::() * 3 - 2; - -pub(crate) enum PollVec { - Inline(u8, [PollState; MAX_INLINE_ENTRIES]), - Boxed(Box<[PollState]>), -} - -impl PollVec { - pub(crate) fn new(len: usize) -> Self { - assert!(MAX_INLINE_ENTRIES <= u8::MAX as usize); - - if len <= MAX_INLINE_ENTRIES { - Self::Inline(len as u8, Default::default()) - } else { - // Make sure that we don't reallocate the vec's memory - // during `Vec::into_boxed_slice()`. - let mut states = Vec::new(); - debug_assert_eq!(states.capacity(), 0); - states.reserve_exact(len); - debug_assert_eq!(states.capacity(), len); - states.resize(len, PollState::default()); - debug_assert_eq!(states.capacity(), len); - Self::Boxed(states.into_boxed_slice()) - } - } -} - -impl Deref for PollVec { - type Target = [PollState]; - - fn deref(&self) -> &Self::Target { - match self { - PollVec::Inline(len, states) => &states[..*len as usize], - Self::Boxed(states) => &states[..], - } - } -} - -impl DerefMut for PollVec { - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - PollVec::Inline(len, states) => &mut states[..*len as usize], - Self::Boxed(states) => &mut states[..], - } - } -} - -#[cfg(test)] -mod tests { - use super::{PollVec, MAX_INLINE_ENTRIES}; - - #[test] - fn type_size() { - assert_eq!( - std::mem::size_of::(), - std::mem::size_of::() * 3 - ); - } - - #[test] - fn boxed_does_not_allocate_twice() { - // Make sure the debug_assertions in PollStates::new() don't fail. - let _ = PollVec::new(MAX_INLINE_ENTRIES + 10); - } -} diff --git a/src/utils/wakers/array/readiness.rs b/src/utils/wakers/array/readiness.rs index 440035d..7a5eded 100644 --- a/src/utils/wakers/array/readiness.rs +++ b/src/utils/wakers/array/readiness.rs @@ -53,8 +53,9 @@ impl ReadinessArray { if self.awake_list_len < Self::TRESHOLD { self.awake_set.fill(false); } else { + let awake_set = &mut self.awake_set; self.awake_list.iter().for_each(|&idx| { - self.awake_set[idx] = false; + awake_set[idx] = false; }); } self.awake_list_len = 0; diff --git a/src/utils/wakers/vec/readiness.rs b/src/utils/wakers/vec/readiness.rs index 9e2daac..9c67189 100644 --- a/src/utils/wakers/vec/readiness.rs +++ b/src/utils/wakers/vec/readiness.rs @@ -47,8 +47,9 @@ impl ReadinessVec { // either use `fill` (memset) or iterate and set each. // TODO: I came up with the 64 factor at random. Maybe test different factors? if self.awake_list.len() * 64 < self.awake_set.len() { + let awake_set = &mut self.awake_set; self.awake_list.drain(..).for_each(|idx| { - self.awake_set.set(idx, false); + awake_set.set(idx, false); }); } else { self.awake_list.clear(); From ca77d396af6adc205a41144320e4869ad07f6522 Mon Sep 17 00:00:00 2001 From: Wisha Wa Date: Tue, 10 Jan 2023 03:17:06 +0000 Subject: [PATCH 08/10] remove dead code --- src/utils/indexer.rs | 42 ------------------------------------------ src/utils/mod.rs | 3 --- src/utils/tuple.rs | 37 ------------------------------------- 3 files changed, 82 deletions(-) delete mode 100644 src/utils/indexer.rs delete mode 100644 src/utils/tuple.rs diff --git a/src/utils/indexer.rs b/src/utils/indexer.rs deleted file mode 100644 index 9da6227..0000000 --- a/src/utils/indexer.rs +++ /dev/null @@ -1,42 +0,0 @@ -use core::ops; - -/// Generate an iteration sequence. This provides *fair* iteration when multiple -/// futures need to be polled concurrently. -pub(crate) struct Indexer { - offset: usize, - max: usize, -} - -impl Indexer { - pub(crate) fn new(max: usize) -> Self { - Self { offset: 0, max } - } - - /// Generate a range between `0..max`, incrementing the starting point - /// for the next iteration. - pub(crate) fn iter(&mut self) -> IndexIter { - // Increment the starting point for next time. - let offset = self.offset; - self.offset = (self.offset + 1).wrapping_rem(self.max); - - IndexIter { - iter: (0..self.max), - offset, - } - } -} - -pub(crate) struct IndexIter { - iter: ops::Range, - offset: usize, -} - -impl Iterator for IndexIter { - type Item = usize; - - fn next(&mut self) -> Option { - self.iter - .next() - .map(|pos| (pos + self.offset).wrapping_rem(self.iter.end)) - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 370934d..12541c2 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -2,10 +2,8 @@ mod array; mod array_dequeue; -mod indexer; mod pin; mod poll_state; -mod tuple; mod wakers; pub(crate) use array::array_assume_init; @@ -16,4 +14,3 @@ pub(crate) use wakers::{WakerArray, WakerVec}; #[cfg(test)] pub(crate) mod channel; -pub(crate) use wakers::dummy_waker; diff --git a/src/utils/tuple.rs b/src/utils/tuple.rs deleted file mode 100644 index 4afbd64..0000000 --- a/src/utils/tuple.rs +++ /dev/null @@ -1,37 +0,0 @@ -/// Generate the `match` conditions inside the main polling body. This macro -/// chooses a random starting point on each call to the given method, making -/// it "fair". -/// -/// The way this algorithm works is: we generate a random number between 0 and -/// the length of the tuple we have. This number determines which element we -/// start with. All other cases are mapped as `r + index`, and after we have the -/// first one, we'll sequentially iterate over all others. The starting point of -/// the stream is random, but the iteration order of all others is not. -// NOTE(yosh): this macro monstrosity is needed so we can increment each `else -// if` branch with + 1. When RFC 3086 becomes available to us, we can replace -// this with `${index($F)}` to get the current iteration. -// -// # References -// - https://twitter.com/maybewaffle/status/1588426440835727360 -// - https://twitter.com/Veykril/status/1588231414998335490 -// - https://rust-lang.github.io/rfcs/3086-macro-metavar-expr.html -macro_rules! gen_conditions { - // Base condition, setup the depth counter. - ($i:expr, $this:expr, $cx:expr, $method:ident, $(($F_index: expr; $F:ident, { $($arms:pat => $foo:expr,)* }))*) => { - $( - if $i == $F_index { - match unsafe { Pin::new_unchecked(&mut $this.$F) }.$method($cx) { - $($arms => $foo,)* - } - } - )* - } -} -pub(crate) use gen_conditions; - -/// Calculate the number of tuples currently being operated on. -macro_rules! tuple_len { - (@count_one $F:ident) => (1); - ($($F:ident,)*) => (0 $(+ crate::utils::tuple_len!(@count_one $F))*); -} -pub(crate) use tuple_len; From 74e13e5c5f17672c77d34b6ef7072343b4fe9502 Mon Sep 17 00:00:00 2001 From: Wisha Wa Date: Tue, 10 Jan 2023 03:44:29 +0000 Subject: [PATCH 09/10] fix formatting for AggregateError --- src/future/race_ok/array.rs | 2 +- src/future/race_ok/tuple.rs | 2 +- src/future/race_ok/vec.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/future/race_ok/array.rs b/src/future/race_ok/array.rs index 4f3b548..e0789df 100644 --- a/src/future/race_ok/array.rs +++ b/src/future/race_ok/array.rs @@ -58,7 +58,7 @@ mod err { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "multiple errors occurred: [")?; for e in self.errors.iter() { - write!(f, "\n{}", e)?; + write!(f, "\n{e},")?; } write!(f, "]") } diff --git a/src/future/race_ok/tuple.rs b/src/future/race_ok/tuple.rs index 8ab855e..b54acab 100644 --- a/src/future/race_ok/tuple.rs +++ b/src/future/race_ok/tuple.rs @@ -51,7 +51,7 @@ macro_rules! impl_race_ok_tuple { write!(f, "multiple errors occurred: [")?; let ($($E,)+) = &self.errors; $( - write!(f, "{}", $E)?; + write!(f, "\n{},", $E)?; )+ write!(f, "]") } diff --git a/src/future/race_ok/vec.rs b/src/future/race_ok/vec.rs index 80adc63..9438c45 100644 --- a/src/future/race_ok/vec.rs +++ b/src/future/race_ok/vec.rs @@ -58,7 +58,7 @@ mod err { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "multiple errors occurred: [")?; for e in self.errors.iter() { - write!(f, "\n{}", e)?; + write!(f, "\n{e},")?; } write!(f, "]") } From e98bc03b3c68577bafb9bc22fca3e3a5ff06307d Mon Sep 17 00:00:00 2001 From: Wisha Wa Date: Mon, 30 Jan 2023 17:46:03 +0000 Subject: [PATCH 10/10] use ControlFlow instead of Result for future::common. ControlFlow Continue/Break captures the Keep/EarlyReturn API in future::common better than Ok/Err. --- src/future/common/array.rs | 15 ++++++++------- src/future/common/tuple.rs | 11 ++++++----- src/future/common/vec.rs | 7 ++++--- src/future/join/array.rs | 6 ++++-- src/future/join/tuple.rs | 5 +++-- src/future/join/vec.rs | 5 +++-- src/future/race/array.rs | 6 ++++-- src/future/race/tuple.rs | 6 +++--- src/future/race/vec.rs | 5 +++-- src/future/race_ok/array.rs | 9 ++++++--- src/future/race_ok/tuple.rs | 10 ++++------ src/future/race_ok/vec.rs | 7 ++++--- src/future/try_join/array.rs | 10 +++++++--- src/future/try_join/tuple.rs | 7 ++++--- src/future/try_join/vec.rs | 7 ++++--- 15 files changed, 67 insertions(+), 49 deletions(-) diff --git a/src/future/common/array.rs b/src/future/common/array.rs index 99e3a07..ff71236 100644 --- a/src/future/common/array.rs +++ b/src/future/common/array.rs @@ -5,6 +5,7 @@ use core::fmt; use core::future::Future; use core::marker::PhantomData; use core::mem::MaybeUninit; +use core::ops::ControlFlow; use core::pin::Pin; use core::task::{Context, Poll}; @@ -30,15 +31,15 @@ where type StoredItem; /// Takes the output of a subfuture and decide what to do with it. - /// If this function returns Err(output), the combinator would early return Poll::Ready(output). - /// For Ok(item), the combinator would keep the item in an array. + /// If this function returns ControlFlow::Break(output), the combinator would early return Poll::Ready(output). + /// For ControlFlow::Continue(item), the combinator would keep the item in an array. /// If by the end, all items are kept (no early return made), /// then `when_completed` will be called on the items array. /// /// Example: - /// Join will always wrap the output in Ok because it want to wait until all outputs are ready. - /// Race will always wrap the output in Err because it want to early return with the first output. - fn maybe_return(idx: usize, res: Fut::Output) -> Result; + /// Join will always wrap the output in ControlFlow::Continue because it want to wait until all outputs are ready. + /// Race will always wrap the output in ControlFlow::Break because it want to early return with the first output. + fn maybe_return(idx: usize, res: Fut::Output) -> ControlFlow; /// Called when all subfutures are completed and none caused the combinator to return early. /// The argument is an array of the kept item from each subfuture. @@ -145,13 +146,13 @@ where if let Poll::Ready(value) = fut.poll(&mut cx) { match B::maybe_return(idx, value) { // Keep the item for returning once every subfuture is done. - Ok(store) => { + ControlFlow::Continue(store) => { this.items[idx].write(store); *filled = true; *this.pending -= 1; } // Early return. - Err(ret) => return Poll::Ready(ret), + ControlFlow::Break(ret) => return Poll::Ready(ret), } } } diff --git a/src/future/common/tuple.rs b/src/future/common/tuple.rs index d35f908..deffba5 100644 --- a/src/future/common/tuple.rs +++ b/src/future/common/tuple.rs @@ -4,6 +4,7 @@ use core::fmt::{self, Debug}; use core::future::Future; use core::marker::PhantomData; use core::mem::MaybeUninit; +use core::ops::ControlFlow; use core::pin::Pin; use core::task::{Context, Poll}; @@ -25,9 +26,9 @@ pub trait TupleMaybeReturn { /// The type of the item to store for this subfuture. type StoredItem; /// Take the return value of a subfuture and decide whether to store it or early return. - /// Ok(v) = store v. - /// Err(o) = early return o. - fn maybe_return(idx: usize, res: R) -> Result; + /// ControlFlow::Continue(v) = store v. + /// ControlFlow::Break(o) = early return o. + fn maybe_return(idx: usize, res: R) -> ControlFlow; } /// This and [TupleMaybeReturn] takes the role of [super::array::CombinatorBehaviorArray] but for tuples. /// Type parameters: @@ -159,10 +160,10 @@ macro_rules! impl_common_tuple { $idx => { if let Poll::Ready(value) = futures.$F.as_mut().poll(&mut cx) { match B::maybe_return($idx, value) { - Err(ret) => { + ControlFlow::Break(ret) => { return Poll::Ready(ret); }, - Ok(store) => { + ControlFlow::Continue(store) => { this.items.$idx.write(store); true } diff --git a/src/future/common/vec.rs b/src/future/common/vec.rs index d91a482..d15019d 100644 --- a/src/future/common/vec.rs +++ b/src/future/common/vec.rs @@ -3,6 +3,7 @@ use crate::utils::{self, WakerVec}; use core::fmt; use core::future::Future; use core::mem::MaybeUninit; +use core::ops::ControlFlow; use core::pin::Pin; use core::task::{Context, Poll}; use std::vec::Vec; @@ -20,7 +21,7 @@ where { type Output; type StoredItem; - fn maybe_return(idx: usize, res: Fut::Output) -> Result; + fn maybe_return(idx: usize, res: Fut::Output) -> ControlFlow; fn when_completed(vec: Vec) -> Self::Output; } @@ -103,12 +104,12 @@ where let mut cx = Context::from_waker(this.wakers.get(idx).unwrap()); if let Poll::Ready(value) = fut.poll(&mut cx) { match B::maybe_return(idx, value) { - Ok(store) => { + ControlFlow::Continue(store) => { this.items[idx].write(store); this.filled.set(idx, true); *this.pending -= 1; } - Err(ret) => { + ControlFlow::Break(ret) => { return Poll::Ready(ret); } } diff --git a/src/future/join/array.rs b/src/future/join/array.rs index 53ff3d6..d0de047 100644 --- a/src/future/join/array.rs +++ b/src/future/join/array.rs @@ -2,6 +2,7 @@ use super::super::common::{CombinatorArray, CombinatorBehaviorArray}; use super::{Join as JoinTrait, JoinBehavior}; use core::future::{Future, IntoFuture}; +use core::ops::ControlFlow; /// Waits for two similarly-typed futures to complete. /// @@ -23,8 +24,9 @@ where fn maybe_return( _idx: usize, res: ::Output, - ) -> Result { - Ok(res) + ) -> ControlFlow { + // Continue with other subfutures + ControlFlow::Continue(res) } fn when_completed(arr: [Self::StoredItem; N]) -> Self::Output { diff --git a/src/future/join/tuple.rs b/src/future/join/tuple.rs index 7893813..6fb68b4 100644 --- a/src/future/join/tuple.rs +++ b/src/future/join/tuple.rs @@ -3,11 +3,12 @@ use super::{Join as JoinTrait, JoinBehavior}; use core::future::IntoFuture; use core::marker::PhantomData; +use core::ops::ControlFlow; impl TupleMaybeReturn for JoinBehavior { type StoredItem = T; - fn maybe_return(_: usize, res: T) -> Result { - Ok(res) + fn maybe_return(_: usize, res: T) -> ControlFlow { + ControlFlow::Continue(res) } } impl TupleWhenCompleted for JoinBehavior { diff --git a/src/future/join/vec.rs b/src/future/join/vec.rs index 1310ebf..fcdf8a9 100644 --- a/src/future/join/vec.rs +++ b/src/future/join/vec.rs @@ -2,6 +2,7 @@ use super::super::common::{CombinatorBehaviorVec, CombinatorVec}; use super::{Join as JoinTrait, JoinBehavior}; use core::future::{Future, IntoFuture}; +use core::ops::ControlFlow; use std::vec::Vec; /// Waits for two similarly-typed futures to complete. @@ -24,8 +25,8 @@ where fn maybe_return( _idx: usize, res: ::Output, - ) -> Result { - Ok(res) + ) -> ControlFlow { + ControlFlow::Continue(res) } fn when_completed(vec: Vec) -> Self::Output { diff --git a/src/future/race/array.rs b/src/future/race/array.rs index 8cd225b..f7716ec 100644 --- a/src/future/race/array.rs +++ b/src/future/race/array.rs @@ -2,6 +2,7 @@ use super::super::common::{CombinatorArray, CombinatorBehaviorArray}; use super::{Race as RaceTrait, RaceBehavior}; use core::future::{Future, IntoFuture}; +use core::ops::ControlFlow; /// Wait for the first future to complete. /// @@ -23,8 +24,9 @@ where fn maybe_return( _idx: usize, res: ::Output, - ) -> Result { - Err(res) + ) -> ControlFlow { + // Subfuture finished, so the race is over. Break now. + ControlFlow::Break(res) } fn when_completed(_arr: [Self::StoredItem; N]) -> Self::Output { diff --git a/src/future/race/tuple.rs b/src/future/race/tuple.rs index 26e7602..dcd768e 100644 --- a/src/future/race/tuple.rs +++ b/src/future/race/tuple.rs @@ -4,14 +4,14 @@ use super::{Race as RaceTrait, RaceBehavior}; use core::convert::Infallible; use core::future::{Future, IntoFuture}; use core::marker::PhantomData; +use core::ops::ControlFlow; impl TupleMaybeReturn for RaceBehavior { // We early return as soon as any subfuture finishes. // Results from subfutures are never stored. type StoredItem = Infallible; - fn maybe_return(_: usize, res: T) -> Result { - // Err = early return. - Err(res) + fn maybe_return(_: usize, res: T) -> ControlFlow { + ControlFlow::Break(res) } } impl TupleWhenCompleted for RaceBehavior { diff --git a/src/future/race/vec.rs b/src/future/race/vec.rs index 327fa3f..8534a5c 100644 --- a/src/future/race/vec.rs +++ b/src/future/race/vec.rs @@ -2,6 +2,7 @@ use super::super::common::{CombinatorBehaviorVec, CombinatorVec}; use super::{Race as RaceTrait, RaceBehavior}; use core::future::{Future, IntoFuture}; +use core::ops::ControlFlow; /// Wait for the first future to complete. /// @@ -23,8 +24,8 @@ where fn maybe_return( _idx: usize, res: ::Output, - ) -> Result { - Err(res) + ) -> ControlFlow { + ControlFlow::Break(res) } fn when_completed(_vec: Vec) -> Self::Output { diff --git a/src/future/race_ok/array.rs b/src/future/race_ok/array.rs index e0789df..d5d7c34 100644 --- a/src/future/race_ok/array.rs +++ b/src/future/race_ok/array.rs @@ -3,6 +3,7 @@ use super::error::AggregateError; use super::{RaceOk as RaceOkTrait, RaceOkBehavior}; use core::future::{Future, IntoFuture}; +use core::ops::ControlFlow; /// Wait for the first successful future to complete. /// @@ -24,10 +25,12 @@ where fn maybe_return( _idx: usize, res: ::Output, - ) -> Result { + ) -> ControlFlow { match res { - Ok(v) => Err(Ok(v)), - Err(e) => Ok(e), + // Got an Ok result. Break now. + Ok(v) => ControlFlow::Break(Ok(v)), + // Err result. Continue polling other subfutures. + Err(e) => ControlFlow::Continue(e), } } diff --git a/src/future/race_ok/tuple.rs b/src/future/race_ok/tuple.rs index b54acab..d9c8895 100644 --- a/src/future/race_ok/tuple.rs +++ b/src/future/race_ok/tuple.rs @@ -4,17 +4,15 @@ use super::{RaceOk as RaceOkTrait, RaceOkBehavior}; use core::future::IntoFuture; use core::marker::PhantomData; +use core::ops::ControlFlow; use std::{error::Error, fmt::Display}; impl TupleMaybeReturn, Result> for RaceOkBehavior { type StoredItem = E; - fn maybe_return(_: usize, res: Result) -> Result> { + fn maybe_return(_: usize, res: Result) -> ControlFlow, Self::StoredItem> { match res { - // If subfuture returns Ok we want to early return from the combinator. - // We do this by returning Err to the combinator. - Ok(t) => Err(Ok(t)), - // If subfuture returns Err, we keep the error for potential use in AggregateError. - Err(e) => Ok(e), + Ok(t) => ControlFlow::Break(Ok(t)), + Err(e) => ControlFlow::Continue(e), } } } diff --git a/src/future/race_ok/vec.rs b/src/future/race_ok/vec.rs index 9438c45..b636d3c 100644 --- a/src/future/race_ok/vec.rs +++ b/src/future/race_ok/vec.rs @@ -3,6 +3,7 @@ use super::error::AggregateError; use super::{RaceOk as RaceOkTrait, RaceOkBehavior}; use core::future::{Future, IntoFuture}; +use core::ops::ControlFlow; use std::vec::Vec; /// Wait for the first successful future to complete. @@ -25,10 +26,10 @@ where fn maybe_return( _idx: usize, res: ::Output, - ) -> Result { + ) -> ControlFlow { match res { - Ok(v) => Err(Ok(v)), - Err(e) => Ok(e), + Ok(v) => ControlFlow::Break(Ok(v)), + Err(e) => ControlFlow::Continue(e), } } diff --git a/src/future/try_join/array.rs b/src/future/try_join/array.rs index 10fea2f..ba3499d 100644 --- a/src/future/try_join/array.rs +++ b/src/future/try_join/array.rs @@ -1,6 +1,8 @@ use super::super::common::{CombinatorArray, CombinatorBehaviorArray}; use super::{TryJoin as TryJoinTrait, TryJoinBehavior}; + use core::future::{Future, IntoFuture}; +use core::ops::ControlFlow; /// Wait for all futures to complete successfully, or abort early on error. /// @@ -22,10 +24,12 @@ where fn maybe_return( _idx: usize, res: ::Output, - ) -> Result { + ) -> ControlFlow { match res { - Ok(v) => Ok(v), - Err(e) => Err(Err(e)), + // Got an Ok result. Keep it. + Ok(v) => ControlFlow::Continue(v), + // An error happended. Break now. + Err(e) => ControlFlow::Break(Err(e)), } } diff --git a/src/future/try_join/tuple.rs b/src/future/try_join/tuple.rs index a018127..c4ca244 100644 --- a/src/future/try_join/tuple.rs +++ b/src/future/try_join/tuple.rs @@ -3,15 +3,16 @@ use super::{TryJoin as TryJoinTrait, TryJoinBehavior}; use core::future::IntoFuture; use core::marker::PhantomData; +use core::ops::ControlFlow; use futures_core::TryFuture; impl TupleMaybeReturn, Result> for TryJoinBehavior { type StoredItem = T; - fn maybe_return(_: usize, res: Result) -> Result> { + fn maybe_return(_: usize, res: Result) -> ControlFlow, Self::StoredItem> { match res { - Ok(t) => Ok(t), - Err(e) => Err(Err(e)), + Ok(t) => ControlFlow::Continue(t), + Err(e) => ControlFlow::Break(Err(e)), } } } diff --git a/src/future/try_join/vec.rs b/src/future/try_join/vec.rs index e5a6a51..c625bf9 100644 --- a/src/future/try_join/vec.rs +++ b/src/future/try_join/vec.rs @@ -2,6 +2,7 @@ use super::super::common::{CombinatorBehaviorVec, CombinatorVec}; use super::{TryJoin as TryJoinTrait, TryJoinBehavior}; use core::future::{Future, IntoFuture}; +use core::ops::ControlFlow; use std::vec::Vec; /// Wait for all futures to complete successfully, or abort early on error. @@ -24,10 +25,10 @@ where fn maybe_return( _idx: usize, res: ::Output, - ) -> Result { + ) -> ControlFlow { match res { - Ok(v) => Ok(v), - Err(e) => Err(Err(e)), + Ok(v) => ControlFlow::Continue(v), + Err(e) => ControlFlow::Break(Err(e)), } }