Skip to content

Commit

Permalink
Merge pull request #75 from yoshuawuyts/fast-merge-array
Browse files Browse the repository at this point in the history
perfect waker for `array::Merge`
  • Loading branch information
yoshuawuyts authored Nov 16, 2022
2 parents faba0b7 + f550650 commit eb31f45
Showing 1 changed file with 183 additions and 18 deletions.
201 changes: 183 additions & 18 deletions src/stream/merge/array.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use super::Merge as MergeTrait;
use crate::stream::IntoStream;
use crate::utils::{self, PollState};
use crate::utils::{self, PollState, PollStates, RandomGenerator, WakerList};

use core::array;
use core::fmt;
use futures_core::Stream;
use std::pin::Pin;
Expand All @@ -22,8 +21,10 @@ where
{
#[pin]
streams: [S; N],
rng: utils::RandomGenerator,
poll_state: [PollState; N],
rng: RandomGenerator,
complete: usize,
wakers: WakerList,
state: PollStates,
done: bool,
}

Expand All @@ -32,10 +33,13 @@ where
S: Stream,
{
pub(crate) fn new(streams: [S; N]) -> Self {
let len = streams.len();
Self {
wakers: WakerList::new(len),
state: PollStates::new(len),
streams,
rng: utils::RandomGenerator::new(),
poll_state: array::from_fn(|_| PollState::default()),
rng: RandomGenerator::new(),
complete: 0,
done: false,
}
}
Expand All @@ -59,28 +63,50 @@ where
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();

let mut readiness = this.wakers.readiness().lock().unwrap();
readiness.set_waker(cx.waker());

// Iterate over our streams one-by-one. If a stream yields a value,
// we exit early. By default we'll return `Poll::Ready(None)`, but
// this changes if we encounter a `Poll::Pending`.
let mut res = Poll::Ready(None);
let index = this.rng.generate(N as u32) as usize;

for index in (0..N).map(|pos| (index + pos).wrapping_rem(N)) {
if this.poll_state[index].is_consumed() {
let len = this.streams.len();
let r = this.rng.generate(len as u32) as usize;
for index in (0..len).map(|n| (r + n).wrapping_rem(len)) {
if !readiness.any_ready() {
// Nothing is ready yet
return Poll::Pending;
} else if !readiness.clear_ready(index) || this.state[index].is_consumed() {
continue;
}

// unlock readiness so we don't deadlock when polling
drop(readiness);

// Obtain the intermediate waker.
let mut cx = Context::from_waker(this.wakers.get(index).unwrap());

let stream = utils::get_pin_mut(this.streams.as_mut(), index).unwrap();
match stream.poll_next(cx) {
Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
match stream.poll_next(&mut cx) {
Poll::Ready(Some(item)) => {
// Mark ourselves as ready again because we need to poll for the next item.
this.wakers.readiness().lock().unwrap().set_ready(index);
return Poll::Ready(Some(item));
}
Poll::Ready(None) => {
this.poll_state[index] = PollState::Consumed;
continue;
*this.complete += 1;
this.state[index] = PollState::Consumed;
if *this.complete == this.streams.len() {
return Poll::Ready(None);
}
}
Poll::Pending => res = Poll::Pending,
Poll::Pending => {}
}

// Lock readiness so we can use it again
readiness = this.wakers.readiness().lock().unwrap();
}
res

Poll::Pending
}
}

Expand All @@ -98,13 +124,22 @@ where

#[cfg(test)]
mod tests {
use std::cell::RefCell;
use std::collections::VecDeque;
use std::rc::Rc;
use std::task::Waker;

use super::*;
use futures::executor::LocalPool;
use futures::task::LocalSpawnExt;
use futures_lite::future::block_on;
use futures_lite::prelude::*;
use futures_lite::stream;

use crate::future::join::Join;

#[test]
fn merge_tuple_4() {
fn merge_array_4() {
block_on(async {
let a = stream::once(1);
let b = stream::once(2);
Expand All @@ -119,4 +154,134 @@ mod tests {
assert_eq!(counter, 10);
})
}

#[test]
fn merge_array_2x2() {
block_on(async {
let a = stream::repeat(1).take(2);
let b = stream::repeat(2).take(2);
let mut s = [a, b].merge();

let mut counter = 0;
while let Some(n) = s.next().await {
counter += n;
}
assert_eq!(counter, 6);
})
}

/// This test case uses channels so we'll have streams that return Pending from time to time.
///
/// The purpose of this test is to make sure we have the waking logic working.
#[test]
fn merge_channels() {
struct LocalChannel<T> {
queue: VecDeque<T>,
waker: Option<Waker>,
closed: bool,
}

struct LocalReceiver<T> {
channel: Rc<RefCell<LocalChannel<T>>>,
}

impl<T> Stream for LocalReceiver<T> {
type Item = T;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut channel = self.channel.borrow_mut();

match channel.queue.pop_front() {
Some(item) => Poll::Ready(Some(item)),
None => {
if channel.closed {
Poll::Ready(None)
} else {
channel.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
}
}

struct LocalSender<T> {
channel: Rc<RefCell<LocalChannel<T>>>,
}

impl<T> LocalSender<T> {
fn send(&self, item: T) {
let mut channel = self.channel.borrow_mut();

channel.queue.push_back(item);

let _ = channel.waker.take().map(Waker::wake);
}
}

impl<T> Drop for LocalSender<T> {
fn drop(&mut self) {
let mut channel = self.channel.borrow_mut();
channel.closed = true;
let _ = channel.waker.take().map(Waker::wake);
}
}

fn local_channel<T>() -> (LocalSender<T>, LocalReceiver<T>) {
let channel = Rc::new(RefCell::new(LocalChannel {
queue: VecDeque::new(),
waker: None,
closed: false,
}));

(
LocalSender {
channel: channel.clone(),
},
LocalReceiver { channel },
)
}

let mut pool = LocalPool::new();

let done = Rc::new(RefCell::new(false));
let done2 = done.clone();

pool.spawner()
.spawn_local(async move {
let (send1, receive1) = local_channel();
let (send2, receive2) = local_channel();
let (send3, receive3) = local_channel();

let (count, ()) = (
async {
[receive1, receive2, receive3]
.merge()
.fold(0, |a, b| a + b)
.await
},
async {
for i in 1..=4 {
send1.send(i);
send2.send(i);
send3.send(i);
}
drop(send1);
drop(send2);
drop(send3);
},
)
.join()
.await;

assert_eq!(count, 30);

*done2.borrow_mut() = true;
})
.unwrap();

while !*done.borrow() {
pool.run_until_stalled()
}
}
}

0 comments on commit eb31f45

Please sign in to comment.