Skip to content

Commit

Permalink
impl Zip for Array
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshuawuyts committed Nov 15, 2022
1 parent b6da9e2 commit 93a552e
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 0 deletions.
192 changes: 192 additions & 0 deletions src/stream/zip/array.rs
Original file line number Diff line number Diff line change
@@ -1 +1,193 @@
use super::Zip as ZipTrait;
use crate::stream::IntoStream;
use crate::utils::{self, PollState, WakerList};

use core::array;
use core::fmt;
use core::mem::MaybeUninit;
use core::pin::Pin;
use core::task::{Context, Poll};
use std::mem;

use futures_core::Stream;
use pin_project::{pin_project, pinned_drop};

/// ‘Zips up’ two streams into a single stream of pairs.
///
/// This `struct` is created by the [`merge`] method on the [`Zip`] trait. See its
/// documentation for more.
///
/// [`zip`]: trait.Zip.html#method.zip
/// [`Zip`]: trait.Zip.html
#[pin_project(PinnedDrop)]
pub struct Zip<S, const N: usize>
where
S: Stream,
{
#[pin]
streams: [S; N],
output: [MaybeUninit<<S as Stream>::Item>; N],
wakers: WakerList,
state: [PollState; N],
done: bool,
}

impl<S, const N: usize> Zip<S, N>
where
S: Stream,
{
pub(crate) fn new(streams: [S; N]) -> Self {
Self {
streams,
output: array::from_fn(|_| MaybeUninit::uninit()),
state: array::from_fn(|_| PollState::default()),
wakers: WakerList::new(N),
done: false,
}
}
}

impl<S, const N: usize> fmt::Debug for Zip<S, N>
where
S: Stream + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.streams.iter()).finish()
}
}

impl<S, const N: usize> Stream for Zip<S, N>
where
S: Stream,
{
type Item = [S::Item; N];

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();

assert!(!*this.done, "Stream should not be polled after completion");

let mut readiness = this.wakers.readiness().lock().unwrap();
readiness.set_waker(cx.waker());
for index in 0..N {
if !readiness.any_ready() {
// Nothing is ready yet
return Poll::Pending;
} else if this.state[index].is_done() {
// We already have data stored for this stream
continue;
} else if !readiness.clear_ready(index) {
// This waker isn't ready yet
continue;
}

// unlock readiness so we don't deadlock when polling
drop(readiness);

// Obtain the intermediate waker.
let mut cx = Context::from_waker(this.wakers.get(index).unwrap());

let stream = utils::get_pin_mut(this.streams.as_mut(), index).unwrap();
match stream.poll_next(&mut cx) {
Poll::Ready(Some(item)) => {
this.output[index] = MaybeUninit::new(item);
this.state[index] = PollState::Done;

let all_ready = this.state.iter().all(|state| state.is_done());
if all_ready {
// Reset the future's state.
readiness = this.wakers.readiness().lock().unwrap();
readiness.set_all_ready();
this.state.fill(PollState::Pending);

// Take the output
//
// SAFETY: we just validated all our data is populated, meaning
// we can assume this is initialized.
let mut output = array::from_fn(|_| MaybeUninit::uninit());
mem::swap(this.output, &mut output);
let output = unsafe { array_assume_init(output) };
return Poll::Ready(Some(output));
}
}
Poll::Ready(None) => {
// If one stream returns `None`, we can no longer return
// pairs - meaning the stream is over.
*this.done = true;
return Poll::Ready(None);
}
Poll::Pending => {}
}

// Lock readiness so we can use it again
readiness = this.wakers.readiness().lock().unwrap();
}
Poll::Pending
}
}

/// Drop the already initialized values on cancellation.
#[pinned_drop]
impl<S, const N: usize> PinnedDrop for Zip<S, N>
where
S: Stream,
{
fn drop(self: Pin<&mut Self>) {
let this = self.project();

for (state, output) in this.state.iter_mut().zip(this.output.iter_mut()) {
if state.is_done() {
// SAFETY: we've just filtered down to *only* the initialized values.
// We can assume they're initialized, and this is where we drop them.
unsafe { output.assume_init_drop() };
}
}
}
}

impl<S, const N: usize> ZipTrait for [S; N]
where
S: IntoStream,
{
type Item = <Zip<S::IntoStream, N> as Stream>::Item;
type Stream = Zip<S::IntoStream, N>;

fn zip(self) -> Self::Stream {
Zip::new(self.map(|i| i.into_stream()))
}
}

#[cfg(test)]
mod tests {
use crate::stream::Zip;
use futures_lite::future::block_on;
use futures_lite::prelude::*;
use futures_lite::stream;

#[test]
fn zip_array_3() {
block_on(async {
let a = stream::repeat(1).take(2);
let b = stream::repeat(2).take(2);
let c = stream::repeat(3).take(2);
let mut s = Zip::zip([a, b, c]);

assert_eq!(s.next().await, Some([1, 2, 3]));
assert_eq!(s.next().await, Some([1, 2, 3]));
assert_eq!(s.next().await, None);
})
}
}

// Inlined version of the unstable `MaybeUninit::array_assume_init` feature.
// FIXME: replace with `utils::array_assume_init`
unsafe fn array_assume_init<T, const N: usize>(array: [MaybeUninit<T>; N]) -> [T; N] {
// SAFETY:
// * The caller guarantees that all elements of the array are initialized
// * `MaybeUninit<T>` and T are guaranteed to have the same layout
// * `MaybeUninit` does not drop, so there are no double-frees
// And thus the conversion is safe
let ret = unsafe { (&array as *const _ as *const [T; N]).read() };
mem::forget(array);
ret
}
8 changes: 8 additions & 0 deletions src/utils/wakers/readiness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::utils;
#[derive(Debug)]
pub(crate) struct Readiness {
count: usize,
max_count: usize,
ready: BitVec,
parent_waker: Option<Waker>,
}
Expand All @@ -16,6 +17,7 @@ impl Readiness {
pub(crate) fn new(count: usize) -> Self {
Self {
count,
max_count: count,
ready: bitvec![true as usize; count],
parent_waker: None,
}
Expand All @@ -33,6 +35,12 @@ impl Readiness {
}
}

/// Set all markers to ready.
pub(crate) fn set_all_ready(&mut self) {
self.ready.fill(true);
self.count = self.max_count;
}

/// Returns whether the task id was previously ready
pub(crate) fn clear_ready(&mut self, id: usize) -> bool {
if self.ready[id] {
Expand Down

0 comments on commit 93a552e

Please sign in to comment.