From 20340e724d7de50cbfdcbf900c9523451a6d69ce Mon Sep 17 00:00:00 2001 From: Collin Styles Date: Sat, 21 Oct 2023 13:56:34 -0700 Subject: [PATCH 1/3] Panic if `All` or `Any` are polled after completing due to a short-circuit These futures should panic if they are polled after completing. Currently they do so but only if they complete due to exhausting the `Stream` that they pull data from. If they complete due to short-circuiting, they are left in a state where `fut` and `accum` are still `Some`. This means that if they are polled again, they end up polling the inner `fut` again. That usually causes a panic but the error message will likely reference the internal `Future`, not `All` / `Any`. With this commit, `All` and `Any`'s internal state will be set such that if they are polled again after completing, they will panic without polling `fut`. --- futures-util/src/stream/stream/all.rs | 3 ++- futures-util/src/stream/stream/any.rs | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/futures-util/src/stream/stream/all.rs b/futures-util/src/stream/stream/all.rs index ba2baa5cf1..b2aaa6b78a 100644 --- a/futures-util/src/stream/stream/all.rs +++ b/futures-util/src/stream/stream/all.rs @@ -69,11 +69,12 @@ where if let Some(fut) = this.future.as_mut().as_pin_mut() { // we're currently processing a future to produce a new accum value let acc = this.accum.unwrap() && ready!(fut.poll(cx)); + this.future.set(None); if !acc { + this.accum.take().unwrap(); break false; } // early exit *this.accum = Some(acc); - this.future.set(None); } else if this.accum.is_some() { // we're waiting on a new item from the stream match ready!(this.stream.as_mut().poll_next(cx)) { diff --git a/futures-util/src/stream/stream/any.rs b/futures-util/src/stream/stream/any.rs index f023125c70..f8b2a5829a 100644 --- a/futures-util/src/stream/stream/any.rs +++ b/futures-util/src/stream/stream/any.rs @@ -69,11 +69,12 @@ where if let Some(fut) = this.future.as_mut().as_pin_mut() { // we're currently processing a future to produce a new accum value let acc = this.accum.unwrap() || ready!(fut.poll(cx)); + this.future.set(None); if acc { + this.accum.take().unwrap(); break true; } // early exit *this.accum = Some(acc); - this.future.set(None); } else if this.accum.is_some() { // we're waiting on a new item from the stream match ready!(this.stream.as_mut().poll_next(cx)) { From 32a87173808e6caab11dc4ba28f178c72e85231d Mon Sep 17 00:00:00 2001 From: Collin Styles Date: Sat, 21 Oct 2023 14:10:31 -0700 Subject: [PATCH 2/3] Replace `All` and `Any`'s `accum` field with `done` It looks like `All` was originally implemented by copying from `TryFold` from which it inherited its `accum` field. However, `accum` can only ever be one of two values: `None` (if `All` has already completed) or `Some(true)` (if it's still processing values from the inner `Stream`). It doesn't need to keep track of an accumulator because the very fact that it hasn't short-circuited yet means that the accumulated value can't be `Some(false)`. Therefore, we only need two values here and we can represent them with a `bool` indicating whether or not `All` has already completed. The same principle applies for `Any` but substituting `Some(false)` for `Some(true)`. --- futures-util/src/stream/stream/all.rs | 22 +++++++++++----------- futures-util/src/stream/stream/any.rs | 22 +++++++++++----------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/futures-util/src/stream/stream/all.rs b/futures-util/src/stream/stream/all.rs index b2aaa6b78a..1435c798f2 100644 --- a/futures-util/src/stream/stream/all.rs +++ b/futures-util/src/stream/stream/all.rs @@ -13,7 +13,7 @@ pin_project! { #[pin] stream: St, f: F, - accum: Option, + done: bool, #[pin] future: Option, } @@ -27,7 +27,7 @@ where fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("All") .field("stream", &self.stream) - .field("accum", &self.accum) + .field("done", &self.done) .field("future", &self.future) .finish() } @@ -40,7 +40,7 @@ where Fut: Future, { pub(super) fn new(stream: St, f: F) -> Self { - Self { stream, f, accum: Some(true), future: None } + Self { stream, f, done: false, future: None } } } @@ -51,7 +51,7 @@ where Fut: Future, { fn is_terminated(&self) -> bool { - self.accum.is_none() && self.future.is_none() + self.done && self.future.is_none() } } @@ -67,22 +67,22 @@ where let mut this = self.project(); Poll::Ready(loop { if let Some(fut) = this.future.as_mut().as_pin_mut() { - // we're currently processing a future to produce a new accum value - let acc = this.accum.unwrap() && ready!(fut.poll(cx)); + // we're currently processing a future to produce a new value + let res = ready!(fut.poll(cx)); this.future.set(None); - if !acc { - this.accum.take().unwrap(); + if !res { + *this.done = true; break false; } // early exit - *this.accum = Some(acc); - } else if this.accum.is_some() { + } else if !*this.done { // we're waiting on a new item from the stream match ready!(this.stream.as_mut().poll_next(cx)) { Some(item) => { this.future.set(Some((this.f)(item))); } None => { - break this.accum.take().unwrap(); + *this.done = true; + break true; } } } else { diff --git a/futures-util/src/stream/stream/any.rs b/futures-util/src/stream/stream/any.rs index f8b2a5829a..cc3d695b9d 100644 --- a/futures-util/src/stream/stream/any.rs +++ b/futures-util/src/stream/stream/any.rs @@ -13,7 +13,7 @@ pin_project! { #[pin] stream: St, f: F, - accum: Option, + done: bool, #[pin] future: Option, } @@ -27,7 +27,7 @@ where fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Any") .field("stream", &self.stream) - .field("accum", &self.accum) + .field("done", &self.done) .field("future", &self.future) .finish() } @@ -40,7 +40,7 @@ where Fut: Future, { pub(super) fn new(stream: St, f: F) -> Self { - Self { stream, f, accum: Some(false), future: None } + Self { stream, f, done: false, future: None } } } @@ -51,7 +51,7 @@ where Fut: Future, { fn is_terminated(&self) -> bool { - self.accum.is_none() && self.future.is_none() + self.done && self.future.is_none() } } @@ -67,22 +67,22 @@ where let mut this = self.project(); Poll::Ready(loop { if let Some(fut) = this.future.as_mut().as_pin_mut() { - // we're currently processing a future to produce a new accum value - let acc = this.accum.unwrap() || ready!(fut.poll(cx)); + // we're currently processing a future to produce a new value + let res = ready!(fut.poll(cx)); this.future.set(None); - if acc { - this.accum.take().unwrap(); + if res { + *this.done = true; break true; } // early exit - *this.accum = Some(acc); - } else if this.accum.is_some() { + } else if !*this.done { // we're waiting on a new item from the stream match ready!(this.stream.as_mut().poll_next(cx)) { Some(item) => { this.future.set(Some((this.f)(item))); } None => { - break this.accum.take().unwrap(); + *this.done = true; + break false; } } } else { From 830333ce13f1c35d082f47035a2c27c76ce8185a Mon Sep 17 00:00:00 2001 From: Collin Styles Date: Sat, 21 Oct 2023 18:56:11 -0700 Subject: [PATCH 3/3] Add tests for `StreamExt::all` and `StreamExt::any` --- futures/tests/stream.rs | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index 41721f15cf..4f042832f2 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -552,3 +552,43 @@ fn select_with_strategy_doesnt_terminate_early() { assert_eq!(count.get(), times_should_poll + 1); } } + +async fn is_even(number: u8) -> bool { + number % 2 == 0 +} + +#[test] +fn all() { + block_on(async { + let empty: [u8; 0] = []; + let st = stream::iter(empty); + let all = st.all(is_even).await; + assert!(all); + + let st = stream::iter([2, 4, 6, 8]); + let all = st.all(is_even).await; + assert!(all); + + let st = stream::iter([2, 3, 4]); + let all = st.all(is_even).await; + assert!(!all); + }); +} + +#[test] +fn any() { + block_on(async { + let empty: [u8; 0] = []; + let st = stream::iter(empty); + let any = st.any(is_even).await; + assert!(!any); + + let st = stream::iter([1, 2, 3]); + let any = st.any(is_even).await; + assert!(any); + + let st = stream::iter([1, 3, 5]); + let any = st.any(is_even).await; + assert!(!any); + }); +}