diff --git a/src/future/join/array.rs b/src/future/join/array.rs index 50529f5..a673603 100644 --- a/src/future/join/array.rs +++ b/src/future/join/array.rs @@ -1,5 +1,5 @@ use super::Join as JoinTrait; -use crate::utils::{self, PollArray}; +use crate::utils::{self, PollArray, WakerArray}; use core::array; use core::fmt; @@ -26,6 +26,7 @@ where consumed: bool, pending: usize, items: [MaybeUninit<::Output>; N], + wakers: WakerArray, state: PollArray, #[pin] futures: [Fut; N], @@ -41,6 +42,7 @@ where consumed: false, pending: N, items: array::from_fn(|_| MaybeUninit::uninit()), + wakers: WakerArray::new(), state: PollArray::new(), futures, } @@ -85,12 +87,32 @@ where "Futures must not be polled after completing" ); - // Poll all futures + // Mark futures as ready according to the wakers + { + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; + } + + for (i, state) in this.state.iter_mut().enumerate() { + if !state.is_consumed() && readiness.clear_ready(i) { + state.set_ready(); + } + } + } + + // Poll all ready futures for (i, fut) in utils::iter_pin_mut(this.futures.as_mut()).enumerate() { - if this.state[i].is_pending() { - if let Poll::Ready(value) = fut.poll(cx) { + if this.state[i].is_ready() { + // Obtain the intermediate waker. + let mut cx = Context::from_waker(this.wakers.get(i).unwrap()); + + if let Poll::Ready(value) = fut.poll(&mut cx) { this.items[i] = MaybeUninit::new(value); - this.state[i].set_ready(); + this.state[i].set_consumed(); *this.pending -= 1; } } @@ -102,10 +124,9 @@ where *this.consumed = true; for state in this.state.iter_mut() { debug_assert!( - state.is_ready(), - "Future should have reached a `Ready` state" + state.is_consumed(), + "Future should have reached a `Consumed` state" ); - state.set_consumed(); } let mut items = array::from_fn(|_| MaybeUninit::uninit()); @@ -135,7 +156,7 @@ where .state .iter_mut() .enumerate() - .filter(|(_, state)| state.is_ready()) + .filter(|(_, state)| state.is_consumed()) .map(|(i, _)| i); // Drop each value at the index. diff --git a/src/future/join/vec.rs b/src/future/join/vec.rs index 4efcf70..1073ff0 100644 --- a/src/future/join/vec.rs +++ b/src/future/join/vec.rs @@ -87,28 +87,34 @@ where "Futures must not be polled after completing" ); - let mut readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_waker(cx.waker()); + // Mark futures as ready according to the wakers + { + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; + } - if !readiness.any_ready() { - // Nothing is ready yet - return Poll::Pending; + for (i, state) in this.state.iter_mut().enumerate() { + if !state.is_consumed() && readiness.clear_ready(i) { + state.set_ready(); + } + } } - // unlock readiness so we don't deadlock when polling - drop(readiness); - - // Poll all futures + // Poll all ready futures let futures = this.futures.as_mut(); let states = &mut this.state[..]; for (i, fut) in iter_pin_mut_vec(futures).enumerate() { - if states[i].is_pending() && this.wakers.readiness().lock().unwrap().clear_ready(i) { + if states[i].is_ready() { // Obtain the intermediate waker. let mut cx = Context::from_waker(this.wakers.get(i).unwrap()); if let Poll::Ready(value) = fut.poll(&mut cx) { this.items[i] = MaybeUninit::new(value); - states[i].set_ready(); + states[i].set_consumed(); *this.pending -= 1; } } @@ -120,10 +126,9 @@ where *this.consumed = true; this.state.iter_mut().for_each(|state| { debug_assert!( - state.is_ready(), - "Future should have reached a `Ready` state" + state.is_consumed(), + "Future should have reached a `Consumed` state" ); - state.set_consumed(); }); // SAFETY: we've checked with the state that all of our outputs have been @@ -153,7 +158,7 @@ where .state .iter_mut() .enumerate() - .filter(|(_, state)| state.is_ready()) + .filter(|(_, state)| state.is_consumed()) .map(|(i, _)| i); // Drop each value at the index. diff --git a/src/utils/poll_state/poll_state.rs b/src/utils/poll_state/poll_state.rs index 35be7fe..1843ef4 100644 --- a/src/utils/poll_state/poll_state.rs +++ b/src/utils/poll_state/poll_state.rs @@ -17,6 +17,7 @@ impl PollState { /// Returns `true` if the metadata is [`Pending`][Self::Pending]. #[must_use] #[inline] + #[allow(unused)] pub(crate) fn is_pending(&self) -> bool { matches!(self, Self::Pending) }