diff --git a/src/stream/merge/vec.rs b/src/stream/merge/vec.rs index 5cc1086..e17b741 100644 --- a/src/stream/merge/vec.rs +++ b/src/stream/merge/vec.rs @@ -1,6 +1,6 @@ use super::Merge as MergeTrait; use crate::stream::IntoStream; -use crate::utils::{self, Fuse, RandomGenerator, WakerList}; +use crate::utils::{self, PollState, PollStates, RandomGenerator, WakerList}; use core::fmt; use futures_core::Stream; @@ -20,10 +20,13 @@ where S: Stream, { #[pin] - streams: Vec>, + streams: Vec, rng: RandomGenerator, complete: usize, wakers: WakerList, + state: PollStates, + done: bool, + len: usize, } impl Merge @@ -31,11 +34,15 @@ where S: Stream, { pub(crate) fn new(streams: Vec) -> Self { + let len = streams.len(); Self { - wakers: WakerList::new(streams.len()), - streams: streams.into_iter().map(Fuse::new).collect(), + wakers: WakerList::new(len), + state: PollStates::new(len), + streams, rng: RandomGenerator::new(), complete: 0, + done: false, + len, } } } @@ -58,22 +65,19 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + // Iterate over our streams one-by-one. If a stream yields a value, // we exit early. By default we'll return `Poll::Ready(None)`, but // this changes if we encounter a `Poll::Pending`. - let mut index = this.rng.generate(this.streams.len() as u32) as usize; - - let mut readiness = this.wakers.readiness().lock().unwrap(); - readiness.set_waker(cx.waker()); - loop { + let len = *this.len; + let r = this.rng.generate(this.streams.len() as u32) as usize; + for index in (0..len).map(|n| (r + n).wrapping_rem(len)) { if !readiness.any_ready() { // Nothing is ready yet return Poll::Pending; - } - - index = (index + 1).wrapping_rem(this.streams.len()); - - if !readiness.clear_ready(index) { + } else if !readiness.clear_ready(index) || this.state[index].is_consumed() { continue; } @@ -92,6 +96,7 @@ where } Poll::Ready(None) => { *this.complete += 1; + this.state[index] = PollState::Consumed; if *this.complete == this.streams.len() { return Poll::Ready(None); } @@ -102,6 +107,8 @@ where // Lock readiness so we can use it again readiness = this.wakers.readiness().lock().unwrap(); } + + Poll::Pending } }