diff --git a/src/par_stream/mod.rs b/src/par_stream/mod.rs index 500befe..69dc2a7 100644 --- a/src/par_stream/mod.rs +++ b/src/par_stream/mod.rs @@ -9,11 +9,13 @@ pub use for_each::ForEach; pub use map::Map; pub use next::NextFuture; pub use take::Take; +pub use skip::Skip; mod for_each; mod map; mod next; mod take; +mod skip; /// Parallel version of the standard `Stream` trait. pub trait ParallelStream: Sized + Send + Sync + Unpin + 'static { @@ -54,6 +56,14 @@ pub trait ParallelStream: Sized + Send + Sync + Unpin + 'static { Take::new(self, n) } + /// Creates a stream that skips the first `n` elements. + fn skip(self, n: usize) -> Skip + where + Self: Sized + { + Skip::new(self, n) + } + /// Applies `f` to each item of this stream in parallel. fn for_each(self, f: F) -> ForEach where diff --git a/src/par_stream/skip.rs b/src/par_stream/skip.rs new file mode 100644 index 0000000..278e5df --- /dev/null +++ b/src/par_stream/skip.rs @@ -0,0 +1,73 @@ +use async_std::sync::{self, Receiver}; +use async_std::task; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +use crate::ParallelStream; + +pin_project! { + /// A stream that skips the first `n` items of another stream. + /// + /// This `struct` is created by the [`skip`] method on [`ParallelStream`]. See its + /// documentation for more. + /// + /// [`skip`]: trait.ParallelStream.html#method.skip + /// [`ParallelStream`]: trait.ParallelStream.html + #[derive(Clone, Debug)] + pub struct Skip { + #[pin] + receiver: Receiver, + limit: Option, + } +} + +impl Skip { + pub(super) fn new(mut stream: S, mut skipped: usize) -> Self + where + S: ParallelStream + { + let limit = stream.get_limit(); + let (sender, receiver) = sync::channel(1); + task::spawn(async move { + while let Some(val) = stream.next().await { + if skipped == 0 { + sender.send(val).await + } else { + skipped -= 1; + } + } + }); + + Skip { limit, receiver } + } +} + +impl ParallelStream for Skip { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + this.receiver.poll_next(cx) + } + + fn limit(mut self, limit: impl Into>) -> Self { + self.limit = limit.into(); + self + } + + fn get_limit(&self) -> Option { + self.limit + } +} + +#[async_std::test] +async fn smoke() { + let s = async_std::stream::from_iter(vec![1, 2, 3, 4, 5, 6]); + let mut output = vec![]; + let mut stream = crate::from_stream(s).skip(3); + while let Some(n) = stream.next().await { + output.push(n); + } + assert_eq!(output, vec![4, 5, 6]); +}