Skip to content

Commit

Permalink
Merge pull request #109 from matheus-consoli/impl-race-ok-for-tuple
Browse files Browse the repository at this point in the history
Initial implementation of `RaceOk` for tuples
  • Loading branch information
eholk authored Dec 1, 2022
2 parents 058da68 + 983cbc5 commit 38fbdb7
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/future/race_ok/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use core::future::Future;

pub(crate) mod array;
pub(crate) mod tuple;
pub(crate) mod vec;

/// Wait for the first successful future to complete.
Expand Down
54 changes: 54 additions & 0 deletions src/future/race_ok/tuple/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use core::fmt;
use core::ops::{Deref, DerefMut};
use std::error::Error;

/// A collection of errors.
#[repr(transparent)]
pub struct AggregateError<E, const N: usize> {
inner: [E; N],
}

impl<E, const N: usize> AggregateError<E, N> {
pub(super) fn new(inner: [E; N]) -> Self {
Self { inner }
}
}

impl<E: Error, const N: usize> fmt::Debug for AggregateError<E, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "{self}:")?;

for (i, err) in self.inner.iter().enumerate() {
writeln!(f, "- Error {}: {err}", i + 1)?;
let mut source = err.source();
while let Some(err) = source {
writeln!(f, " ↳ Caused by: {err}")?;
source = err.source();
}
}

Ok(())
}
}

impl<E: Error, const N: usize> fmt::Display for AggregateError<E, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} errors occured", self.inner.len())
}
}

impl<E, const N: usize> Deref for AggregateError<E, N> {
type Target = [E; N];

fn deref(&self) -> &Self::Target {
&self.inner
}
}

impl<E, const N: usize> DerefMut for AggregateError<E, N> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}

impl<E: Error, const N: usize> std::error::Error for AggregateError<E, N> {}
192 changes: 192 additions & 0 deletions src/future/race_ok/tuple/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use super::RaceOk;
use crate::utils;

use core::array;
use core::fmt;
use core::future::{Future, IntoFuture};
use core::mem::{self, MaybeUninit};
use core::pin::Pin;
use core::task::{Context, Poll};

use pin_project::pin_project;

mod error;
pub(crate) use error::AggregateError;

macro_rules! impl_race_ok_tuple {
($StructName:ident $($F:ident)+) => {
/// Wait for the first successful future to complete.
///
/// This `struct` is created by the [`race_ok`] method on the [`RaceOk`] trait. See
/// its documentation for more.
///
/// [`race_ok`]: crate::future::RaceOk::race_ok
/// [`RaceOk`]: crate::future::RaceOk
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[allow(non_snake_case)]
#[pin_project]
pub struct $StructName<T, ERR, $($F),*>
where
$( $F: Future<Output = Result<T, ERR>>, )*
ERR: fmt::Debug,
{
completed: usize,
done: bool,
indexer: utils::Indexer,
errors: [MaybeUninit<ERR>; {utils::tuple_len!($($F,)*)}],
$(#[pin] $F: $F,)*
}

impl<T, ERR, $($F),*> fmt::Debug for $StructName<T, ERR, $($F),*>
where
$( $F: Future<Output = Result<T, ERR>> + fmt::Debug, )*
T: fmt::Debug,
ERR: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Race")
$(.field(&self.$F))*
.finish()
}
}

impl<T, ERR, $($F),*> RaceOk for ($($F,)*)
where
$( $F: IntoFuture<Output = Result<T, ERR>>, )*
ERR: fmt::Debug,
{
type Output = T;
type Error = AggregateError<ERR, {utils::tuple_len!($($F,)*)}>;
type Future = $StructName<T, ERR, $($F::IntoFuture),*>;

fn race_ok(self) -> Self::Future {
let ($($F,)*): ($($F,)*) = self;
$StructName {
completed: 0,
done: false,
indexer: utils::Indexer::new(utils::tuple_len!($($F,)*)),
errors: array::from_fn(|_| MaybeUninit::uninit()),
$($F: $F.into_future()),*
}
}
}

impl<T, ERR, $($F: Future),*> Future for $StructName<T, ERR, $($F),*>
where
$( $F: Future<Output = Result<T, ERR>>, )*
ERR: fmt::Debug,
{
type Output = Result<T, AggregateError<ERR, {utils::tuple_len!($($F,)*)}>>;

fn poll(
self: Pin<&mut Self>, cx: &mut Context<'_>
) -> Poll<Self::Output> {
const LEN: usize = utils::tuple_len!($($F,)*);

let mut this = self.project();

let can_poll = !*this.done;
assert!(can_poll, "Futures must not be polled after completing");

#[repr(usize)]
enum Indexes {
$($F),*
}

for i in this.indexer.iter() {
utils::gen_conditions!(i, this, cx, poll, $((Indexes::$F as usize; $F, {
Poll::Ready(output) => match output {
Ok(output) => {
*this.done = true;
*this.completed += 1;
return Poll::Ready(Ok(output));
},
Err(err) => {
this.errors[i] = MaybeUninit::new(err);
*this.completed += 1;
continue;
},
},
_ => continue,
}))*);
}

let all_completed = *this.completed == LEN;
if all_completed {
let mut errors = array::from_fn(|_| MaybeUninit::uninit());
mem::swap(&mut errors, this.errors);

let result = unsafe { utils::array_assume_init(errors) };

*this.done = true;
return Poll::Ready(Err(AggregateError::new(result)));
}

Poll::Pending
}
}
};
}

impl_race_ok_tuple! { RaceOk1 A }
impl_race_ok_tuple! { RaceOk2 A B }
impl_race_ok_tuple! { RaceOk3 A B C }
impl_race_ok_tuple! { RaceOk4 A B C D }
impl_race_ok_tuple! { RaceOk5 A B C D E }
impl_race_ok_tuple! { RaceOk6 A B C D E F }
impl_race_ok_tuple! { RaceOk7 A B C D E F G }
impl_race_ok_tuple! { RaceOk8 A B C D E F G H }
impl_race_ok_tuple! { RaceOk9 A B C D E F G H I }
impl_race_ok_tuple! { RaceOk10 A B C D E F G H I J }
impl_race_ok_tuple! { RaceOk11 A B C D E F G H I J K }
impl_race_ok_tuple! { RaceOk12 A B C D E F G H I J K L }

#[cfg(test)]
mod test {
use super::*;
use core::future;
use std::error::Error;

type DynError = Box<dyn Error>;

#[test]
fn race_ok_1() {
futures_lite::future::block_on(async {
let a = async { Ok::<_, DynError>("world") };
let res = (a,).race_ok().await;
assert!(matches!(res, Ok("world")));
});
}

#[test]
fn race_ok_2() {
futures_lite::future::block_on(async {
let a = future::pending();
let b = async { Ok::<_, DynError>("world") };
let res = (a, b).race_ok().await;
assert!(matches!(res, Ok("world")));
});
}

#[test]
fn race_ok_3() {
futures_lite::future::block_on(async {
let a = future::pending();
let b = async { Ok::<_, DynError>("hello") };
let c = async { Ok::<_, DynError>("world") };
let result = (a, b, c).race_ok().await;
assert!(matches!(result, Ok("hello") | Ok("world")));
});
}

#[test]
fn race_ok_err() {
futures_lite::future::block_on(async {
let a = async { Err::<(), _>("hello") };
let b = async { Err::<(), _>("world") };
let errors = (a, b).race_ok().await.unwrap_err();
assert_eq!(errors[0], "hello");
assert_eq!(errors[1], "world");
});
}
}

0 comments on commit 38fbdb7

Please sign in to comment.