From 0af8217512f8ceb4bbfff15913ae2438ee1165e0 Mon Sep 17 00:00:00 2001 From: John Starks Date: Mon, 27 Jan 2025 18:35:25 +0000 Subject: [PATCH] Fix RaceOk tuple and array impls Both the tuple and array implementations of `RaceOk` would poll futures that had already been polled to completion. Fix these by tracking per-future state. Additionally, the array implementation failed to drop any accumulated errors in the success case. Add a drop implementation to fix this. --- src/future/race_ok/array/mod.rs | 94 +++++++++++++++++++++++++++++++-- src/future/race_ok/tuple/mod.rs | 20 +++++++ 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/src/future/race_ok/array/mod.rs b/src/future/race_ok/array/mod.rs index 9772035..d3d1ca8 100644 --- a/src/future/race_ok/array/mod.rs +++ b/src/future/race_ok/array/mod.rs @@ -1,6 +1,7 @@ use super::RaceOk as RaceOkTrait; use crate::utils::array_assume_init; use crate::utils::iter_pin_mut; +use crate::utils::PollArray; use core::array; use core::fmt; @@ -9,7 +10,7 @@ use core::mem::{self, MaybeUninit}; use core::pin::Pin; use core::task::{Context, Poll}; -use pin_project::pin_project; +use pin_project::{pin_project, pinned_drop}; mod error; @@ -23,7 +24,7 @@ pub use error::AggregateError; /// [`race_ok`]: crate::future::RaceOk::race_ok /// [`RaceOk`]: crate::future::RaceOk #[must_use = "futures do nothing unless you `.await` or poll them"] -#[pin_project] +#[pin_project(PinnedDrop)] pub struct RaceOk where Fut: Future>, @@ -31,9 +32,30 @@ where #[pin] futures: [Fut; N], errors: [MaybeUninit; N], + error_states: PollArray, completed: usize, } +#[pinned_drop] +impl PinnedDrop for RaceOk +where + Fut: Future>, +{ + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + for (st, err) in this + .error_states + .iter_mut() + .zip(this.errors.iter_mut()) + .filter(|(st, _err)| st.is_ready()) + { + // SAFETY: we've filtered down to only the `ready`/initialized data + unsafe { err.assume_init_drop() }; + st.set_none(); + } + } +} + impl fmt::Debug for RaceOk where Fut: Future> + fmt::Debug, @@ -55,13 +77,20 @@ where let futures = iter_pin_mut(this.futures); - for (fut, out) in futures.zip(this.errors.iter_mut()) { + for ((fut, out), st) in futures + .zip(this.errors.iter_mut()) + .zip(this.error_states.iter_mut()) + { + if st.is_ready() { + continue; + } if let Poll::Ready(output) = fut.poll(cx) { match output { Ok(ok) => return Poll::Ready(Ok(ok)), Err(err) => { *out = MaybeUninit::new(err); *this.completed += 1; + st.set_ready(); } } } @@ -71,6 +100,7 @@ where if all_completed { let mut errors = array::from_fn(|_| MaybeUninit::uninit()); mem::swap(&mut errors, this.errors); + this.error_states.set_all_none(); // SAFETY: we know that all futures are properly initialized because they're all completed let result = unsafe { array_assume_init(errors) }; @@ -94,6 +124,7 @@ where RaceOk { futures: self.map(|fut| fut.into_future()), errors: array::from_fn(|_| MaybeUninit::uninit()), + error_states: PollArray::new_pending(), completed: 0, } } @@ -138,4 +169,61 @@ mod test { assert_eq!(errs[1], "oh no"); }); } + + #[test] + fn resume_after_completion() { + use futures_lite::future::yield_now; + futures_lite::future::block_on(async { + let fut = |ok| async move { + if ok { + yield_now().await; + yield_now().await; + Ok(()) + } else { + Err(()) + } + }; + + let res = [fut(true), fut(false)].race_ok().await; + assert_eq!(res.ok().unwrap(), ()); + }); + } + + #[test] + fn drop_errors() { + use futures_lite::future::yield_now; + + struct Droper<'a>(&'a core::cell::Cell); + impl Drop for Droper<'_> { + fn drop(&mut self) { + self.0.set(self.0.get() + 1); + } + } + + futures_lite::future::block_on(async { + let drop_count = Default::default(); + let fut = |ok| { + let drop_count = &drop_count; + async move { + if ok { + yield_now().await; + yield_now().await; + Ok(()) + } else { + Err(Droper(drop_count)) + } + } + }; + let res = [fut(true), fut(false)].race_ok().await; + assert_eq!(drop_count.get(), 1); + assert_eq!(res.ok().unwrap(), ()); + + drop_count.set(0); + let res = [fut(false), fut(false)].race_ok().await; + assert!(res.is_err()); + assert_eq!(drop_count.get(), 0); + drop(res); + assert_eq!(drop_count.get(), 2); + }) + } } diff --git a/src/future/race_ok/tuple/mod.rs b/src/future/race_ok/tuple/mod.rs index 43dd986..2556435 100644 --- a/src/future/race_ok/tuple/mod.rs +++ b/src/future/race_ok/tuple/mod.rs @@ -101,6 +101,9 @@ macro_rules! impl_race_ok_tuple { } for i in this.indexer.iter() { + if this.errors_states[i].is_ready() { + continue; + } utils::gen_conditions!(i, this, cx, poll, $((Indexes::$F as usize; $F, { Poll::Ready(output) => match output { Ok(output) => { @@ -219,4 +222,21 @@ mod test { assert_eq!(errors[1], "world"); }); } + + #[test] + fn race_ok_resume_after_completion() { + use futures_lite::future::yield_now; + futures_lite::future::block_on(async { + let ok = async { + yield_now().await; + yield_now().await; + Ok::<_, ()>(()) + }; + let err = async { Err::<(), _>(()) }; + + let res = (ok, err).race_ok().await; + + assert_eq!(res.ok().unwrap(), ()); + }); + } }