From 6907e20e58918cc5535d5796e3ce42ccbd31a2d9 Mon Sep 17 00:00:00 2001 From: Yoshua Wuyts <2467194+yoshuawuyts@users.noreply.github.com> Date: Tue, 15 Nov 2022 00:37:43 +0100 Subject: [PATCH] perfect waker for `array::Merge` --- src/stream/merge/array.rs | 212 +++++++++++++++++++++++++++++++++----- 1 file changed, 185 insertions(+), 27 deletions(-) diff --git a/src/stream/merge/array.rs b/src/stream/merge/array.rs index 1c8de5b..14f53f9 100644 --- a/src/stream/merge/array.rs +++ b/src/stream/merge/array.rs @@ -1,6 +1,6 @@ use super::Merge as MergeTrait; use crate::stream::IntoStream; -use crate::utils::{self, Fuse}; +use crate::utils::{self, Fuse, RandomGenerator, WakerList}; use core::fmt; use futures_core::Stream; @@ -21,7 +21,9 @@ where { #[pin] streams: [Fuse; N], - rng: utils::RandomGenerator, + rng: RandomGenerator, + complete: usize, + wakers: WakerList, } impl Merge @@ -30,8 +32,10 @@ where { pub(crate) fn new(streams: [S; N]) -> Self { Self { + wakers: WakerList::new(streams.len()), streams: streams.map(Fuse::new), - rng: utils::RandomGenerator::new(), + rng: RandomGenerator::new(), + complete: 0, } } } @@ -54,35 +58,50 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); - // Randomize the indexes into our streams array. This ensures that when - // multiple streams are ready at the same time, we don't accidentally - // exhaust one stream before another. - let mut arr: [usize; N] = { - // this is an inlined version of `core::array::from_fn` - // TODO: replace this with `core::array::from_fn` when it becomes stable - let cb = |n| n; - let mut idx = 0; - [(); N].map(|_| { - let res = cb(idx); - idx += 1; - res - }) - }; - arr.sort_by_cached_key(|_| this.rng.generate(1000)); - // Iterate over our streams one-by-one. If a stream yields a value, // we exit early. By default we'll return `Poll::Ready(None)`, but // this changes if we encounter a `Poll::Pending`. - let mut res = Poll::Ready(None); - for index in arr { + let mut index = this.rng.generate(this.streams.len() as u32) as usize; + + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + loop { + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; + } + + index = (index + 1).wrapping_rem(this.streams.len()); + + if !readiness.clear_ready(index) { + continue; + } + + // unlock readiness so we don't deadlock when polling + drop(readiness); + + // Obtain the intermediate waker. + let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); + let stream = utils::get_pin_mut(this.streams.as_mut(), index).unwrap(); - match stream.poll_next(cx) { - Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), - Poll::Ready(None) => continue, - Poll::Pending => res = Poll::Pending, + match stream.poll_next(&mut cx) { + Poll::Ready(Some(item)) => { + // Mark ourselves as ready again because we need to poll for the next item. + this.wakers.readiness().lock().unwrap().set_ready(index); + return Poll::Ready(Some(item)); + } + Poll::Ready(None) => { + *this.complete += 1; + if *this.complete == this.streams.len() { + return Poll::Ready(None); + } + } + Poll::Pending => {} } + + // Lock readiness so we can use it again + readiness = this.wakers.readiness().lock().unwrap(); } - res } } @@ -100,13 +119,22 @@ where #[cfg(test)] mod tests { + use std::cell::RefCell; + use std::collections::VecDeque; + use std::rc::Rc; + use std::task::Waker; + use super::*; + use futures::executor::LocalPool; + use futures::task::LocalSpawnExt; use futures_lite::future::block_on; use futures_lite::prelude::*; use futures_lite::stream; + use crate::future::join::Join; + #[test] - fn merge_tuple_4() { + fn merge_vec_4() { block_on(async { let a = stream::once(1); let b = stream::once(2); @@ -121,4 +149,134 @@ mod tests { assert_eq!(counter, 10); }) } + + #[test] + fn merge_vec_2x2() { + block_on(async { + let a = stream::repeat(1).take(2); + let b = stream::repeat(2).take(2); + let mut s = [a, b].merge(); + + let mut counter = 0; + while let Some(n) = s.next().await { + counter += n; + } + assert_eq!(counter, 6); + }) + } + + /// This test case uses channels so we'll have streams that return Pending from time to time. + /// + /// The purpose of this test is to make sure we have the waking logic working. + #[test] + fn merge_channels() { + struct LocalChannel { + queue: VecDeque, + waker: Option, + closed: bool, + } + + struct LocalReceiver { + channel: Rc>>, + } + + impl Stream for LocalReceiver { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut channel = self.channel.borrow_mut(); + + match channel.queue.pop_front() { + Some(item) => Poll::Ready(Some(item)), + None => { + if channel.closed { + Poll::Ready(None) + } else { + channel.waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } + } + } + + struct LocalSender { + channel: Rc>>, + } + + impl LocalSender { + fn send(&self, item: T) { + let mut channel = self.channel.borrow_mut(); + + channel.queue.push_back(item); + + let _ = channel.waker.take().map(Waker::wake); + } + } + + impl Drop for LocalSender { + fn drop(&mut self) { + let mut channel = self.channel.borrow_mut(); + channel.closed = true; + let _ = channel.waker.take().map(Waker::wake); + } + } + + fn local_channel() -> (LocalSender, LocalReceiver) { + let channel = Rc::new(RefCell::new(LocalChannel { + queue: VecDeque::new(), + waker: None, + closed: false, + })); + + ( + LocalSender { + channel: channel.clone(), + }, + LocalReceiver { channel }, + ) + } + + let mut pool = LocalPool::new(); + + let done = Rc::new(RefCell::new(false)); + let done2 = done.clone(); + + pool.spawner() + .spawn_local(async move { + let (send1, receive1) = local_channel(); + let (send2, receive2) = local_channel(); + let (send3, receive3) = local_channel(); + + let (count, ()) = ( + async { + [receive1, receive2, receive3] + .merge() + .fold(0, |a, b| a + b) + .await + }, + async { + for i in 1..=4 { + send1.send(i); + send2.send(i); + send3.send(i); + } + drop(send1); + drop(send2); + drop(send3); + }, + ) + .join() + .await; + + assert_eq!(count, 30); + + *done2.borrow_mut() = true; + }) + .unwrap(); + + while !*done.borrow() { + pool.run_until_stalled() + } + } }