Skip to content

Commit

Permalink
Merge pull request #71 from yoshuawuyts/race-use-pollstate
Browse files Browse the repository at this point in the history
make `{array,vec}::race` fair
  • Loading branch information
yoshuawuyts authored Nov 16, 2022
2 parents 00acd05 + 7d05b37 commit c13f81e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 28 deletions.
36 changes: 22 additions & 14 deletions src/future/race/array.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::utils::{self, RandomGenerator};

use super::Race as RaceTrait;

use core::fmt;
Expand All @@ -20,7 +22,9 @@ pub struct Race<Fut, const N: usize>
where
Fut: Future,
{
futs: [Fut; N],
#[pin]
futures: [Fut; N],
rng: RandomGenerator,
done: bool,
}

Expand All @@ -30,7 +34,7 @@ where
Fut::Output: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.futs.iter()).finish()
f.debug_list().entries(self.futures.iter()).finish()
}
}

Expand All @@ -41,16 +45,19 @@ where
type Output = Fut::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
assert!(
!*this.done,
"Futures must not be polled after being completed"
);
for fut in this.futs {
let fut = unsafe { Pin::new_unchecked(fut) };
if let Poll::Ready(output) = Future::poll(fut, cx) {
*this.done = true;
return Poll::Ready(output);
let mut this = self.project();
assert!(!*this.done, "Futures must not be polled after completing");

let index = this.rng.generate(N as u32) as usize;

for index in (0..N).map(|pos| (index + pos).wrapping_rem(N)) {
let fut = utils::get_pin_mut(this.futures.as_mut(), index).unwrap();
match fut.poll(cx) {
Poll::Ready(item) => {
*this.done = true;
return Poll::Ready(item);
}
Poll::Pending => continue,
}
}
Poll::Pending
Expand All @@ -66,7 +73,8 @@ where

fn race(self) -> Self::Future {
Race {
futs: self.map(|fut| fut.into_future()),
futures: self.map(|fut| fut.into_future()),
rng: RandomGenerator::new(),
done: false,
}
}
Expand All @@ -84,7 +92,7 @@ mod test {
let res = [future::ready("hello"), future::ready("world")]
.race()
.await;
assert_eq!(res, "hello");
assert!(matches!(res, "hello" | "world"));
});
}
}
37 changes: 23 additions & 14 deletions src/future/race/vec.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::utils::{self, RandomGenerator};

use super::Race as RaceTrait;

use core::fmt;
Expand All @@ -20,7 +22,9 @@ pub struct Race<Fut>
where
Fut: Future,
{
futs: Vec<Fut>,
#[pin]
futures: Vec<Fut>,
rng: RandomGenerator,
done: bool,
}

Expand All @@ -30,7 +34,7 @@ where
Fut::Output: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.futs.iter()).finish()
f.debug_list().entries(self.futures.iter()).finish()
}
}

Expand All @@ -41,16 +45,20 @@ where
type Output = Fut::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
assert!(
!*this.done,
"Futures must not be polled after being completed"
);
for fut in this.futs {
let fut = unsafe { Pin::new_unchecked(fut) };
if let Poll::Ready(output) = Future::poll(fut, cx) {
*this.done = true;
return Poll::Ready(output);
let mut this = self.project();
assert!(!*this.done, "Futures must not be polled after completing");

let len = this.futures.len();
let index = this.rng.generate(len as u32) as usize;

for index in (0..len).map(|pos| (index + pos).wrapping_rem(len)) {
let fut = utils::get_pin_mut_from_vec(this.futures.as_mut(), index).unwrap();
match fut.poll(cx) {
Poll::Ready(item) => {
*this.done = true;
return Poll::Ready(item);
}
Poll::Pending => continue,
}
}
Poll::Pending
Expand All @@ -66,7 +74,8 @@ where

fn race(self) -> Self::Future {
Race {
futs: self.into_iter().map(|fut| fut.into_future()).collect(),
futures: self.into_iter().map(|fut| fut.into_future()).collect(),
rng: RandomGenerator::new(),
done: false,
}
}
Expand All @@ -84,7 +93,7 @@ mod test {
let res = vec![future::ready("hello"), future::ready("world")]
.race()
.await;
assert_eq!(res, "hello");
assert!(matches!(res, "hello" | "world"));
});
}
}

0 comments on commit c13f81e

Please sign in to comment.