diff --git a/src/concurrent_stream/convert.rs b/src/concurrent_stream/convert.rs new file mode 100644 index 0000000..c06d7a6 --- /dev/null +++ b/src/concurrent_stream/convert.rs @@ -0,0 +1,100 @@ +use super::{ConcurrentStream, Consumer, ConsumerState}; +use crate::future::FutureGroup; +use core::future::Future; +use core::pin::Pin; +use futures_lite::StreamExt; + +/// Conversion into a [`ConcurrentStream`] +pub trait IntoConcurrentStream { + /// The type of the elements being iterated over. + type Item; + /// Which kind of iterator are we turning this into? + type ConcurrentStream: ConcurrentStream; + + /// Convert `self` into a concurrent iterator. + fn into_concurrent_stream(self) -> Self::ConcurrentStream; +} + +impl IntoConcurrentStream for S { + type Item = S::Item; + type ConcurrentStream = S; + + fn into_concurrent_stream(self) -> Self::ConcurrentStream { + self + } +} + +/// Conversion from a [`ConcurrentStream`] +#[allow(async_fn_in_trait)] +pub trait FromConcurrentStream: Sized { + /// Creates a value from a concurrent iterator. + async fn from_concurrent_stream(iter: T) -> Self + where + T: IntoConcurrentStream; +} + +impl FromConcurrentStream for Vec { + async fn from_concurrent_stream(iter: S) -> Self + where + S: IntoConcurrentStream, + { + let stream = iter.into_concurrent_stream(); + let mut output = Vec::with_capacity(stream.size_hint().1.unwrap_or_default()); + stream.drive(VecConsumer::new(&mut output)).await; + output + } +} + +// TODO: replace this with a generalized `fold` operation +pub(crate) struct VecConsumer<'a, Fut: Future> { + group: Pin>>, + output: &'a mut Vec, +} + +impl<'a, Fut: Future> VecConsumer<'a, Fut> { + pub(crate) fn new(output: &'a mut Vec) -> Self { + Self { + group: Box::pin(FutureGroup::new()), + output, + } + } +} + +impl<'a, Item, Fut> Consumer for VecConsumer<'a, Fut> +where + Fut: Future, +{ + type Output = (); + + async fn send(&mut self, future: Fut) -> super::ConsumerState { + // unbounded concurrency, so we just goooo + self.group.as_mut().insert_pinned(future); + ConsumerState::Continue + } + + async fn progress(&mut self) -> super::ConsumerState { + while let Some(item) = self.group.next().await { + self.output.push(item); + } + ConsumerState::Empty + } + async fn finish(mut self) -> Self::Output { + while let Some(item) = self.group.next().await { + self.output.push(item); + } + } +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + use futures_lite::stream; + + #[test] + fn collect() { + futures_lite::future::block_on(async { + let v: Vec<_> = stream::repeat(1).co().take(5).collect().await; + assert_eq!(v, &[1, 1, 1, 1, 1]); + }); + } +} diff --git a/src/concurrent_stream/mod.rs b/src/concurrent_stream/mod.rs index 21ae9d9..6e49790 100644 --- a/src/concurrent_stream/mod.rs +++ b/src/concurrent_stream/mod.rs @@ -1,5 +1,6 @@ //! Concurrent execution of streams +mod convert; mod drain; mod enumerate; mod for_each; @@ -14,6 +15,7 @@ use core::num::NonZeroUsize; use for_each::ForEachConsumer; use try_for_each::TryForEachConsumer; +pub use convert::{FromConcurrentStream, IntoConcurrentStream}; pub use enumerate::Enumerate; pub use from_stream::FromStream; pub use limit::Limit; @@ -148,6 +150,15 @@ pub trait ConcurrentStream { let limit = self.concurrency_limit(); self.drive(TryForEachConsumer::new(limit, f)).await } + + /// Transforms an iterator into a collection. + async fn collect(self) -> B + where + B: FromConcurrentStream, + Self: Sized, + { + B::from_concurrent_stream(self).await + } } /// The state of the consumer, used to communicate back to the source. @@ -162,23 +173,6 @@ pub enum ConsumerState { Empty, } -/// Convert into a concurrent stream -pub trait IntoConcurrentStream { - /// The type of concurrent stream we're returning. - type ConcurrentStream: ConcurrentStream; - - /// Convert `self` into a concurrent stream. - fn into_concurrent_stream(self) -> Self::ConcurrentStream; -} - -impl IntoConcurrentStream for S { - type ConcurrentStream = S; - - fn into_concurrent_stream(self) -> Self::ConcurrentStream { - self - } -} - #[cfg(test)] mod test { use super::*;