From 74d58892db1360b5aa94bd259034b0a279c9d1b2 Mon Sep 17 00:00:00 2001 From: James Wilson Date: Mon, 4 Mar 2024 11:55:30 +0000 Subject: [PATCH] Add poll_recv method to Receiver (#56) This PR Adds a `poll_recv()` method to the `Receiver` type. It returns the same `Result` type that the `receiver.recv()` future returns (hence the name). This method can be used when defining custom streams that internally make use of `async_broadcast` and want to know about whether the `async_broadcast` stream has overflowed or not. --- src/lib.rs | 126 +++++++++++++++++++++++++++++++++++--------------- tests/test.rs | 35 ++++++++++++++ 2 files changed, 124 insertions(+), 37 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d6c9d2f..90334b1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1298,6 +1298,89 @@ impl Receiver { listener: None, } } + + /// A low level poll method that is similar to [`Receiver::recv()`] or + /// [`Receiver::recv_direct()`], and can be useful for building stream implementations which + /// use a [`Receiver`] under the hood and want to know if the stream has overflowed. + /// + /// Prefer to use [`Receiver::recv()`] or [`Receiver::recv_direct()`] when otherwise possible. + /// + /// # Errors + /// + /// If the number of messages that have been sent has overflowed the channel capacity, a + /// [`RecvError::Overflowed`] variant is returned containing the number of items that + /// overflowed and were lost. + /// + /// # Examples + /// + /// This example shows how the [`Receiver::poll_recv`] method can be used to allow a custom + /// stream implementation to internally make use of a [`Receiver`]. This example implementation + /// differs from the stream implementation of [`Receiver`] because it returns an error if + /// the channel capacity overflows, which the built in [`Receiver`] stream doesn't do. + /// + /// ``` + /// use futures_core::Stream; + /// use async_broadcast::{Receiver, RecvError}; + /// use std::{pin::Pin, task::{Poll, Context}}; + /// + /// struct MyStream(Receiver); + /// + /// impl futures_core::Stream for MyStream { + /// type Item = Result; + /// fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + /// Pin::new(&mut self.0).poll_recv(cx) + /// } + /// } + /// ``` + pub fn poll_recv( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + // If this stream is listening for events, first wait for a notification. + if let Some(listener) = self.listener.as_mut() { + ready!(Pin::new(listener).poll(cx)); + self.listener = None; + } + + loop { + // Attempt to receive a message. + match self.try_recv() { + Ok(msg) => { + // The stream is not blocked on an event - drop the listener. + self.listener = None; + return Poll::Ready(Some(Ok(msg))); + } + Err(TryRecvError::Closed) => { + // The stream is not blocked on an event - drop the listener. + self.listener = None; + return Poll::Ready(None); + } + Err(TryRecvError::Overflowed(n)) => { + // The stream is not blocked on an event - drop the listener. + self.listener = None; + return Poll::Ready(Some(Err(RecvError::Overflowed(n)))); + } + Err(TryRecvError::Empty) => {} + } + + // Receiving failed - now start listening for notifications or wait for one. + match self.listener.as_mut() { + None => { + // Start listening and then try receiving again. + self.listener = { + let inner = self.inner.write().unwrap(); + Some(inner.recv_ops.listen()) + }; + } + Some(_) => { + // Go back to the outer loop to poll the listener. + break; + } + } + } + } + } } impl Drop for Receiver { @@ -1363,43 +1446,12 @@ impl Stream for Receiver { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - // If this stream is listening for events, first wait for a notification. - if let Some(listener) = self.listener.as_mut() { - ready!(Pin::new(listener).poll(cx)); - self.listener = None; - } - - loop { - // Attempt to receive a message. - match self.try_recv() { - Ok(msg) => { - // The stream is not blocked on an event - drop the listener. - self.listener = None; - return Poll::Ready(Some(msg)); - } - Err(TryRecvError::Closed) => { - // The stream is not blocked on an event - drop the listener. - self.listener = None; - return Poll::Ready(None); - } - Err(TryRecvError::Overflowed(_)) => continue, - Err(TryRecvError::Empty) => {} - } - - // Receiving failed - now start listening for notifications or wait for one. - match self.listener.as_mut() { - None => { - // Start listening and then try receiving again. - self.listener = { - let inner = self.inner.write().unwrap(); - Some(inner.recv_ops.listen()) - }; - } - Some(_) => { - // Go back to the outer loop to poll the listener. - break; - } - } + match ready!(self.as_mut().poll_recv(cx)) { + Some(Ok(val)) => return Poll::Ready(Some(val)), + // If overflowed, we expect future operations to succeed so try again. + Some(Err(RecvError::Overflowed(_))) => continue, + // RecvError::Closed should never appear here, but handle it anyway. + None | Some(Err(RecvError::Closed)) => return Poll::Ready(None), } } } diff --git a/tests/test.rs b/tests/test.rs index 4416909..79718cb 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -288,3 +288,38 @@ fn inactive_drop() { assert!(s.is_closed()) } + +#[test] +fn poll_recv() { + let (s, mut r) = broadcast::(2); + r.set_overflow(true); + + // A quick custom stream impl to demonstrate/test `poll_recv`. + struct MyStream(Receiver); + impl futures_core::Stream for MyStream { + type Item = Result; + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_recv(cx) + } + } + + block_on(async move { + let mut stream = MyStream(r); + + s.broadcast(1).await.unwrap(); + s.broadcast(2).await.unwrap(); + s.broadcast(3).await.unwrap(); + s.broadcast(4).await.unwrap(); + + assert_eq!(stream.next().await.unwrap(), Err(RecvError::Overflowed(2))); + assert_eq!(stream.next().await.unwrap(), Ok(3)); + assert_eq!(stream.next().await.unwrap(), Ok(4)); + + drop(s); + + assert_eq!(stream.next().await, None); + }) +}