diff --git a/src/concurrent_stream/enumerate.rs b/src/concurrent_stream/enumerate.rs index 07c1a09..796d686 100644 --- a/src/concurrent_stream/enumerate.rs +++ b/src/concurrent_stream/enumerate.rs @@ -1,3 +1,5 @@ +use pin_project::pin_project; + use super::{ConcurrentStream, Consumer}; use core::future::Future; use core::num::NonZeroUsize; @@ -47,7 +49,9 @@ impl ConcurrentStream for Enumerate { } } +#[pin_project] struct EnumerateConsumer { + #[pin] inner: C, count: usize, } @@ -58,18 +62,21 @@ where { type Output = C::Output; - async fn send(&mut self, future: Fut) -> super::ConsumerState { - let count = self.count; - self.count += 1; - self.inner.send(EnumerateFuture::new(future, count)).await + async fn send(self: Pin<&mut Self>, future: Fut) -> super::ConsumerState { + let this = self.project(); + let count = *this.count; + *this.count += 1; + this.inner.send(EnumerateFuture::new(future, count)).await } - async fn progress(&mut self) -> super::ConsumerState { - self.inner.progress().await + async fn progress(self: Pin<&mut Self>) -> super::ConsumerState { + let this = self.project(); + this.inner.progress().await } - async fn flush(&mut self) -> Self::Output { - self.inner.flush().await + async fn flush(self: Pin<&mut Self>) -> Self::Output { + let this = self.project(); + this.inner.flush().await } } diff --git a/src/concurrent_stream/for_each.rs b/src/concurrent_stream/for_each.rs index dffc84f..03c8711 100644 --- a/src/concurrent_stream/for_each.rs +++ b/src/concurrent_stream/for_each.rs @@ -23,7 +23,7 @@ where // NOTE: we can remove the `Arc` here if we're willing to make this struct self-referential count: Arc, #[pin] - group: Pin>>>, + group: FutureGroup>, limit: usize, f: F, _phantom: PhantomData<(T, FutB)>, @@ -45,7 +45,7 @@ where f, _phantom: PhantomData, count: Arc::new(AtomicUsize::new(0)), - group: Box::pin(FutureGroup::new()), + group: FutureGroup::new(), } } } @@ -60,30 +60,33 @@ where { type Output = (); - async fn send(&mut self, future: FutT) -> super::ConsumerState { + async fn send(mut self: Pin<&mut Self>, future: FutT) -> super::ConsumerState { + let mut this = self.project(); // If we have no space, we're going to provide backpressure until we have space - while self.count.load(Ordering::Relaxed) >= self.limit { - self.group.next().await; + while this.count.load(Ordering::Relaxed) >= *this.limit { + this.group.next().await; } // Space was available! - insert the item for posterity - self.count.fetch_add(1, Ordering::Relaxed); - let fut = ForEachFut::new(self.f.clone(), future, self.count.clone()); - self.group.as_mut().insert_pinned(fut); + this.count.fetch_add(1, Ordering::Relaxed); + let fut = ForEachFut::new(this.f.clone(), future, this.count.clone()); + this.group.as_mut().insert_pinned(fut); ConsumerState::Continue } - async fn progress(&mut self) -> super::ConsumerState { - while let Some(_) = self.group.next().await {} + async fn progress(self: Pin<&mut Self>) -> super::ConsumerState { + let mut this = self.project(); + while let Some(_) = this.group.next().await {} ConsumerState::Empty } - async fn flush(&mut self) -> Self::Output { + async fn flush(self: Pin<&mut Self>) -> Self::Output { + let mut this = self.project(); // 4. We will no longer receive any additional futures from the // underlying stream; wait until all the futures in the group have // resolved. - while let Some(_) = self.group.next().await {} + while let Some(_) = this.group.next().await {} } } diff --git a/src/concurrent_stream/from_concurrent_stream.rs b/src/concurrent_stream/from_concurrent_stream.rs index 6db7372..a3597fd 100644 --- a/src/concurrent_stream/from_concurrent_stream.rs +++ b/src/concurrent_stream/from_concurrent_stream.rs @@ -5,6 +5,7 @@ use alloc::vec::Vec; use core::future::Future; use core::pin::Pin; use futures_lite::StreamExt; +use pin_project::pin_project; /// Conversion from a [`ConcurrentStream`] #[allow(async_fn_in_trait)] @@ -28,15 +29,17 @@ impl FromConcurrentStream for Vec { } // TODO: replace this with a generalized `fold` operation +#[pin_project] pub(crate) struct VecConsumer<'a, Fut: Future> { - group: Pin>>, + #[pin] + group: FutureGroup, output: &'a mut Vec, } impl<'a, Fut: Future> VecConsumer<'a, Fut> { pub(crate) fn new(output: &'a mut Vec) -> Self { Self { - group: Box::pin(FutureGroup::new()), + group: FutureGroup::new(), output, } } @@ -48,21 +51,24 @@ where { type Output = (); - async fn send(&mut self, future: Fut) -> super::ConsumerState { + async fn send(self: Pin<&mut Self>, future: Fut) -> super::ConsumerState { + let mut this = self.project(); // unbounded concurrency, so we just goooo - self.group.as_mut().insert_pinned(future); + this.group.as_mut().insert_pinned(future); ConsumerState::Continue } - async fn progress(&mut self) -> super::ConsumerState { - while let Some(item) = self.group.next().await { - self.output.push(item); + async fn progress(self: Pin<&mut Self>) -> super::ConsumerState { + let mut this = self.project(); + while let Some(item) = this.group.next().await { + this.output.push(item); } ConsumerState::Empty } - async fn flush(&mut self) -> Self::Output { - while let Some(item) = self.group.next().await { - self.output.push(item); + async fn flush(self: Pin<&mut Self>) -> Self::Output { + let mut this = self.project(); + while let Some(item) = this.group.next().await { + this.output.push(item); } } } diff --git a/src/concurrent_stream/limit.rs b/src/concurrent_stream/limit.rs index 346ff65..9ea1363 100644 --- a/src/concurrent_stream/limit.rs +++ b/src/concurrent_stream/limit.rs @@ -1,6 +1,9 @@ +use pin_project::pin_project; + use super::{ConcurrentStream, Consumer}; use core::future::Future; use core::num::NonZeroUsize; +use core::pin::Pin; /// A concurrent iterator that limits the amount of concurrency applied. /// @@ -43,7 +46,9 @@ impl ConcurrentStream for Limit { } } +#[pin_project] struct LimitConsumer { + #[pin] inner: C, } impl Consumer for LimitConsumer @@ -53,15 +58,18 @@ where { type Output = C::Output; - async fn send(&mut self, future: Fut) -> super::ConsumerState { - self.inner.send(future).await + async fn send(self: Pin<&mut Self>, future: Fut) -> super::ConsumerState { + let this = self.project(); + this.inner.send(future).await } - async fn progress(&mut self) -> super::ConsumerState { - self.inner.progress().await + async fn progress(self: Pin<&mut Self>) -> super::ConsumerState { + let this = self.project(); + this.inner.progress().await } - async fn flush(&mut self) -> Self::Output { - self.inner.flush().await + async fn flush(self: Pin<&mut Self>) -> Self::Output { + let this = self.project(); + this.inner.flush().await } } diff --git a/src/concurrent_stream/map.rs b/src/concurrent_stream/map.rs index 4a76935..a492f45 100644 --- a/src/concurrent_stream/map.rs +++ b/src/concurrent_stream/map.rs @@ -1,3 +1,5 @@ +use pin_project::pin_project; + use super::{ConcurrentStream, Consumer}; use core::num::NonZeroUsize; use core::{ @@ -71,7 +73,7 @@ where } } -// OK: validated! - all bounds should check out +#[pin_project] pub struct MapConsumer where FutT: Future, @@ -80,6 +82,7 @@ where F: Clone, FutB: Future, { + #[pin] inner: C, f: F, _phantom: PhantomData<(FutT, T, FutB, B)>, @@ -95,17 +98,20 @@ where { type Output = C::Output; - async fn progress(&mut self) -> super::ConsumerState { - self.inner.progress().await + async fn progress(self: Pin<&mut Self>) -> super::ConsumerState { + let this = self.project(); + this.inner.progress().await } - async fn send(&mut self, future: FutT) -> super::ConsumerState { - let fut = MapFuture::new(self.f.clone(), future); - self.inner.send(fut).await + async fn send(self: Pin<&mut Self>, future: FutT) -> super::ConsumerState { + let this = self.project(); + let fut = MapFuture::new(this.f.clone(), future); + this.inner.send(fut).await } - async fn flush(&mut self) -> Self::Output { - self.inner.flush().await + async fn flush(self: Pin<&mut Self>) -> Self::Output { + let this = self.project(); + this.inner.flush().await } } diff --git a/src/concurrent_stream/mod.rs b/src/concurrent_stream/mod.rs index 01d2efd..1ac985f 100644 --- a/src/concurrent_stream/mod.rs +++ b/src/concurrent_stream/mod.rs @@ -12,6 +12,7 @@ mod try_for_each; use core::future::Future; use core::num::NonZeroUsize; +use core::pin::Pin; use for_each::ForEachConsumer; use try_for_each::TryForEachConsumer; @@ -37,18 +38,18 @@ where type Output; /// Send an item down to the next step in the processing queue. - async fn send(&mut self, fut: Fut) -> ConsumerState; + async fn send(self: Pin<&mut Self>, fut: Fut) -> ConsumerState; /// Make progress on the consumer while doing something else. /// /// It should always be possible to drop the future returned by this /// function. This is solely intended to keep work going on the `Consumer` /// while doing e.g. waiting for new futures from a stream. - async fn progress(&mut self) -> ConsumerState; + async fn progress(self: Pin<&mut Self>) -> ConsumerState; /// We have no more data left to send to the `Consumer`; wait for its /// output. - async fn flush(&mut self) -> Self::Output; + async fn flush(self: Pin<&mut Self>) -> Self::Output; } /// Concurrently operate over items in a stream diff --git a/src/concurrent_stream/take.rs b/src/concurrent_stream/take.rs index 8de5cb5..e951147 100644 --- a/src/concurrent_stream/take.rs +++ b/src/concurrent_stream/take.rs @@ -1,6 +1,9 @@ +use pin_project::pin_project; + use super::{ConcurrentStream, Consumer, ConsumerState}; use core::future::Future; use core::num::NonZeroUsize; +use core::pin::Pin; /// A concurrent iterator that only iterates over the first `n` iterations of `iter`. /// @@ -49,7 +52,9 @@ impl ConcurrentStream for Take { } } +#[pin_project] struct TakeConsumer { + #[pin] inner: C, count: usize, limit: usize, @@ -61,22 +66,25 @@ where { type Output = C::Output; - async fn send(&mut self, future: Fut) -> ConsumerState { - self.count += 1; - let state = self.inner.send(future).await; - if self.count >= self.limit { + async fn send(self: Pin<&mut Self>, future: Fut) -> ConsumerState { + let this = self.project(); + *this.count += 1; + let state = this.inner.send(future).await; + if this.count >= this.limit { ConsumerState::Break } else { state } } - async fn progress(&mut self) -> ConsumerState { - self.inner.progress().await + async fn progress(self: Pin<&mut Self>) -> ConsumerState { + let this = self.project(); + this.inner.progress().await } - async fn flush(&mut self) -> Self::Output { - self.inner.flush().await + async fn flush(self: Pin<&mut Self>) -> Self::Output { + let this = self.project(); + this.inner.flush().await } } diff --git a/src/concurrent_stream/try_for_each.rs b/src/concurrent_stream/try_for_each.rs index 55cef21..a8dadf2 100644 --- a/src/concurrent_stream/try_for_each.rs +++ b/src/concurrent_stream/try_for_each.rs @@ -2,9 +2,9 @@ use crate::concurrent_stream::ConsumerState; use crate::future::FutureGroup; use crate::private::Try; use futures_lite::StreamExt; +use pin_project::pin_project; use super::Consumer; -use alloc::boxed::Box; use alloc::sync::Arc; use core::future::Future; use core::marker::PhantomData; @@ -14,7 +14,7 @@ use core::pin::Pin; use core::sync::atomic::{AtomicUsize, Ordering}; use core::task::{ready, Context, Poll}; -// OK: validated! - all bounds should check out +#[pin_project] pub(crate) struct TryForEachConsumer where FutT: Future, @@ -25,7 +25,8 @@ where // NOTE: we can remove the `Arc` here if we're willing to make this struct self-referential count: Arc, // TODO: remove the `Pin` from this signature by requiring this struct is pinned - group: Pin>>>, + #[pin] + group: FutureGroup>, limit: usize, residual: Option, f: F, @@ -49,7 +50,7 @@ where f, residual: None, count: Arc::new(AtomicUsize::new(0)), - group: Box::pin(FutureGroup::new()), + group: FutureGroup::new(), _phantom: PhantomData, } } @@ -65,15 +66,16 @@ where { type Output = B; - async fn send(&mut self, future: FutT) -> super::ConsumerState { + async fn send(self: Pin<&mut Self>, future: FutT) -> super::ConsumerState { + let mut this = self.project(); // If we have no space, we're going to provide backpressure until we have space - while self.count.load(Ordering::Relaxed) >= self.limit { - match self.group.next().await { + while this.count.load(Ordering::Relaxed) >= *this.limit { + match this.group.next().await { None => break, Some(res) => match res.branch() { ControlFlow::Continue(_) => todo!(), ControlFlow::Break(residual) => { - self.residual = Some(residual); + *this.residual = Some(residual); return ConsumerState::Break; } }, @@ -81,32 +83,34 @@ where } // Space was available! - insert the item for posterity - self.count.fetch_add(1, Ordering::Relaxed); - let fut = TryForEachFut::new(self.f.clone(), future, self.count.clone()); - self.group.as_mut().insert_pinned(fut); + this.count.fetch_add(1, Ordering::Relaxed); + let fut = TryForEachFut::new(this.f.clone(), future, this.count.clone()); + this.group.as_mut().insert_pinned(fut); ConsumerState::Continue } - async fn progress(&mut self) -> super::ConsumerState { - while let Some(res) = self.group.next().await { + async fn progress(self: Pin<&mut Self>) -> super::ConsumerState { + let mut this = self.project(); + while let Some(res) = this.group.next().await { if let ControlFlow::Break(residual) = res.branch() { - self.residual = Some(residual); + *this.residual = Some(residual); return ConsumerState::Break; } } ConsumerState::Empty } - async fn flush(&mut self) -> Self::Output { + async fn flush(self: Pin<&mut Self>) -> Self::Output { + let mut this = self.project(); // Return the error if we stopped iteration because of a previous error. - if self.residual.is_some() { - return B::from_residual(self.residual.take().unwrap()); + if this.residual.is_some() { + return B::from_residual(this.residual.take().unwrap()); } // We will no longer receive any additional futures from the // underlying stream; wait until all the futures in the group have // resolved. - while let Some(res) = self.group.next().await { + while let Some(res) = this.group.next().await { if let ControlFlow::Break(residual) = res.branch() { return B::from_residual(residual); }