Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perfect waker for array::Merge #75

Merged
merged 1 commit into from
Nov 16, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 183 additions & 18 deletions src/stream/merge/array.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use super::Merge as MergeTrait;
use crate::stream::IntoStream;
use crate::utils::{self, PollState};
use crate::utils::{self, PollState, PollStates, RandomGenerator, WakerList};

use core::array;
use core::fmt;
use futures_core::Stream;
use std::pin::Pin;
Expand All @@ -22,8 +21,10 @@ where
{
#[pin]
streams: [S; N],
rng: utils::RandomGenerator,
poll_state: [PollState; N],
rng: RandomGenerator,
complete: usize,
wakers: WakerList,
state: PollStates,
done: bool,
}

Expand All @@ -32,10 +33,13 @@ where
S: Stream,
{
pub(crate) fn new(streams: [S; N]) -> Self {
let len = streams.len();
Self {
wakers: WakerList::new(len),
state: PollStates::new(len),
streams,
rng: utils::RandomGenerator::new(),
poll_state: array::from_fn(|_| PollState::default()),
rng: RandomGenerator::new(),
complete: 0,
done: false,
}
}
Expand All @@ -59,28 +63,50 @@ where
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();

let mut readiness = this.wakers.readiness().lock().unwrap();
readiness.set_waker(cx.waker());

// Iterate over our streams one-by-one. If a stream yields a value,
// we exit early. By default we'll return `Poll::Ready(None)`, but
// this changes if we encounter a `Poll::Pending`.
let mut res = Poll::Ready(None);
let index = this.rng.generate(N as u32) as usize;

for index in (0..N).map(|pos| (index + pos).wrapping_rem(N)) {
if this.poll_state[index].is_consumed() {
let len = this.streams.len();
let r = this.rng.generate(len as u32) as usize;
for index in (0..len).map(|n| (r + n).wrapping_rem(len)) {
if !readiness.any_ready() {
// Nothing is ready yet
return Poll::Pending;
} else if !readiness.clear_ready(index) || this.state[index].is_consumed() {
continue;
}

// unlock readiness so we don't deadlock when polling
drop(readiness);

// Obtain the intermediate waker.
let mut cx = Context::from_waker(this.wakers.get(index).unwrap());

let stream = utils::get_pin_mut(this.streams.as_mut(), index).unwrap();
match stream.poll_next(cx) {
Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
match stream.poll_next(&mut cx) {
Poll::Ready(Some(item)) => {
// Mark ourselves as ready again because we need to poll for the next item.
this.wakers.readiness().lock().unwrap().set_ready(index);
return Poll::Ready(Some(item));
}
Poll::Ready(None) => {
this.poll_state[index] = PollState::Consumed;
continue;
*this.complete += 1;
this.state[index] = PollState::Consumed;
if *this.complete == this.streams.len() {
return Poll::Ready(None);
}
}
Poll::Pending => res = Poll::Pending,
Poll::Pending => {}
}

// Lock readiness so we can use it again
readiness = this.wakers.readiness().lock().unwrap();
}
res

Poll::Pending
}
}

Expand All @@ -98,13 +124,22 @@ where

#[cfg(test)]
mod tests {
use std::cell::RefCell;
use std::collections::VecDeque;
use std::rc::Rc;
use std::task::Waker;

use super::*;
use futures::executor::LocalPool;
use futures::task::LocalSpawnExt;
use futures_lite::future::block_on;
use futures_lite::prelude::*;
use futures_lite::stream;

use crate::future::join::Join;

#[test]
fn merge_tuple_4() {
fn merge_array_4() {
block_on(async {
let a = stream::once(1);
let b = stream::once(2);
Expand All @@ -119,4 +154,134 @@ mod tests {
assert_eq!(counter, 10);
})
}

#[test]
fn merge_array_2x2() {
block_on(async {
let a = stream::repeat(1).take(2);
let b = stream::repeat(2).take(2);
let mut s = [a, b].merge();

let mut counter = 0;
while let Some(n) = s.next().await {
counter += n;
}
assert_eq!(counter, 6);
})
}

/// This test case uses channels so we'll have streams that return Pending from time to time.
///
/// The purpose of this test is to make sure we have the waking logic working.
#[test]
fn merge_channels() {
struct LocalChannel<T> {
yoshuawuyts marked this conversation as resolved.
Show resolved Hide resolved
queue: VecDeque<T>,
waker: Option<Waker>,
closed: bool,
}

struct LocalReceiver<T> {
channel: Rc<RefCell<LocalChannel<T>>>,
}

impl<T> Stream for LocalReceiver<T> {
type Item = T;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut channel = self.channel.borrow_mut();

match channel.queue.pop_front() {
Some(item) => Poll::Ready(Some(item)),
None => {
if channel.closed {
Poll::Ready(None)
} else {
channel.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
}
}

struct LocalSender<T> {
channel: Rc<RefCell<LocalChannel<T>>>,
}

impl<T> LocalSender<T> {
fn send(&self, item: T) {
let mut channel = self.channel.borrow_mut();

channel.queue.push_back(item);

let _ = channel.waker.take().map(Waker::wake);
}
}

impl<T> Drop for LocalSender<T> {
fn drop(&mut self) {
let mut channel = self.channel.borrow_mut();
channel.closed = true;
let _ = channel.waker.take().map(Waker::wake);
}
}

fn local_channel<T>() -> (LocalSender<T>, LocalReceiver<T>) {
let channel = Rc::new(RefCell::new(LocalChannel {
queue: VecDeque::new(),
waker: None,
closed: false,
}));

(
LocalSender {
channel: channel.clone(),
},
LocalReceiver { channel },
)
}

let mut pool = LocalPool::new();

let done = Rc::new(RefCell::new(false));
let done2 = done.clone();

pool.spawner()
.spawn_local(async move {
let (send1, receive1) = local_channel();
let (send2, receive2) = local_channel();
let (send3, receive3) = local_channel();

let (count, ()) = (
async {
[receive1, receive2, receive3]
.merge()
.fold(0, |a, b| a + b)
.await
},
async {
for i in 1..=4 {
send1.send(i);
send2.send(i);
send3.send(i);
}
drop(send1);
drop(send2);
drop(send3);
},
)
.join()
.await;

assert_eq!(count, 30);

*done2.borrow_mut() = true;
})
.unwrap();

while !*done.borrow() {
pool.run_until_stalled()
}
}
}