Skip to content

Commit

Permalink
use the private Try trait for try_for_each
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshuawuyts committed Mar 17, 2024
1 parent a1695a6 commit 46ab3f4
Showing 1 changed file with 47 additions and 36 deletions.
83 changes: 47 additions & 36 deletions src/concurrent_stream/try_for_each.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::concurrent_stream::ConsumerState;
use crate::future::FutureGroup;
use crate::private::Try;
use futures_lite::StreamExt;

use super::Consumer;
Expand All @@ -8,32 +9,35 @@ use alloc::sync::Arc;
use core::future::Future;
use core::marker::PhantomData;
use core::num::NonZeroUsize;
use core::ops::ControlFlow;
use core::pin::Pin;
use core::sync::atomic::{AtomicUsize, Ordering};
use core::task::{ready, Context, Poll};

// OK: validated! - all bounds should check out
pub(crate) struct TryForEachConsumer<FutT, T, F, FutB, E>
pub(crate) struct TryForEachConsumer<FutT, T, F, FutB, B>
where
FutT: Future<Output = T>,
F: Fn(T) -> FutB,
FutB: Future<Output = Result<(), E>>,
F: Clone + Fn(T) -> FutB,
FutB: Future<Output = B>,
B: Try<Output = ()>,
{
// NOTE: we can remove the `Arc` here if we're willing to make this struct self-referential
count: Arc<AtomicUsize>,
// TODO: remove the `Pin<Box>` from this signature by requiring this struct is pinned
group: Pin<Box<FutureGroup<TryForEachFut<F, FutT, T, FutB, E>>>>,
group: Pin<Box<FutureGroup<TryForEachFut<F, FutT, T, FutB, B>>>>,
limit: usize,
err: Option<E>,
err: Option<B::Residual>,
f: F,
_phantom: PhantomData<(T, FutB)>,
}

impl<A, T, F, B, E> TryForEachConsumer<A, T, F, B, E>
impl<FutT, T, F, FutB, B> TryForEachConsumer<FutT, T, F, FutB, B>
where
A: Future<Output = T>,
F: Fn(T) -> B,
B: Future<Output = Result<(), E>>,
FutT: Future<Output = T>,
F: Clone + Fn(T) -> FutB,
FutB: Future<Output = B>,
B: Try<Output = ()>,
{
pub(crate) fn new(limit: Option<NonZeroUsize>, f: F) -> Self {
let limit = match limit {
Expand All @@ -52,26 +56,28 @@ where
}

// OK: validated! - we push types `B` into the next consumer
impl<FutT, T, F, B, E> Consumer<T, FutT> for TryForEachConsumer<FutT, T, F, B, E>
impl<FutT, T, F, FutB, B> Consumer<T, FutT> for TryForEachConsumer<FutT, T, F, FutB, B>
where
FutT: Future<Output = T>,
F: Fn(T) -> B,
F: Clone,
B: Future<Output = Result<(), E>>,
F: Clone + Fn(T) -> FutB,
FutB: Future<Output = B>,
B: Try<Output = ()>,
{
type Output = Result<(), E>;
type Output = B;

async fn send(&mut self, future: FutT) -> super::ConsumerState {
// If we have no space, we're going to provide backpressure until we have space
while self.count.load(Ordering::Relaxed) >= self.limit {
match self.group.next().await {
Some(Ok(_)) => continue,
Some(Err(err)) => {
self.err = Some(err);
return ConsumerState::Break;
}
None => break,
};
Some(res) => match res.branch() {
ControlFlow::Continue(_) => todo!(),
ControlFlow::Break(residual) => {
self.err = Some(residual);
return ConsumerState::Break;
}
},
}
}

// Space was available! - insert the item for posterity
Expand All @@ -83,8 +89,8 @@ where

async fn progress(&mut self) -> super::ConsumerState {
while let Some(res) = self.group.next().await {
if let Err(err) = res {
self.err = Some(err);
if let ControlFlow::Break(residual) = res.branch() {
self.err = Some(residual);
return ConsumerState::Break;
}
}
Expand All @@ -93,27 +99,30 @@ where

async fn finish(mut self) -> Self::Output {
// Return the error if we stopped iteration because of a previous error.
if let Some(err) = self.err {
return Err(err);
if let Some(residual) = self.err {
return B::from_residual(residual);
}

// We will no longer receive any additional futures from the
// underlying stream; wait until all the futures in the group have
// resolved.
while let Some(res) = self.group.next().await {
res?;
if let ControlFlow::Break(residual) = res.branch() {
return B::from_residual(residual);
}
}
Ok(())
B::from_output(())
}
}

/// Takes a future and maps it to another future via a closure
#[derive(Debug)]
pub struct TryForEachFut<F, FutT, T, FutB, E>
pub struct TryForEachFut<F, FutT, T, FutB, B>
where
FutT: Future<Output = T>,
F: Fn(T) -> FutB,
FutB: Future<Output = Result<(), E>>,
F: Clone + Fn(T) -> FutB,
FutB: Future<Output = B>,
B: Try<Output = ()>,
{
done: bool,
count: Arc<AtomicUsize>,
Expand All @@ -122,11 +131,12 @@ where
fut_b: Option<FutB>,
}

impl<F, FutT, T, FutB, E> TryForEachFut<F, FutT, T, FutB, E>
impl<F, FutT, T, FutB, B> TryForEachFut<F, FutT, T, FutB, B>
where
FutT: Future<Output = T>,
F: Fn(T) -> FutB,
FutB: Future<Output = Result<(), E>>,
F: Clone + Fn(T) -> FutB,
FutB: Future<Output = B>,
B: Try<Output = ()>,
{
fn new(f: F, fut_t: FutT, count: Arc<AtomicUsize>) -> Self {
Self {
Expand All @@ -139,13 +149,14 @@ where
}
}

impl<F, FutT, T, FutB, E> Future for TryForEachFut<F, FutT, T, FutB, E>
impl<F, FutT, T, FutB, B> Future for TryForEachFut<F, FutT, T, FutB, B>
where
FutT: Future<Output = T>,
F: Fn(T) -> FutB,
FutB: Future<Output = Result<(), E>>,
F: Clone + Fn(T) -> FutB,
FutB: Future<Output = B>,
B: Try<Output = ()>,
{
type Output = Result<(), E>;
type Output = B;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// SAFETY: we need to access the inner future's fields to project them
Expand Down

0 comments on commit 46ab3f4

Please sign in to comment.