diff --git a/src/lib.rs b/src/lib.rs index 11f5546..dfe65ed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,8 +57,10 @@ pub mod prelude { pub use super::future::Race as _; pub use super::future::RaceOk as _; pub use super::future::TryJoin as _; + pub use super::stream::Chain as _; pub use super::stream::IntoStream as _; pub use super::stream::Merge as _; + pub use super::stream::Zip as _; } pub mod future; @@ -70,7 +72,9 @@ pub mod array { pub use crate::future::race::array::Race; pub use crate::future::race_ok::array::{AggregateError, RaceOk}; pub use crate::future::try_join::array::TryJoin; + pub use crate::stream::chain::array::Chain; pub use crate::stream::merge::array::Merge; + pub use crate::stream::zip::array::Zip; } /// A contiguous growable array type with heap-allocated contents, written `Vec`. pub mod vec { @@ -78,5 +82,7 @@ pub mod vec { pub use crate::future::race::vec::Race; pub use crate::future::race_ok::vec::{AggregateError, RaceOk}; pub use crate::future::try_join::vec::TryJoin; + pub use crate::stream::chain::vec::Chain; pub use crate::stream::merge::vec::Merge; + pub use crate::stream::zip::vec::Zip; } diff --git a/src/stream/chain/array.rs b/src/stream/chain/array.rs new file mode 100644 index 0000000..39f92d2 --- /dev/null +++ b/src/stream/chain/array.rs @@ -0,0 +1,101 @@ +use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use futures_core::Stream; +use pin_project::pin_project; + +use crate::utils; + +use super::Chain as ChainTrait; + +/// A stream that chains multiple streams one after another. +/// +/// This `struct` is created by the [`chain`] method on the [`Chain`] trait. See its +/// documentation for more. +/// +/// [`chain`]: trait.Chain.html#method.merge +/// [`Chain`]: trait.Chain.html +#[pin_project] +pub struct Chain { + #[pin] + streams: [S; N], + index: usize, + len: usize, + done: bool, +} + +impl Stream for Chain { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + assert!(!*this.done, "Stream should not be polled after completion"); + + loop { + if this.index == this.len { + *this.done = true; + return Poll::Ready(None); + } + let stream = utils::iter_pin_mut(this.streams.as_mut()) + .nth(*this.index) + .unwrap(); + match stream.poll_next(cx) { + Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), + Poll::Ready(None) => { + *this.index += 1; + continue; + } + Poll::Pending => return Poll::Pending, + } + } + } +} + +impl fmt::Debug for Chain +where + S: Stream + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.streams.iter()).finish() + } +} + +impl ChainTrait for [S; N] { + type Item = S::Item; + + type Stream = Chain; + + fn chain(self) -> Self::Stream { + Chain { + len: self.len(), + streams: self, + index: 0, + done: false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures_lite::future::block_on; + use futures_lite::prelude::*; + use futures_lite::stream; + + #[test] + fn chain_3() { + block_on(async { + let a = stream::once(1); + let b = stream::once(2); + let c = stream::once(3); + let mut s = [a, b, c].chain(); + + assert_eq!(s.next().await, Some(1)); + assert_eq!(s.next().await, Some(2)); + assert_eq!(s.next().await, Some(3)); + assert_eq!(s.next().await, None); + }) + } +} diff --git a/src/stream/chain/mod.rs b/src/stream/chain/mod.rs new file mode 100644 index 0000000..3dadf3a --- /dev/null +++ b/src/stream/chain/mod.rs @@ -0,0 +1,17 @@ +use futures_core::Stream; + +pub(crate) mod array; +pub(crate) mod tuple; +pub(crate) mod vec; + +/// Takes multiple streams and creates a new stream over all in sequence. +pub trait Chain { + /// What's the return type of our stream? + type Item; + + /// What stream do we return? + type Stream: Stream; + + /// Combine multiple streams into a single stream. + fn chain(self) -> Self::Stream; +} diff --git a/src/stream/chain/tuple.rs b/src/stream/chain/tuple.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/stream/chain/tuple.rs @@ -0,0 +1 @@ + diff --git a/src/stream/chain/vec.rs b/src/stream/chain/vec.rs new file mode 100644 index 0000000..76ce0f6 --- /dev/null +++ b/src/stream/chain/vec.rs @@ -0,0 +1,101 @@ +use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use futures_core::Stream; +use pin_project::pin_project; + +use crate::utils; + +use super::Chain as ChainTrait; + +/// A stream that chains multiple streams one after another. +/// +/// This `struct` is created by the [`chain`] method on the [`Chain`] trait. See its +/// documentation for more. +/// +/// [`chain`]: trait.Chain.html#method.merge +/// [`Chain`]: trait.Chain.html +#[pin_project] +pub struct Chain { + #[pin] + streams: Vec, + index: usize, + len: usize, + done: bool, +} + +impl Stream for Chain { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + assert!(!*this.done, "Stream should not be polled after completion"); + + loop { + if this.index == this.len { + *this.done = true; + return Poll::Ready(None); + } + let stream = utils::iter_pin_mut_vec(this.streams.as_mut()) + .nth(*this.index) + .unwrap(); + match stream.poll_next(cx) { + Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), + Poll::Ready(None) => { + *this.index += 1; + continue; + } + Poll::Pending => return Poll::Pending, + } + } + } +} + +impl fmt::Debug for Chain +where + S: Stream + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.streams.iter()).finish() + } +} + +impl ChainTrait for Vec { + type Item = S::Item; + + type Stream = Chain; + + fn chain(self) -> Self::Stream { + Chain { + len: self.len(), + streams: self, + index: 0, + done: false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures_lite::future::block_on; + use futures_lite::prelude::*; + use futures_lite::stream; + + #[test] + fn chain_3() { + block_on(async { + let a = stream::once(1); + let b = stream::once(2); + let c = stream::once(3); + let mut s = vec![a, b, c].chain(); + + assert_eq!(s.next().await, Some(1)); + assert_eq!(s.next().await, Some(2)); + assert_eq!(s.next().await, Some(3)); + assert_eq!(s.next().await, None); + }) + } +} diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 92cecfe..e235412 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -47,8 +47,12 @@ //! //! See the [future concurrency][crate::future#concurrency] documentation for //! more on futures concurrency. +pub use chain::Chain; pub use into_stream::IntoStream; pub use merge::Merge; +pub use zip::Zip; +pub(crate) mod chain; mod into_stream; pub(crate) mod merge; +pub(crate) mod zip; diff --git a/src/stream/zip/array.rs b/src/stream/zip/array.rs new file mode 100644 index 0000000..339dd0d --- /dev/null +++ b/src/stream/zip/array.rs @@ -0,0 +1,191 @@ +use super::Zip as ZipTrait; +use crate::stream::IntoStream; +use crate::utils::{self, PollState, WakerList}; + +use core::array; +use core::fmt; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::{Context, Poll}; +use std::mem; + +use futures_core::Stream; +use pin_project::{pin_project, pinned_drop}; + +/// A stream that ‘zips up’ multiple streams into a single stream of pairs. +/// +/// This `struct` is created by the [`zip`] method on the [`Zip`] trait. See its +/// documentation for more. +/// +/// [`zip`]: trait.Zip.html#method.zip +/// [`Zip`]: trait.Zip.html +#[pin_project(PinnedDrop)] +pub struct Zip +where + S: Stream, +{ + #[pin] + streams: [S; N], + output: [MaybeUninit<::Item>; N], + wakers: WakerList, + state: [PollState; N], + done: bool, +} + +impl Zip +where + S: Stream, +{ + pub(crate) fn new(streams: [S; N]) -> Self { + Self { + streams, + output: array::from_fn(|_| MaybeUninit::uninit()), + state: array::from_fn(|_| PollState::default()), + wakers: WakerList::new(N), + done: false, + } + } +} + +impl fmt::Debug for Zip +where + S: Stream + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.streams.iter()).finish() + } +} + +impl Stream for Zip +where + S: Stream, +{ + type Item = [S::Item; N]; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + assert!(!*this.done, "Stream should not be polled after completion"); + + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + for index in 0..N { + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; + } else if this.state[index].is_done() || !readiness.clear_ready(index) { + // We already have data stored for this stream, + // Or this waker isn't ready yet + continue; + } + + // unlock readiness so we don't deadlock when polling + drop(readiness); + + // Obtain the intermediate waker. + let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); + + let stream = utils::get_pin_mut(this.streams.as_mut(), index).unwrap(); + match stream.poll_next(&mut cx) { + Poll::Ready(Some(item)) => { + this.output[index] = MaybeUninit::new(item); + this.state[index] = PollState::Done; + + let all_ready = this.state.iter().all(|state| state.is_done()); + if all_ready { + // Reset the future's state. + readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_all_ready(); + this.state.fill(PollState::Pending); + + // Take the output + // + // SAFETY: we just validated all our data is populated, meaning + // we can assume this is initialized. + let mut output = array::from_fn(|_| MaybeUninit::uninit()); + mem::swap(this.output, &mut output); + let output = unsafe { array_assume_init(output) }; + return Poll::Ready(Some(output)); + } + } + Poll::Ready(None) => { + // If one stream returns `None`, we can no longer return + // pairs - meaning the stream is over. + *this.done = true; + return Poll::Ready(None); + } + Poll::Pending => {} + } + + // Lock readiness so we can use it again + readiness = this.wakers.readiness().lock().unwrap(); + } + Poll::Pending + } +} + +/// Drop the already initialized values on cancellation. +#[pinned_drop] +impl PinnedDrop for Zip +where + S: Stream, +{ + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + + for (state, output) in this.state.iter_mut().zip(this.output.iter_mut()) { + if state.is_done() { + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { output.assume_init_drop() }; + } + } + } +} + +impl ZipTrait for [S; N] +where + S: IntoStream, +{ + type Item = as Stream>::Item; + type Stream = Zip; + + fn zip(self) -> Self::Stream { + Zip::new(self.map(|i| i.into_stream())) + } +} + +#[cfg(test)] +mod tests { + use crate::stream::Zip; + use futures_lite::future::block_on; + use futures_lite::prelude::*; + use futures_lite::stream; + + #[test] + fn zip_array_3() { + block_on(async { + let a = stream::repeat(1).take(2); + let b = stream::repeat(2).take(2); + let c = stream::repeat(3).take(2); + let mut s = Zip::zip([a, b, c]); + + assert_eq!(s.next().await, Some([1, 2, 3])); + assert_eq!(s.next().await, Some([1, 2, 3])); + assert_eq!(s.next().await, None); + }) + } +} + +// Inlined version of the unstable `MaybeUninit::array_assume_init` feature. +// FIXME: replace with `utils::array_assume_init` +unsafe fn array_assume_init(array: [MaybeUninit; N]) -> [T; N] { + // SAFETY: + // * The caller guarantees that all elements of the array are initialized + // * `MaybeUninit` and T are guaranteed to have the same layout + // * `MaybeUninit` does not drop, so there are no double-frees + // And thus the conversion is safe + let ret = unsafe { (&array as *const _ as *const [T; N]).read() }; + mem::forget(array); + ret +} diff --git a/src/stream/zip/mod.rs b/src/stream/zip/mod.rs new file mode 100644 index 0000000..14b33b0 --- /dev/null +++ b/src/stream/zip/mod.rs @@ -0,0 +1,17 @@ +use futures_core::Stream; + +pub(crate) mod array; +pub(crate) mod tuple; +pub(crate) mod vec; + +/// ‘Zips up’ multiple streams into a single stream of pairs. +pub trait Zip { + /// What's the return type of our stream? + type Item; + + /// What stream do we return? + type Stream: Stream; + + /// Combine multiple streams into a single stream. + fn zip(self) -> Self::Stream; +} diff --git a/src/stream/zip/tuple.rs b/src/stream/zip/tuple.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/stream/zip/tuple.rs @@ -0,0 +1 @@ + diff --git a/src/stream/zip/vec.rs b/src/stream/zip/vec.rs new file mode 100644 index 0000000..7c870fe --- /dev/null +++ b/src/stream/zip/vec.rs @@ -0,0 +1,193 @@ +use super::Zip as ZipTrait; +use crate::stream::IntoStream; +use crate::utils::{self, PollState, WakerList}; + +use core::fmt; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::{Context, Poll}; +use std::mem; + +use futures_core::Stream; +use pin_project::{pin_project, pinned_drop}; + +/// A stream that ‘zips up’ multiple streams into a single stream of pairs. +/// +/// This `struct` is created by the [`zip`] method on the [`Zip`] trait. See its +/// documentation for more. +/// +/// [`zip`]: trait.Zip.html#method.zip +/// [`Zip`]: trait.Zip.html +#[pin_project(PinnedDrop)] +pub struct Zip +where + S: Stream, +{ + #[pin] + streams: Vec, + output: Vec::Item>>, + wakers: WakerList, + state: Vec, + done: bool, + len: usize, +} + +impl Zip +where + S: Stream, +{ + pub(crate) fn new(streams: Vec) -> Self { + let len = streams.len(); + Self { + len, + streams, + wakers: WakerList::new(len), + output: (0..len).map(|_| MaybeUninit::uninit()).collect(), + state: (0..len).map(|_| PollState::default()).collect(), + done: false, + } + } +} + +impl fmt::Debug for Zip +where + S: Stream + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.streams.iter()).finish() + } +} + +impl Stream for Zip +where + S: Stream, +{ + type Item = Vec; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + assert!(!*this.done, "Stream should not be polled after completion"); + + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + for index in 0..*this.len { + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; + } else if this.state[index].is_done() || !readiness.clear_ready(index) { + // We already have data stored for this stream, + // Or this waker isn't ready yet + continue; + } + + // unlock readiness so we don't deadlock when polling + drop(readiness); + + // Obtain the intermediate waker. + let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); + + let stream = utils::get_pin_mut_from_vec(this.streams.as_mut(), index).unwrap(); + match stream.poll_next(&mut cx) { + Poll::Ready(Some(item)) => { + this.output[index] = MaybeUninit::new(item); + this.state[index] = PollState::Done; + + let all_ready = this.state.iter().all(|state| state.is_done()); + if all_ready { + // Reset the future's state. + readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_all_ready(); + this.state.fill(PollState::Pending); + + // Take the output + // + // SAFETY: we just validated all our data is populated, meaning + // we can assume this is initialized. + let mut output = (0..*this.len).map(|_| MaybeUninit::uninit()).collect(); + mem::swap(this.output, &mut output); + let output = unsafe { vec_assume_init(output) }; + return Poll::Ready(Some(output)); + } + } + Poll::Ready(None) => { + // If one stream returns `None`, we can no longer return + // pairs - meaning the stream is over. + *this.done = true; + return Poll::Ready(None); + } + Poll::Pending => {} + } + + // Lock readiness so we can use it again + readiness = this.wakers.readiness().lock().unwrap(); + } + Poll::Pending + } +} + +/// Drop the already initialized values on cancellation. +#[pinned_drop] +impl PinnedDrop for Zip +where + S: Stream, +{ + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + + for (state, output) in this.state.iter_mut().zip(this.output.iter_mut()) { + if state.is_done() { + // SAFETY: we've just filtered down to *only* the initialized values. + // We can assume they're initialized, and this is where we drop them. + unsafe { output.assume_init_drop() }; + } + } + } +} + +impl ZipTrait for Vec +where + S: IntoStream, +{ + type Item = as Stream>::Item; + type Stream = Zip; + + fn zip(self) -> Self::Stream { + Zip::new(self.into_iter().map(|i| i.into_stream()).collect()) + } +} + +#[cfg(test)] +mod tests { + use crate::stream::Zip; + use futures_lite::future::block_on; + use futures_lite::prelude::*; + use futures_lite::stream; + + #[test] + fn zip_array_3() { + block_on(async { + let a = stream::repeat(1).take(2); + let b = stream::repeat(2).take(2); + let c = stream::repeat(3).take(2); + let mut s = vec![a, b, c].zip(); + + assert_eq!(s.next().await, Some(vec![1, 2, 3])); + assert_eq!(s.next().await, Some(vec![1, 2, 3])); + assert_eq!(s.next().await, None); + }) + } +} + +// Inlined version of the unstable `MaybeUninit::array_assume_init` feature. +// FIXME: replace with `utils::array_assume_init` +unsafe fn vec_assume_init(vec: Vec>) -> Vec { + // SAFETY: + // * The caller guarantees that all elements of the vec are initialized + // * `MaybeUninit` and T are guaranteed to have the same layout + // * `MaybeUninit` does not drop, so there are no double-frees + // And thus the conversion is safe + let ret = unsafe { (&vec as *const _ as *const Vec).read() }; + mem::forget(vec); + ret +} diff --git a/src/utils/wakers/readiness.rs b/src/utils/wakers/readiness.rs index 1cd9a93..b6618c8 100644 --- a/src/utils/wakers/readiness.rs +++ b/src/utils/wakers/readiness.rs @@ -7,6 +7,7 @@ use crate::utils; #[derive(Debug)] pub(crate) struct Readiness { count: usize, + max_count: usize, ready: BitVec, parent_waker: Option, } @@ -16,6 +17,7 @@ impl Readiness { pub(crate) fn new(count: usize) -> Self { Self { count, + max_count: count, ready: bitvec![true as usize; count], parent_waker: None, } @@ -33,6 +35,12 @@ impl Readiness { } } + /// Set all markers to ready. + pub(crate) fn set_all_ready(&mut self) { + self.ready.fill(true); + self.count = self.max_count; + } + /// Returns whether the task id was previously ready pub(crate) fn clear_ready(&mut self, id: usize) -> bool { if self.ready[id] { diff --git a/src/utils/wakers/waker_list.rs b/src/utils/wakers/waker_list.rs index dc305e5..2895e8d 100644 --- a/src/utils/wakers/waker_list.rs +++ b/src/utils/wakers/waker_list.rs @@ -15,12 +15,10 @@ impl WakerList { /// Create a new instance of `WakerList`. pub(crate) fn new(len: usize) -> Self { let readiness = Arc::new(Mutex::new(Readiness::new(len))); - Self { - wakers: (0..len) - .map(|i| Arc::new(InlineWaker::new(i, readiness.clone())).into()) - .collect(), - readiness, - } + let wakers = (0..len) + .map(|i| Arc::new(InlineWaker::new(i, readiness.clone())).into()) + .collect(); + Self { wakers, readiness } } pub(crate) fn get(&self, index: usize) -> Option<&Waker> {