diff --git a/src/concurrent_stream/from_concurrent_stream.rs b/src/concurrent_stream/from_concurrent_stream.rs index ab704e9..b0f874d 100644 --- a/src/concurrent_stream/from_concurrent_stream.rs +++ b/src/concurrent_stream/from_concurrent_stream.rs @@ -28,6 +28,18 @@ impl FromConcurrentStream for Vec { } } +impl FromConcurrentStream> for Result, E> { + async fn from_concurrent_stream(iter: S) -> Self + where + S: IntoConcurrentStream>, + { + let stream = iter.into_co_stream(); + let mut output = Ok(Vec::with_capacity(stream.size_hint().1.unwrap_or_default())); + stream.drive(ResultVecConsumer::new(&mut output)).await; + output + } +} + // TODO: replace this with a generalized `fold` operation #[pin_project] pub(crate) struct VecConsumer<'a, Fut: Future> { @@ -73,6 +85,60 @@ where } } +#[pin_project] +pub(crate) struct ResultVecConsumer<'a, Fut: Future, T, E> { + #[pin] + group: FuturesUnordered, + output: &'a mut Result, E>, +} + +impl<'a, Fut: Future, T, E> ResultVecConsumer<'a, Fut, T, E> { + pub(crate) fn new(output: &'a mut Result, E>) -> Self { + Self { + group: FuturesUnordered::new(), + output, + } + } +} + +impl<'a, Fut, T, E> Consumer, Fut> for ResultVecConsumer<'a, Fut, T, E> +where + Fut: Future>, +{ + type Output = (); + + async fn send(self: Pin<&mut Self>, future: Fut) -> super::ConsumerState { + let mut this = self.project(); + // unbounded concurrency, so we just goooo + this.group.as_mut().push(future); + ConsumerState::Continue + } + + async fn progress(self: Pin<&mut Self>) -> super::ConsumerState { + let mut this = self.project(); + + while let Some(item) = this.group.next().await { + match item { + Ok(item) => { + let Ok(items) = this.output else { + panic!("progress called after returning ConsumerState::Break"); + }; + items.push(item); + } + Err(e) => { + **this.output = Err(e); + return ConsumerState::Break; + } + } + } + ConsumerState::Empty + } + + async fn flush(self: Pin<&mut Self>) -> Self::Output { + self.progress().await; + } +} + #[cfg(test)] mod test { use crate::prelude::*; @@ -85,4 +151,24 @@ mod test { assert_eq!(v, &[1, 1, 1, 1, 1]); }); } + + #[test] + fn collect_to_result_ok() { + futures_lite::future::block_on(async { + let v: Result, ()> = stream::repeat(Ok(1)).co().take(5).collect().await; + assert_eq!(v, Ok(vec![1, 1, 1, 1, 1])); + }); + } + + #[test] + fn collect_to_result_err() { + futures_lite::future::block_on(async { + let v: Result, _> = stream::repeat(Err::(())) + .co() + .take(5) + .collect() + .await; + assert_eq!(v, Err(())); + }); + } }