diff --git a/Cargo.toml b/Cargo.toml index 41504a6..dc2257d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,13 +6,16 @@ repository = "https://github.com/yoshuawuyts/futures-concurrency" documentation = "https://docs.rs/futures-concurrency" description = "Structured concurrency operations for async Rust" readme = "README.md" -edition = "2018" +edition = "2021" keywords = [] categories = [] authors = [ "Yoshua Wuyts " ] +[profile.bench] +debug = true + [lib] bench = false @@ -28,6 +31,8 @@ harness = false bitvec = { version = "1.0.1", default-features = false, features = ["alloc"] } futures-core = "0.3" pin-project = "1.0.8" +slab = "0.4.8" +smallvec = "1.11.0" [dev-dependencies] futures = "0.3.25" diff --git a/benches/bench.rs b/benches/bench.rs index 92e680c..d2ab343 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -2,7 +2,70 @@ mod utils; -criterion::criterion_main!(merge::merge_benches, join::join_benches, race::race_benches); +// #[global_allocator] +// static ALLOC: dhat::Alloc = dhat::Alloc; + +fn main() { + // let _profiler = dhat::Profiler::new_heap(); + criterion::criterion_main!( + merge::merge_benches, + join::join_benches, + race::race_benches, + stream_group::stream_group_benches + ); + main() +} + +mod stream_group { + use criterion::async_executor::FuturesExecutor; + use criterion::{black_box, criterion_group, BatchSize, BenchmarkId, Criterion}; + use futures::stream::SelectAll; + use futures_concurrency::stream::StreamGroup; + use futures_lite::prelude::*; + + use crate::utils::{make_select_all, make_stream_group}; + criterion_group! { + name = stream_group_benches; + // This can be any expression that returns a `Criterion` object. + config = Criterion::default(); + targets = stream_set_bench + } + + fn stream_set_bench(c: &mut Criterion) { + let mut group = c.benchmark_group("stream_group"); + for i in [10, 100, 1000].iter() { + group.bench_with_input(BenchmarkId::new("StreamGroup", i), i, |b, i| { + let setup = || make_stream_group(*i); + let routine = |mut group: StreamGroup<_>| async move { + let mut counter = 0; + black_box({ + while group.next().await.is_some() { + counter += 1; + } + assert_eq!(counter, *i); + }); + }; + b.to_async(FuturesExecutor) + .iter_batched(setup, routine, BatchSize::SmallInput) + }); + group.bench_with_input(BenchmarkId::new("SelectAll", i), i, |b, i| { + let setup = || make_select_all(*i); + let routine = |mut group: SelectAll<_>| async move { + let mut counter = 0; + black_box({ + while group.next().await.is_some() { + counter += 1; + } + assert_eq!(counter, *i); + }); + }; + b.to_async(FuturesExecutor) + .iter_batched(setup, routine, BatchSize::SmallInput) + }); + } + group.finish(); + } +} mod merge { use criterion::async_executor::FuturesExecutor; @@ -17,7 +80,7 @@ mod merge { merge_benches, vec_merge_bench, array_merge_bench, - tuple_merge_bench + tuple_merge_bench, ); fn vec_merge_bench(c: &mut Criterion) { diff --git a/benches/utils/countdown_streams.rs b/benches/utils/countdown_streams.rs index f090f36..29dd91b 100644 --- a/benches/utils/countdown_streams.rs +++ b/benches/utils/countdown_streams.rs @@ -1,3 +1,4 @@ +use futures_concurrency::stream::StreamGroup; use futures_core::Stream; use std::cell::{Cell, RefCell}; @@ -19,6 +20,24 @@ pub fn streams_vec(len: usize) -> Vec { streams } +#[allow(unused)] +pub fn make_stream_group(len: usize) -> StreamGroup { + let wakers = Rc::new(RefCell::new(BinaryHeap::new())); + let completed = Rc::new(Cell::new(0)); + (0..len) + .map(|n| CountdownStream::new(n, len, wakers.clone(), completed.clone())) + .collect() +} + +#[allow(unused)] +pub fn make_select_all(len: usize) -> futures::stream::SelectAll { + let wakers = Rc::new(RefCell::new(BinaryHeap::new())); + let completed = Rc::new(Cell::new(0)); + (0..len) + .map(|n| CountdownStream::new(n, len, wakers.clone(), completed.clone())) + .collect() +} + pub fn streams_array() -> [CountdownStream; N] { let wakers = Rc::new(RefCell::new(BinaryHeap::new())); let completed = Rc::new(Cell::new(0)); diff --git a/src/future/join/array.rs b/src/future/join/array.rs index 0414da0..4db6447 100644 --- a/src/future/join/array.rs +++ b/src/future/join/array.rs @@ -50,7 +50,7 @@ where pending: N, items: OutputArray::uninit(), wakers: WakerArray::new(), - state: PollArray::new(), + state: PollArray::new_pending(), futures: FutureArray::new(futures), } } @@ -138,7 +138,7 @@ where state.is_ready(), "Future should have reached a `Ready` state" ); - state.set_consumed(); + state.set_none(); } // SAFETY: we've checked with the state that all of our outputs have been @@ -202,6 +202,6 @@ mod test { let waker = Arc::new(DummyWaker()).into(); let mut cx = Context::from_waker(&waker); let _ = fut.as_mut().poll(&mut cx); - assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]"); + assert_eq!(format!("{:?}", fut), "[None, None]"); } } diff --git a/src/future/join/tuple.rs b/src/future/join/tuple.rs index 3727314..52e3001 100644 --- a/src/future/join/tuple.rs +++ b/src/future/join/tuple.rs @@ -60,7 +60,7 @@ macro_rules! drop_initialized_values { // SAFETY: we've just filtered down to *only* the initialized values. // We can assume they're initialized, and this is where we drop them. unsafe { $output.assume_init_drop() }; - $states[$state_idx].set_consumed(); + $states[$state_idx].set_none(); } drop_initialized_values!(@drop $($rem_outs,)* | $states, $($rem_idx,)*); }; @@ -267,7 +267,7 @@ macro_rules! impl_join_tuple { let ($($F,)+): ($($F,)+) = self; $StructName { futures: $mod_name::Futures {$($F: ManuallyDrop::new($F.into_future()),)+}, - state: PollArray::new(), + state: PollArray::new_pending(), outputs: ($(MaybeUninit::<$F::Output>::uninit(),)+), wakers: WakerArray::new(), completed: 0, diff --git a/src/future/join/vec.rs b/src/future/join/vec.rs index a995987..e847832 100644 --- a/src/future/join/vec.rs +++ b/src/future/join/vec.rs @@ -44,7 +44,7 @@ where pending: len, items: OutputVec::uninit(len), wakers: WakerVec::new(len), - state: PollVec::new(len), + state: PollVec::new_pending(len), futures: FutureVec::new(futures), } } @@ -132,7 +132,7 @@ where state.is_ready(), "Future should have reached a `Ready` state" ); - state.set_consumed(); + state.set_none(); }); // SAFETY: we've checked with the state that all of our outputs have been @@ -196,6 +196,6 @@ mod test { let waker = Arc::new(DummyWaker()).into(); let mut cx = Context::from_waker(&waker); let _ = fut.as_mut().poll(&mut cx); - assert_eq!(format!("{:?}", fut), "[Consumed, Consumed]"); + assert_eq!(format!("{:?}", fut), "[None, None]"); } } diff --git a/src/future/race_ok/tuple/mod.rs b/src/future/race_ok/tuple/mod.rs index 18c59e3..e7f47a6 100644 --- a/src/future/race_ok/tuple/mod.rs +++ b/src/future/race_ok/tuple/mod.rs @@ -72,7 +72,7 @@ macro_rules! impl_race_ok_tuple { done: false, indexer: utils::Indexer::new($StructName), errors: array::from_fn(|_| MaybeUninit::uninit()), - errors_states: PollArray::new(), + errors_states: PollArray::new_pending(), $($F: $F.into_future()),* } } @@ -154,7 +154,7 @@ macro_rules! impl_race_ok_tuple { .for_each(|(st, err)| { // SAFETY: we've filtered down to only the `ready`/initialized data unsafe { err.assume_init_drop() }; - st.set_consumed(); + st.set_none(); }); } } diff --git a/src/future/try_join/array.rs b/src/future/try_join/array.rs index 7440b65..5406f74 100644 --- a/src/future/try_join/array.rs +++ b/src/future/try_join/array.rs @@ -50,7 +50,7 @@ where pending: N, items: OutputArray::uninit(), wakers: WakerArray::new(), - state: PollArray::new(), + state: PollArray::new_pending(), futures: FutureArray::new(futures), } } @@ -147,7 +147,7 @@ where state.is_ready(), "Future should have reached a `Ready` state" ); - state.set_consumed(); + state.set_none(); } // SAFETY: we've checked with the state that all of our outputs have been diff --git a/src/future/try_join/tuple.rs b/src/future/try_join/tuple.rs index 36f28e4..2a8e941 100644 --- a/src/future/try_join/tuple.rs +++ b/src/future/try_join/tuple.rs @@ -70,7 +70,7 @@ macro_rules! drop_initialized_values { // SAFETY: we've just filtered down to *only* the initialized values. // We can assume they're initialized, and this is where we drop them. unsafe { $output.assume_init_drop() }; - $states[$state_idx].set_consumed(); + $states[$state_idx].set_none(); } drop_initialized_values!(@drop $($rem_outs,)* | $states, $($rem_idx,)*); }; @@ -289,7 +289,7 @@ macro_rules! impl_try_join_tuple { futures: $mod_name::Futures {$( $F: ManuallyDrop::new($F.into_future()), )+}, - state: PollArray::new(), + state: PollArray::new_pending(), outputs: ($(MaybeUninit::<$T>::uninit(),)+), wakers: WakerArray::new(), completed: 0, diff --git a/src/future/try_join/vec.rs b/src/future/try_join/vec.rs index 07207d9..9c2f800 100644 --- a/src/future/try_join/vec.rs +++ b/src/future/try_join/vec.rs @@ -51,7 +51,7 @@ where pending: len, items: OutputVec::uninit(len), wakers: WakerVec::new(len), - state: PollVec::new(len), + state: PollVec::new_pending(len), futures: FutureVec::new(futures), } } @@ -148,7 +148,7 @@ where state.is_ready(), "Future should have reached a `Ready` state" ); - state.set_consumed(); + state.set_none(); } // SAFETY: we've checked with the state that all of our outputs have been diff --git a/src/stream/merge/array.rs b/src/stream/merge/array.rs index 8a5fe39..2501b91 100644 --- a/src/stream/merge/array.rs +++ b/src/stream/merge/array.rs @@ -37,7 +37,7 @@ where streams, indexer: Indexer::new(N), wakers: WakerArray::new(), - state: PollArray::new(), + state: PollArray::new_pending(), complete: 0, done: false, } @@ -72,7 +72,7 @@ where if !readiness.any_ready() { // Nothing is ready yet return Poll::Pending; - } else if !readiness.clear_ready(index) || this.state[index].is_consumed() { + } else if !readiness.clear_ready(index) || this.state[index].is_none() { continue; } @@ -91,7 +91,7 @@ where } Poll::Ready(None) => { *this.complete += 1; - this.state[index].set_consumed(); + this.state[index].set_none(); if *this.complete == this.streams.len() { return Poll::Ready(None); } diff --git a/src/stream/merge/tuple.rs b/src/stream/merge/tuple.rs index f9f9177..834de90 100644 --- a/src/stream/merge/tuple.rs +++ b/src/stream/merge/tuple.rs @@ -23,7 +23,7 @@ macro_rules! poll_stream { } Poll::Ready(None) => { *$this.completed += 1; - $this.state[$stream_idx].set_consumed(); + $this.state[$stream_idx].set_none(); if *$this.completed == $len_streams { return Poll::Ready(None); } @@ -132,7 +132,7 @@ macro_rules! impl_merge_tuple { if !readiness.any_ready() { // Nothing is ready yet return Poll::Pending; - } else if !readiness.clear_ready(index) || this.state[index].is_consumed() { + } else if !readiness.clear_ready(index) || this.state[index].is_none() { continue; } @@ -175,7 +175,7 @@ macro_rules! impl_merge_tuple { streams: $mod_name::Streams { $($F: $F.into_stream()),+ }, indexer: utils::Indexer::new(utils::tuple_len!($($F,)*)), wakers: WakerArray::new(), - state: PollArray::new(), + state: PollArray::new_pending(), completed: 0, } } diff --git a/src/stream/merge/vec.rs b/src/stream/merge/vec.rs index 80869b4..27508ee 100644 --- a/src/stream/merge/vec.rs +++ b/src/stream/merge/vec.rs @@ -36,7 +36,7 @@ where let len = streams.len(); Self { wakers: WakerVec::new(len), - state: PollVec::new(len), + state: PollVec::new_pending(len), indexer: Indexer::new(len), streams, complete: 0, @@ -73,7 +73,7 @@ where if !readiness.any_ready() { // Nothing is ready yet return Poll::Pending; - } else if !readiness.clear_ready(index) || this.state[index].is_consumed() { + } else if !readiness.clear_ready(index) || this.state[index].is_none() { continue; } @@ -92,7 +92,7 @@ where } Poll::Ready(None) => { *this.complete += 1; - this.state[index].set_consumed(); + this.state[index].set_none(); if *this.complete == this.streams.len() { return Poll::Ready(None); } diff --git a/src/stream/mod.rs b/src/stream/mod.rs index ef6a374..9bdc336 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -51,8 +51,13 @@ pub use chain::Chain; pub use into_stream::IntoStream; pub use merge::Merge; pub use stream_ext::StreamExt; +#[doc(inline)] +pub use stream_group::StreamGroup; pub use zip::Zip; +/// A growable group of streams which act as a single unit. +pub mod stream_group; + pub(crate) mod chain; mod into_stream; pub(crate) mod merge; diff --git a/src/stream/stream_group.rs b/src/stream/stream_group.rs new file mode 100644 index 0000000..a45cdf7 --- /dev/null +++ b/src/stream/stream_group.rs @@ -0,0 +1,425 @@ +use futures_core::Stream; +use slab::Slab; +use smallvec::{smallvec, SmallVec}; +use std::collections::BTreeSet; +use std::fmt::{self, Debug}; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::utils::{PollState, PollVec, WakerVec}; + +/// A growable group of streams which act as a single unit. +/// +/// In order go mutate the group during iteration, the stream should be +/// combined with a mechanism such as +/// [`lend_mut`](https://docs.rs/async-iterator/latest/async_iterator/trait.Iterator.html#method.lend_mut). +/// This is not yet provided by the `futures-concurrency` crate. +/// +/// # Example +/// +/// ```rust +/// use futures_concurrency::stream::StreamGroup; +/// use futures_lite::{stream, StreamExt}; +/// +/// # futures_lite::future::block_on(async { +/// let mut group = StreamGroup::new(); +/// group.insert(stream::once(2)); +/// group.insert(stream::once(4)); +/// +/// let mut out = 0; +/// while let Some(num) = group.next().await { +/// out += num; +/// } +/// assert_eq!(out, 6); +/// # }); +/// ``` +#[must_use = "`StreamGroup` does nothing if not iterated over"] +#[derive(Default)] +#[pin_project::pin_project] +pub struct StreamGroup { + #[pin] + streams: Slab, + wakers: WakerVec, + states: PollVec, + keys: BTreeSet, + key_removal_queue: SmallVec<[usize; 10]>, +} + +impl Debug for StreamGroup { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("StreamGroup") + .field("slab", &"[..]") + .finish() + } +} + +impl StreamGroup { + /// Create a new instance of `StreamGroup`. + /// + /// # Example + /// + /// ```rust + /// use futures_concurrency::stream::StreamGroup; + /// + /// let group = StreamGroup::new(); + /// # let group: StreamGroup = group; + /// ``` + pub fn new() -> Self { + Self::with_capacity(0) + } + + /// Create a new instance of `StreamGroup` with a given capacity. + /// + /// # Example + /// + /// ```rust + /// use futures_concurrency::stream::StreamGroup; + /// + /// let group = StreamGroup::with_capacity(2); + /// # let group: StreamGroup = group; + /// ``` + pub fn with_capacity(capacity: usize) -> Self { + Self { + streams: Slab::with_capacity(capacity), + wakers: WakerVec::new(capacity), + states: PollVec::new(capacity), + keys: BTreeSet::new(), + key_removal_queue: smallvec![], + } + } + + /// Return the number of futures currently active in the group. + /// + /// # Example + /// + /// ```rust + /// use futures_concurrency::stream::StreamGroup; + /// use futures_lite::stream; + /// + /// let mut group = StreamGroup::with_capacity(2); + /// assert_eq!(group.len(), 0); + /// group.insert(stream::once(12)); + /// assert_eq!(group.len(), 1); + /// ``` + pub fn len(&self) -> usize { + self.streams.len() + } + + /// Return the capacity of the `StreamGroup`. + /// + /// # Example + /// + /// ```rust + /// use futures_concurrency::stream::StreamGroup; + /// use futures_lite::stream; + /// + /// let group = StreamGroup::with_capacity(2); + /// assert_eq!(group.capacity(), 2); + /// # let group: StreamGroup = group; + /// ``` + pub fn capacity(&self) -> usize { + self.streams.capacity() + } + + /// Returns true if there are no futures currently active in the group. + /// + /// # Example + /// + /// ```rust + /// use futures_concurrency::stream::StreamGroup; + /// use futures_lite::stream; + /// + /// let mut group = StreamGroup::with_capacity(2); + /// assert!(group.is_empty()); + /// group.insert(stream::once(12)); + /// assert!(!group.is_empty()); + /// ``` + pub fn is_empty(&self) -> bool { + self.streams.is_empty() + } + + /// Insert a new future into the group. + /// + /// # Example + /// + /// ```rust + /// use futures_concurrency::stream::StreamGroup; + /// use futures_lite::stream; + /// + /// let mut group = StreamGroup::with_capacity(2); + /// group.insert(stream::once(12)); + /// ``` + pub fn insert(&mut self, stream: S) -> Key + where + S: Stream, + { + let index = self.streams.insert(stream); + self.keys.insert(index); + let key = Key(index); + + // If our slab allocated more space we need to + // update our tracking structures along with it. + let max_len = self.capacity().max(index); + self.wakers.resize(max_len); + self.states.resize(max_len); + + // Set the corresponding state + self.states[index].set_pending(); + + key + } + + /// Removes a stream from the group. Returns whether the value was present in + /// the group. + /// + /// # Example + /// + /// ``` + /// use futures_lite::stream; + /// use futures_concurrency::stream::StreamGroup; + /// + /// # futures_lite::future::block_on(async { + /// let mut group = StreamGroup::new(); + /// let key = group.insert(stream::once(4)); + /// assert_eq!(group.len(), 1); + /// group.remove(key); + /// assert_eq!(group.len(), 0); + /// # }) + /// ``` + pub fn remove(&mut self, key: Key) -> bool { + let is_present = self.keys.remove(&key.0); + if is_present { + self.states[key.0].set_none(); + self.streams.remove(key.0); + } + is_present + } + + /// Returns `true` if the `StreamGroup` contains a value for the specified key. + /// + /// # Example + /// + /// ``` + /// use futures_lite::stream; + /// use futures_concurrency::stream::StreamGroup; + /// + /// # futures_lite::future::block_on(async { + /// let mut group = StreamGroup::new(); + /// let key = group.insert(stream::once(4)); + /// assert!(group.contains_key(key)); + /// group.remove(key); + /// assert!(!group.contains_key(key)); + /// # }) + /// ``` + pub fn contains_key(&mut self, key: Key) -> bool { + self.keys.contains(&key.0) + } +} + +impl StreamGroup { + /// Create a stream which also yields the key of each item. + /// + /// # Example + /// + /// ```rust + /// use futures_concurrency::stream::StreamGroup; + /// use futures_lite::{stream, StreamExt}; + /// + /// # futures_lite::future::block_on(async { + /// let mut group = StreamGroup::new(); + /// group.insert(stream::once(2)); + /// group.insert(stream::once(4)); + /// + /// let mut out = 0; + /// let mut group = group.keyed(); + /// while let Some((_key, num)) = group.next().await { + /// out += num; + /// } + /// assert_eq!(out, 6); + /// # }); + /// ``` + pub fn keyed(self) -> Keyed { + Keyed { group: self } + } +} + +impl StreamGroup { + fn poll_next_inner( + self: Pin<&mut Self>, + cx: &std::task::Context<'_>, + ) -> Poll::Item)>> { + let mut this = self.project(); + + // Short-circuit if we have no streams to iterate over + if this.streams.is_empty() { + return Poll::Ready(None); + } + + // Set the top-level waker and check readiness + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; + } + + // Setup our stream state + let mut ret = Poll::Pending; + let mut done_count = 0; + let stream_count = this.streams.len(); + let states = this.states; + + // SAFETY: We unpin the stream set so we can later individually access + // single streams. Either to read from them or to drop them. + let streams = unsafe { this.streams.as_mut().get_unchecked_mut() }; + + for index in this.keys.iter().cloned() { + if states[index].is_pending() && readiness.clear_ready(index) { + // unlock readiness so we don't deadlock when polling + drop(readiness); + + // Obtain the intermediate waker. + let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); + + // SAFETY: this stream here is a projection from the streams + // vec, which we're reading from. + let stream = unsafe { Pin::new_unchecked(&mut streams[index]) }; + match stream.poll_next(&mut cx) { + Poll::Ready(Some(item)) => { + // Set the return type for the function + ret = Poll::Ready(Some((Key(index), item))); + + // We just obtained an item from this index, make sure + // we check it again on a next iteration + states[index] = PollState::Pending; + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_ready(index); + + break; + } + Poll::Ready(None) => { + // A stream has ended, make note of that + done_count += 1; + + // Remove all associated data about the stream. + // The only data we can't remove directly is the key entry. + states[index] = PollState::None; + streams.remove(index); + this.key_removal_queue.push(index); + } + // Keep looping if there is nothing for us to do + Poll::Pending => {} + }; + + // Lock readiness so we can use it again + readiness = this.wakers.readiness().lock().unwrap(); + } + } + + // Now that we're no longer borrowing `this.keys` we can loop over + // which items we need to remove + if !this.key_removal_queue.is_empty() { + for key in this.key_removal_queue.iter() { + this.keys.remove(key); + } + this.key_removal_queue.clear(); + } + + // If all streams turned up with `Poll::Ready(None)` our + // stream should return that + if done_count == stream_count { + ret = Poll::Ready(None); + } + + ret + } +} + +impl Stream for StreamGroup { + type Item = ::Item; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match self.poll_next_inner(cx) { + Poll::Ready(Some((_key, item))) => Poll::Ready(Some(item)), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl FromIterator for StreamGroup { + fn from_iter>(iter: T) -> Self { + let iter = iter.into_iter(); + let len = iter.size_hint().1.unwrap_or_default(); + let mut this = Self::with_capacity(len); + for stream in iter { + this.insert(stream); + } + this + } +} + +/// A key used to index into the `StreamGroup` type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Key(usize); + +/// Iterate over items in the stream group with their associated keys. +#[derive(Debug)] +#[pin_project::pin_project] +pub struct Keyed { + #[pin] + group: StreamGroup, +} + +impl Deref for Keyed { + type Target = StreamGroup; + + fn deref(&self) -> &Self::Target { + &self.group + } +} + +impl DerefMut for Keyed { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.group + } +} + +impl Stream for Keyed { + type Item = (Key, ::Item); + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut this = self.project(); + this.group.as_mut().poll_next_inner(cx) + } +} + +#[cfg(test)] +mod test { + use super::StreamGroup; + use futures_lite::{prelude::*, stream}; + + #[test] + fn smoke() { + futures_lite::future::block_on(async { + let mut group = StreamGroup::new(); + group.insert(stream::once(2)); + group.insert(stream::once(4)); + + let mut out = 0; + while let Some(num) = group.next().await { + out += num; + } + assert_eq!(out, 6); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + }); + } +} diff --git a/src/stream/zip/array.rs b/src/stream/zip/array.rs index f9d6acb..fa69b01 100644 --- a/src/stream/zip/array.rs +++ b/src/stream/zip/array.rs @@ -1,6 +1,6 @@ use super::Zip as ZipTrait; use crate::stream::IntoStream; -use crate::utils::{self, PollArray, PollState, WakerArray}; +use crate::utils::{self, PollArray, WakerArray}; use core::array; use core::fmt; @@ -40,7 +40,7 @@ where Self { streams, output: array::from_fn(|_| MaybeUninit::uninit()), - state: PollArray::new(), + state: PollArray::new_pending(), wakers: WakerArray::new(), done: false, } @@ -96,7 +96,7 @@ where // Reset the future's state. readiness = this.wakers.readiness().lock().unwrap(); readiness.set_all_ready(); - this.state.fill_with(PollState::default); + this.state.set_all_pending(); // Take the output // diff --git a/src/stream/zip/tuple.rs b/src/stream/zip/tuple.rs index b41393d..05d3f91 100644 --- a/src/stream/zip/tuple.rs +++ b/src/stream/zip/tuple.rs @@ -6,7 +6,7 @@ use core::task::{Context, Poll}; use futures_core::Stream; use super::Zip; -use crate::utils::{PollArray, PollState, WakerArray}; +use crate::utils::{PollArray, WakerArray}; macro_rules! impl_zip_for_tuple { ($mod_name: ident $StructName: ident $($F: ident)+) => { @@ -128,7 +128,7 @@ macro_rules! impl_zip_for_tuple { // Reset the future's state. readiness = this.wakers.readiness().lock().unwrap(); readiness.set_all_ready(); - this.state.fill_with(PollState::default); + this.state.set_all_pending(); // Take the output // @@ -169,7 +169,7 @@ macro_rules! impl_zip_for_tuple { Self::Stream { done: false, output: Default::default(), - state: PollArray::new(), + state: PollArray::new_pending(), wakers: WakerArray::new(), $($F,)+ } diff --git a/src/stream/zip/vec.rs b/src/stream/zip/vec.rs index 7ebf5f5..ba29084 100644 --- a/src/stream/zip/vec.rs +++ b/src/stream/zip/vec.rs @@ -1,6 +1,6 @@ use super::Zip as ZipTrait; use crate::stream::IntoStream; -use crate::utils::{self, PollState, WakerVec}; +use crate::utils::{self, PollVec, WakerVec}; use core::fmt; use core::mem::MaybeUninit; @@ -27,7 +27,7 @@ where streams: Vec, output: Vec::Item>>, wakers: WakerVec, - state: Vec, + state: PollVec, done: bool, len: usize, } @@ -43,7 +43,7 @@ where streams, wakers: WakerVec::new(len), output: (0..len).map(|_| MaybeUninit::uninit()).collect(), - state: (0..len).map(|_| PollState::default()).collect(), + state: PollVec::new_pending(len), done: false, } } @@ -98,7 +98,7 @@ where // Reset the future's state. readiness = this.wakers.readiness().lock().unwrap(); readiness.set_all_ready(); - this.state.fill_with(PollState::default); + this.state.set_all_pending(); // Take the output // diff --git a/src/utils/poll_state/array.rs b/src/utils/poll_state/array.rs index 1ca647c..eee7bca 100644 --- a/src/utils/poll_state/array.rs +++ b/src/utils/poll_state/array.rs @@ -7,9 +7,18 @@ pub(crate) struct PollArray { } impl PollArray { + /// Create a new `PollArray` with all state marked as `None` + #[allow(unused)] pub(crate) fn new() -> Self { Self { - state: [PollState::default(); N], + state: [PollState::None; N], + } + } + + /// Create a new `PollArray` with all state marked as `Pending` + pub(crate) fn new_pending() -> Self { + Self { + state: [PollState::Pending; N], } } @@ -21,10 +30,23 @@ impl PollArray { state.is_ready(), "Future should have reached a `Ready` state" ); - state.set_consumed(); + state.set_none(); }) } + /// Mark all items as "pending" + #[inline] + pub(crate) fn set_all_pending(&mut self) { + self.fill(PollState::Pending); + } + + /// Mark all items as "none" + #[inline] + #[allow(unused)] + pub(crate) fn set_all_none(&mut self) { + self.fill(PollState::None); + } + /// Get an iterator of indexes of all items which are "ready". pub(crate) fn ready_indexes(&self) -> impl Iterator + '_ { self.iter() diff --git a/src/utils/poll_state/poll_state.rs b/src/utils/poll_state/poll_state.rs index 35be7fe..f33d698 100644 --- a/src/utils/poll_state/poll_state.rs +++ b/src/utils/poll_state/poll_state.rs @@ -1,19 +1,27 @@ /// Enumerate the current poll state. -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Copy)] #[repr(u8)] pub(crate) enum PollState { - /// Polling the underlying future or stream. - #[default] + /// There is no associated future or stream. + /// This can be because no item was placed to begin with, or because there + /// are was previously an item but there no longer is. + None, + /// Polling the associated future or stream. Pending, /// Data has been written to the output structure, and is now ready to be /// read. Ready, - /// The underlying future or stream has finished yielding data and all data - /// has been read. We can now stop reasoning about it. - Consumed, } impl PollState { + /// Returns `true` if the metadata is [`None`][Self::None]. + #[must_use] + #[inline] + #[allow(unused)] + pub(crate) fn is_none(&self) -> bool { + matches!(self, Self::None) + } + /// Returns `true` if the metadata is [`Pending`][Self::Pending]. #[must_use] #[inline] @@ -28,22 +36,22 @@ impl PollState { matches!(self, Self::Ready) } - /// Sets the poll state to [`Ready`][Self::Ready]. + /// Sets the poll state to [`None`][Self::None]. #[inline] - pub(crate) fn set_ready(&mut self) { - *self = PollState::Ready; + pub(crate) fn set_none(&mut self) { + *self = PollState::None; } - /// Returns `true` if the poll state is [`Consumed`][Self::Consumed]. - #[must_use] + /// Sets the poll state to [`Ready`][Self::Pending]. #[inline] - pub(crate) fn is_consumed(&self) -> bool { - matches!(self, Self::Consumed) + #[allow(unused)] + pub(crate) fn set_pending(&mut self) { + *self = PollState::Pending; } - /// Sets the poll state to [`Consumed`][Self::Consumed]. + /// Sets the poll state to [`Ready`][Self::Ready]. #[inline] - pub(crate) fn set_consumed(&mut self) { - *self = PollState::Consumed; + pub(crate) fn set_ready(&mut self) { + *self = PollState::Ready; } } diff --git a/src/utils/poll_state/vec.rs b/src/utils/poll_state/vec.rs index 915851f..4de8078 100644 --- a/src/utils/poll_state/vec.rs +++ b/src/utils/poll_state/vec.rs @@ -1,3 +1,4 @@ +use smallvec::{smallvec, SmallVec}; use std::ops::{Deref, DerefMut}; use super::PollState; @@ -5,7 +6,7 @@ use super::PollState; /// The maximum number of entries that `PollStates` can store without /// dynamic memory allocation. /// -/// The `Boxed` variant is the minimum size the data structure can have. +/// The heap variant is the minimum size the data structure can have. /// It consists of a boxed slice (=2 usizes) and space for the enum /// tag (another usize because of padding), so 3 usizes. /// The inline variant then consists of `3 * size_of(usize) - 2` entries. @@ -25,28 +26,16 @@ use super::PollState; /// ``` const MAX_INLINE_ENTRIES: usize = std::mem::size_of::() * 3 - 2; -pub(crate) enum PollVec { - Inline(u8, [PollState; MAX_INLINE_ENTRIES]), - Boxed(Box<[PollState]>), -} +#[derive(Default)] +pub(crate) struct PollVec(SmallVec<[PollState; MAX_INLINE_ENTRIES]>); impl PollVec { pub(crate) fn new(len: usize) -> Self { - assert!(MAX_INLINE_ENTRIES <= u8::MAX as usize); - - if len <= MAX_INLINE_ENTRIES { - Self::Inline(len as u8, Default::default()) - } else { - // Make sure that we don't reallocate the vec's memory - // during `Vec::into_boxed_slice()`. - let mut states = Vec::new(); - debug_assert_eq!(states.capacity(), 0); - states.reserve_exact(len); - debug_assert_eq!(states.capacity(), len); - states.resize(len, PollState::default()); - debug_assert_eq!(states.capacity(), len); - Self::Boxed(states.into_boxed_slice()) - } + Self(smallvec![PollState::None; len]) + } + + pub(crate) fn new_pending(len: usize) -> Self { + Self(smallvec![PollState::Pending; len]) } /// Get an iterator of indexes of all items which are "ready". @@ -67,25 +56,47 @@ impl PollVec { .filter(|(_, state)| state.is_pending()) .map(|(i, _)| i) } + + /// Get an iterator of indexes of all items which are "consumed". + #[allow(unused)] + pub(crate) fn consumed_indexes(&self) -> impl Iterator + '_ { + self.iter() + .cloned() + .enumerate() + .filter(|(_, state)| state.is_none()) + .map(|(i, _)| i) + } + + /// Mark all items as "pending" + #[inline] + pub(crate) fn set_all_pending(&mut self) { + self.0.fill(PollState::Pending); + } + + /// Mark all items as "none" + #[inline] + #[allow(unused)] + pub(crate) fn set_all_none(&mut self) { + self.0.fill(PollState::None); + } + + /// Resize the `PollVec` + pub(crate) fn resize(&mut self, len: usize) { + self.0.resize_with(len, || PollState::None) + } } impl Deref for PollVec { type Target = [PollState]; fn deref(&self) -> &Self::Target { - match self { - PollVec::Inline(len, states) => &states[..*len as usize], - Self::Boxed(states) => &states[..], - } + &self.0 } } impl DerefMut for PollVec { fn deref_mut(&mut self) -> &mut Self::Target { - match self { - PollVec::Inline(len, states) => &mut states[..*len as usize], - Self::Boxed(states) => &mut states[..], - } + &mut self.0 } } @@ -95,15 +106,16 @@ mod tests { #[test] fn type_size() { + // PollVec is three words plus two bits assert_eq!( std::mem::size_of::(), - std::mem::size_of::() * 3 + std::mem::size_of::() * 4 ); } #[test] fn boxed_does_not_allocate_twice() { // Make sure the debug_assertions in PollStates::new() don't fail. - let _ = PollVec::new(MAX_INLINE_ENTRIES + 10); + let _ = PollVec::new_pending(MAX_INLINE_ENTRIES + 10); } } diff --git a/src/utils/wakers/array/mod.rs b/src/utils/wakers/array/mod.rs index 7303d9c..fc3b38a 100644 --- a/src/utils/wakers/array/mod.rs +++ b/src/utils/wakers/array/mod.rs @@ -1,7 +1,7 @@ -mod readiness; +mod readiness_array; mod waker; mod waker_array; -pub(crate) use readiness::ReadinessArray; +pub(crate) use readiness_array::ReadinessArray; pub(crate) use waker::InlineWakerArray; pub(crate) use waker_array::WakerArray; diff --git a/src/utils/wakers/array/readiness.rs b/src/utils/wakers/array/readiness_array.rs similarity index 82% rename from src/utils/wakers/array/readiness.rs rename to src/utils/wakers/array/readiness_array.rs index 047103c..c7e2157 100644 --- a/src/utils/wakers/array/readiness.rs +++ b/src/utils/wakers/array/readiness_array.rs @@ -4,7 +4,7 @@ use std::task::Waker; #[derive(Debug)] pub(crate) struct ReadinessArray { count: usize, - ready: [bool; N], + readiness_list: [bool; N], parent_waker: Option, } @@ -13,16 +13,16 @@ impl ReadinessArray { pub(crate) fn new() -> Self { Self { count: N, - ready: [true; N], // TODO: use a bitarray instead + readiness_list: [true; N], // TODO: use a bitarray instead parent_waker: None, } } /// Returns the old ready state for this id pub(crate) fn set_ready(&mut self, id: usize) -> bool { - if !self.ready[id] { + if !self.readiness_list[id] { self.count += 1; - self.ready[id] = true; + self.readiness_list[id] = true; false } else { @@ -32,15 +32,15 @@ impl ReadinessArray { /// Set all markers to ready. pub(crate) fn set_all_ready(&mut self) { - self.ready.fill(true); + self.readiness_list.fill(true); self.count = N; } /// Returns whether the task id was previously ready pub(crate) fn clear_ready(&mut self, id: usize) -> bool { - if self.ready[id] { + if self.readiness_list[id] { self.count -= 1; - self.ready[id] = false; + self.readiness_list[id] = false; true } else { diff --git a/src/utils/wakers/vec/mod.rs b/src/utils/wakers/vec/mod.rs index e002fd3..1cbdeb7 100644 --- a/src/utils/wakers/vec/mod.rs +++ b/src/utils/wakers/vec/mod.rs @@ -1,7 +1,7 @@ -mod readiness; +mod readiness_vec; mod waker; mod waker_vec; -pub(crate) use readiness::ReadinessVec; +pub(crate) use readiness_vec::ReadinessVec; pub(crate) use waker::InlineWakerVec; pub(crate) use waker_vec::WakerVec; diff --git a/src/utils/wakers/vec/readiness.rs b/src/utils/wakers/vec/readiness.rs deleted file mode 100644 index 08b6045..0000000 --- a/src/utils/wakers/vec/readiness.rs +++ /dev/null @@ -1,70 +0,0 @@ -use bitvec::{bitvec, vec::BitVec}; -use std::task::Waker; - -/// Tracks which wakers are "ready" and should be polled. -#[derive(Debug)] -pub(crate) struct ReadinessVec { - count: usize, - max_count: usize, - ready: BitVec, - parent_waker: Option, -} - -impl ReadinessVec { - /// Create a new instance of readiness. - pub(crate) fn new(count: usize) -> Self { - Self { - count, - max_count: count, - ready: bitvec![true as usize; count], - parent_waker: None, - } - } - - /// Returns the old ready state for this id - pub(crate) fn set_ready(&mut self, id: usize) -> bool { - if !self.ready[id] { - self.count += 1; - self.ready.set(id, true); - - false - } else { - true - } - } - - /// Set all markers to ready. - pub(crate) fn set_all_ready(&mut self) { - self.ready.fill(true); - self.count = self.max_count; - } - - /// Returns whether the task id was previously ready - pub(crate) fn clear_ready(&mut self, id: usize) -> bool { - if self.ready[id] { - self.count -= 1; - self.ready.set(id, false); - - true - } else { - false - } - } - - /// Returns `true` if any of the wakers are ready. - pub(crate) fn any_ready(&self) -> bool { - self.count > 0 - } - - /// Access the parent waker. - #[inline] - pub(crate) fn parent_waker(&self) -> Option<&Waker> { - self.parent_waker.as_ref() - } - - /// Set the parent `Waker`. This needs to be called at the start of every - /// `poll` function. - pub(crate) fn set_waker(&mut self, parent_waker: &Waker) { - self.parent_waker = Some(parent_waker.clone()); - } -} diff --git a/src/utils/wakers/vec/readiness_vec.rs b/src/utils/wakers/vec/readiness_vec.rs new file mode 100644 index 0000000..76be854 --- /dev/null +++ b/src/utils/wakers/vec/readiness_vec.rs @@ -0,0 +1,107 @@ +use bitvec::{bitvec, vec::BitVec}; +use std::task::Waker; + +/// Tracks which wakers are "ready" and should be polled. +#[derive(Debug)] +pub(crate) struct ReadinessVec { + ready_count: usize, + max_count: usize, + readiness_list: BitVec, + parent_waker: Option, +} + +impl ReadinessVec { + /// Create a new instance of readiness. + pub(crate) fn new(len: usize) -> Self { + Self { + ready_count: len, + max_count: len, + readiness_list: bitvec![true as usize; len], + parent_waker: None, + } + } + + /// Set the ready state to `true` for the given index + /// + /// Returns the old ready state for this id + pub(crate) fn set_ready(&mut self, index: usize) -> bool { + if !self.readiness_list[index] { + self.ready_count += 1; + self.readiness_list.set(index, true); + false + } else { + true + } + } + + /// Set all markers to ready. + pub(crate) fn set_all_ready(&mut self) { + self.readiness_list.fill(true); + self.ready_count = self.max_count; + } + + /// Set the ready state to `false` for the given index + /// + /// Returns whether the task id was previously ready + pub(crate) fn clear_ready(&mut self, index: usize) -> bool { + if self.readiness_list[index] { + self.ready_count -= 1; + self.readiness_list.set(index, false); + true + } else { + false + } + } + + /// Returns whether the task id was previously ready + #[allow(unused)] + pub(crate) fn clear_all_ready(&mut self) { + self.readiness_list.fill(false); + self.ready_count = 0; + } + + /// Returns `true` if any of the wakers are ready. + pub(crate) fn any_ready(&self) -> bool { + self.ready_count > 0 + } + + /// Access the parent waker. + #[inline] + pub(crate) fn parent_waker(&self) -> Option<&Waker> { + self.parent_waker.as_ref() + } + + /// Set the parent `Waker`. This needs to be called at the start of every + /// `poll` function. + pub(crate) fn set_waker(&mut self, parent_waker: &Waker) { + self.parent_waker = Some(parent_waker.clone()); + } + + /// Resize `readiness` to the new length. + /// + /// If new entries are created, they will be marked as 'ready'. + pub(crate) fn resize(&mut self, len: usize) { + self.max_count = len; + self.readiness_list.resize(len, true); + self.ready_count = self.readiness_list.iter_ones().count(); + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn resize() { + let mut readiness = ReadinessVec::new(10); + assert!(readiness.any_ready()); + readiness.clear_all_ready(); + assert!(!readiness.any_ready()); + readiness.set_ready(9); + assert!(readiness.any_ready()); + readiness.resize(9); + assert!(!readiness.any_ready()); + readiness.resize(10); + assert!(readiness.any_ready()); + } +} diff --git a/src/utils/wakers/vec/waker_vec.rs b/src/utils/wakers/vec/waker_vec.rs index e317f0b..dcc6c41 100644 --- a/src/utils/wakers/vec/waker_vec.rs +++ b/src/utils/wakers/vec/waker_vec.rs @@ -10,6 +10,12 @@ pub(crate) struct WakerVec { readiness: Arc>, } +impl Default for WakerVec { + fn default() -> Self { + Self::new(0) + } +} + impl WakerVec { /// Create a new instance of `WakerVec`. pub(crate) fn new(len: usize) -> Self { @@ -28,4 +34,20 @@ impl WakerVec { pub(crate) fn readiness(&self) -> &Mutex { self.readiness.as_ref() } + + /// Resize the `WakerVec` to the new size. + pub(crate) fn resize(&mut self, len: usize) { + // If we grow the vec we'll need to extend beyond the current index. + // Which means the first position is the current length, and every position + // beyond that is incremented by 1. + let mut index = self.wakers.len(); + self.wakers.resize_with(len, || { + let ret = Arc::new(InlineWakerVec::new(index, self.readiness.clone())).into(); + index += 1; + ret + }); + + let mut readiness = self.readiness.lock().unwrap(); + readiness.resize(len); + } }