Skip to content

Commit

Permalink
Merge pull request #204 from jstarks/race_ok
Browse files Browse the repository at this point in the history
Fix RaceOk tuple and array impls
  • Loading branch information
yoshuawuyts authored Jan 29, 2025
2 parents abc7a90 + 0af8217 commit 9dabeb6
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 3 deletions.
94 changes: 91 additions & 3 deletions src/future/race_ok/array/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::RaceOk as RaceOkTrait;
use crate::utils::array_assume_init;
use crate::utils::iter_pin_mut;
use crate::utils::PollArray;

use core::array;
use core::fmt;
Expand All @@ -9,7 +10,7 @@ use core::mem::{self, MaybeUninit};
use core::pin::Pin;
use core::task::{Context, Poll};

use pin_project::pin_project;
use pin_project::{pin_project, pinned_drop};

mod error;

Expand All @@ -23,17 +24,38 @@ pub use error::AggregateError;
/// [`race_ok`]: crate::future::RaceOk::race_ok
/// [`RaceOk`]: crate::future::RaceOk
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[pin_project]
#[pin_project(PinnedDrop)]
pub struct RaceOk<Fut, T, E, const N: usize>
where
Fut: Future<Output = Result<T, E>>,
{
#[pin]
futures: [Fut; N],
errors: [MaybeUninit<E>; N],
error_states: PollArray<N>,
completed: usize,
}

#[pinned_drop]
impl<Fut, T, E, const N: usize> PinnedDrop for RaceOk<Fut, T, E, N>
where
Fut: Future<Output = Result<T, E>>,
{
fn drop(self: Pin<&mut Self>) {
let this = self.project();
for (st, err) in this
.error_states
.iter_mut()
.zip(this.errors.iter_mut())
.filter(|(st, _err)| st.is_ready())
{
// SAFETY: we've filtered down to only the `ready`/initialized data
unsafe { err.assume_init_drop() };
st.set_none();
}
}
}

impl<Fut, T, E, const N: usize> fmt::Debug for RaceOk<Fut, T, E, N>
where
Fut: Future<Output = Result<T, E>> + fmt::Debug,
Expand All @@ -55,13 +77,20 @@ where

let futures = iter_pin_mut(this.futures);

for (fut, out) in futures.zip(this.errors.iter_mut()) {
for ((fut, out), st) in futures
.zip(this.errors.iter_mut())
.zip(this.error_states.iter_mut())
{
if st.is_ready() {
continue;
}
if let Poll::Ready(output) = fut.poll(cx) {
match output {
Ok(ok) => return Poll::Ready(Ok(ok)),
Err(err) => {
*out = MaybeUninit::new(err);
*this.completed += 1;
st.set_ready();
}
}
}
Expand All @@ -71,6 +100,7 @@ where
if all_completed {
let mut errors = array::from_fn(|_| MaybeUninit::uninit());
mem::swap(&mut errors, this.errors);
this.error_states.set_all_none();

// SAFETY: we know that all futures are properly initialized because they're all completed
let result = unsafe { array_assume_init(errors) };
Expand All @@ -94,6 +124,7 @@ where
RaceOk {
futures: self.map(|fut| fut.into_future()),
errors: array::from_fn(|_| MaybeUninit::uninit()),
error_states: PollArray::new_pending(),
completed: 0,
}
}
Expand Down Expand Up @@ -138,4 +169,61 @@ mod test {
assert_eq!(errs[1], "oh no");
});
}

#[test]
fn resume_after_completion() {
use futures_lite::future::yield_now;
futures_lite::future::block_on(async {
let fut = |ok| async move {
if ok {
yield_now().await;
yield_now().await;
Ok(())
} else {
Err(())
}
};

let res = [fut(true), fut(false)].race_ok().await;
assert_eq!(res.ok().unwrap(), ());
});
}

#[test]
fn drop_errors() {
use futures_lite::future::yield_now;

struct Droper<'a>(&'a core::cell::Cell<usize>);
impl Drop for Droper<'_> {
fn drop(&mut self) {
self.0.set(self.0.get() + 1);
}
}

futures_lite::future::block_on(async {
let drop_count = Default::default();
let fut = |ok| {
let drop_count = &drop_count;
async move {
if ok {
yield_now().await;
yield_now().await;
Ok(())
} else {
Err(Droper(drop_count))
}
}
};
let res = [fut(true), fut(false)].race_ok().await;
assert_eq!(drop_count.get(), 1);
assert_eq!(res.ok().unwrap(), ());

drop_count.set(0);
let res = [fut(false), fut(false)].race_ok().await;
assert!(res.is_err());
assert_eq!(drop_count.get(), 0);
drop(res);
assert_eq!(drop_count.get(), 2);
})
}
}
20 changes: 20 additions & 0 deletions src/future/race_ok/tuple/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ macro_rules! impl_race_ok_tuple {
}

for i in this.indexer.iter() {
if this.errors_states[i].is_ready() {
continue;
}
utils::gen_conditions!(i, this, cx, poll, $((Indexes::$F as usize; $F, {
Poll::Ready(output) => match output {
Ok(output) => {
Expand Down Expand Up @@ -219,4 +222,21 @@ mod test {
assert_eq!(errors[1], "world");
});
}

#[test]
fn race_ok_resume_after_completion() {
use futures_lite::future::yield_now;
futures_lite::future::block_on(async {
let ok = async {
yield_now().await;
yield_now().await;
Ok::<_, ()>(())
};
let err = async { Err::<(), _>(()) };

let res = (ok, err).race_ok().await;

assert_eq!(res.ok().unwrap(), ());
});
}
}

0 comments on commit 9dabeb6

Please sign in to comment.