diff --git a/src/stream/merge/array.rs b/src/stream/merge/array.rs index 9020040..1413dea 100644 --- a/src/stream/merge/array.rs +++ b/src/stream/merge/array.rs @@ -1,8 +1,7 @@ use super::Merge as MergeTrait; use crate::stream::IntoStream; -use crate::utils::{self, PollState}; +use crate::utils::{self, PollState, PollStates, RandomGenerator, WakerList}; -use core::array; use core::fmt; use futures_core::Stream; use std::pin::Pin; @@ -22,8 +21,10 @@ where { #[pin] streams: [S; N], - rng: utils::RandomGenerator, - poll_state: [PollState; N], + rng: RandomGenerator, + complete: usize, + wakers: WakerList, + state: PollStates, done: bool, } @@ -32,10 +33,13 @@ where S: Stream, { pub(crate) fn new(streams: [S; N]) -> Self { + let len = streams.len(); Self { + wakers: WakerList::new(len), + state: PollStates::new(len), streams, - rng: utils::RandomGenerator::new(), - poll_state: array::from_fn(|_| PollState::default()), + rng: RandomGenerator::new(), + complete: 0, done: false, } } @@ -59,28 +63,50 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + // Iterate over our streams one-by-one. If a stream yields a value, // we exit early. By default we'll return `Poll::Ready(None)`, but // this changes if we encounter a `Poll::Pending`. - let mut res = Poll::Ready(None); - let index = this.rng.generate(N as u32) as usize; - - for index in (0..N).map(|pos| (index + pos).wrapping_rem(N)) { - if this.poll_state[index].is_consumed() { + let len = this.streams.len(); + let r = this.rng.generate(len as u32) as usize; + for index in (0..len).map(|n| (r + n).wrapping_rem(len)) { + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; + } else if !readiness.clear_ready(index) || this.state[index].is_consumed() { continue; } + // unlock readiness so we don't deadlock when polling + drop(readiness); + + // Obtain the intermediate waker. + let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); + let stream = utils::get_pin_mut(this.streams.as_mut(), index).unwrap(); - match stream.poll_next(cx) { - Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), + match stream.poll_next(&mut cx) { + Poll::Ready(Some(item)) => { + // Mark ourselves as ready again because we need to poll for the next item. + this.wakers.readiness().lock().unwrap().set_ready(index); + return Poll::Ready(Some(item)); + } Poll::Ready(None) => { - this.poll_state[index] = PollState::Consumed; - continue; + *this.complete += 1; + this.state[index] = PollState::Consumed; + if *this.complete == this.streams.len() { + return Poll::Ready(None); + } } - Poll::Pending => res = Poll::Pending, + Poll::Pending => {} } + + // Lock readiness so we can use it again + readiness = this.wakers.readiness().lock().unwrap(); } - res + + Poll::Pending } } @@ -98,13 +124,22 @@ where #[cfg(test)] mod tests { + use std::cell::RefCell; + use std::collections::VecDeque; + use std::rc::Rc; + use std::task::Waker; + use super::*; + use futures::executor::LocalPool; + use futures::task::LocalSpawnExt; use futures_lite::future::block_on; use futures_lite::prelude::*; use futures_lite::stream; + use crate::future::join::Join; + #[test] - fn merge_tuple_4() { + fn merge_array_4() { block_on(async { let a = stream::once(1); let b = stream::once(2); @@ -119,4 +154,134 @@ mod tests { assert_eq!(counter, 10); }) } + + #[test] + fn merge_array_2x2() { + block_on(async { + let a = stream::repeat(1).take(2); + let b = stream::repeat(2).take(2); + let mut s = [a, b].merge(); + + let mut counter = 0; + while let Some(n) = s.next().await { + counter += n; + } + assert_eq!(counter, 6); + }) + } + + /// This test case uses channels so we'll have streams that return Pending from time to time. + /// + /// The purpose of this test is to make sure we have the waking logic working. + #[test] + fn merge_channels() { + struct LocalChannel { + queue: VecDeque, + waker: Option, + closed: bool, + } + + struct LocalReceiver { + channel: Rc>>, + } + + impl Stream for LocalReceiver { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut channel = self.channel.borrow_mut(); + + match channel.queue.pop_front() { + Some(item) => Poll::Ready(Some(item)), + None => { + if channel.closed { + Poll::Ready(None) + } else { + channel.waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } + } + } + + struct LocalSender { + channel: Rc>>, + } + + impl LocalSender { + fn send(&self, item: T) { + let mut channel = self.channel.borrow_mut(); + + channel.queue.push_back(item); + + let _ = channel.waker.take().map(Waker::wake); + } + } + + impl Drop for LocalSender { + fn drop(&mut self) { + let mut channel = self.channel.borrow_mut(); + channel.closed = true; + let _ = channel.waker.take().map(Waker::wake); + } + } + + fn local_channel() -> (LocalSender, LocalReceiver) { + let channel = Rc::new(RefCell::new(LocalChannel { + queue: VecDeque::new(), + waker: None, + closed: false, + })); + + ( + LocalSender { + channel: channel.clone(), + }, + LocalReceiver { channel }, + ) + } + + let mut pool = LocalPool::new(); + + let done = Rc::new(RefCell::new(false)); + let done2 = done.clone(); + + pool.spawner() + .spawn_local(async move { + let (send1, receive1) = local_channel(); + let (send2, receive2) = local_channel(); + let (send3, receive3) = local_channel(); + + let (count, ()) = ( + async { + [receive1, receive2, receive3] + .merge() + .fold(0, |a, b| a + b) + .await + }, + async { + for i in 1..=4 { + send1.send(i); + send2.send(i); + send3.send(i); + } + drop(send1); + drop(send2); + drop(send3); + }, + ) + .join() + .await; + + assert_eq!(count, 30); + + *done2.borrow_mut() = true; + }) + .unwrap(); + + while !*done.borrow() { + pool.run_until_stalled() + } + } }