diff --git a/src/helpers/buffers/unordered_receiver.rs b/src/helpers/buffers/unordered_receiver.rs index 84d529bd3..9a6f752db 100644 --- a/src/helpers/buffers/unordered_receiver.rs +++ b/src/helpers/buffers/unordered_receiver.rs @@ -110,6 +110,7 @@ impl Spare { }; Some(m) } + /// Returns `true` if there are no bytes currently awaiting a read. fn is_empty(&self) -> bool { self.offset == self.buf.len() @@ -153,6 +154,8 @@ where /// that easing load on this mechanism. There might also need to be some /// end-to-end back pressure for tasks that do not involve sending at all. overflow_wakers: Vec, + /// If this receiver is closed and no longer capable of receiving data. + closed: bool, _marker: PhantomData, } @@ -212,54 +215,79 @@ where /// Poll for the next record. This should only be invoked when /// the future for the next message is polled. + /// + /// ## Errors + /// If buffer capacity is not enough to read `M` and the underlying stream does not have + /// more data. This may lead to different behavior depending on the order of issued reads. See + /// [`read_order`] test for an example. + /// + /// [`read_order`]: test::read_order fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { // If spare has enough data for us, poll it first // otherwise, poll the underlying stream until it returns pending or it provides enough // data to return a value. - let message = self.spare.read(); - let next = match message { - Some(m) => { - // this check exists to make sure the inner stream is eventually moved to - // the closed state. We don't want to poll it too often, but we also need to know - // when it is done and `UnorderedReceiver` can be dropped. - if self.spare.is_empty() { - // we don't want to be woken up here, control loop is driven by the client. - // They decide when they want the next message and must issue a `poll` for it. - - // TODO: https://github.com/rust-lang/rust/issues/98286 - let mut cx = Context::from_waker(noop_waker_ref()); - match self.stream.as_mut().poll_next(&mut cx) { - Poll::Ready(Some(bytes)) => { - // Spare is empty because of the check above. - self.spare.replace(bytes.as_ref()); - } - Poll::Ready(None) | Poll::Pending => {} + let next = if let Some(m) = message { + // this check exists to make sure the inner stream is eventually moved to + // the closed state. We don't want to poll it too often, but we also need to know + // when it is done and `UnorderedReceiver` can be dropped. + if self.spare.is_empty() { + // we don't want to be woken up here, control loop is driven by the client. + // They decide when they want the next message and must issue a `poll` for it. + + // TODO: https://github.com/rust-lang/rust/issues/98286 + let mut cx = Context::from_waker(noop_waker_ref()); + match self.stream.as_mut().poll_next(&mut cx) { + Poll::Ready(Some(bytes)) => { + // Spare is empty because of the check above. + self.spare.replace(bytes.as_ref()); } + Poll::Ready(None) => { + self.close(); + } + Poll::Pending => {} } - - self.wake_next(); - Poll::Ready(Ok(m)) } - None => loop { - match ready!(self.stream.as_mut().poll_next(cx)) { - Some(bytes) => { - if let Some(m) = self.spare.extend(bytes.as_ref()) { - self.wake_next(); - break Poll::Ready(Ok(m)); - } - } - None => { - break Poll::Ready(Err(Error::EndOfStream { - record_id: RecordId::from(self.next), - })); + + self.wake_next(); + Poll::Ready(Ok(m)) + } else { + loop { + if let Some(bytes) = ready!(self.stream.as_mut().poll_next(cx)) { + if let Some(m) = self.spare.extend(bytes.as_ref()) { + self.wake_next(); + break Poll::Ready(Ok(m)); } + } else { + self.spare.replace(&[]); + self.close(); + break Poll::Ready(Err(Error::EndOfStream { + record_id: RecordId::from(self.next), + })); } - }, + } }; next } + + /// Returns `true` if this receiver is closed. + fn is_closed(&self) -> bool { + self.closed + } + + /// Close this receiver, so it can no longer be used to poll messages. + /// ## Errors + /// If the underlying stream is not closed or if there are unread bytes inside the buffer. + fn close(&mut self) { + if !self.closed { + assert!(self.stream.is_done()); + assert!(self.spare.is_empty()); + + self.closed = true; + self.spare = Spare::default(); + } + } } /// Take an ordered stream of bytes and make messages from that stream @@ -297,6 +325,7 @@ where spare: Spare::default(), wakers, overflow_wakers: Vec::new(), + closed: false, _marker: PhantomData, })), } @@ -326,8 +355,7 @@ where pub fn is_closed(&self) -> bool { // If this function is ever called on the hot path, consider caching closed status. // Closed streams cannot move back to open. - let inner = self.inner.lock().unwrap(); - inner.stream.is_done() && inner.spare.is_empty() + self.inner.lock().unwrap().is_closed() } } @@ -567,18 +595,53 @@ mod test { } assert!(recv.is_closed()); + assert_eq!(0, recv.inner.lock().unwrap().spare.buf.capacity()); }); } #[test] fn end_of_stream() { - const DATA: &[u8] = &[1u8, 2, 3, 4, 5]; + const DATA: &[u8] = &[1_u8, 2, 3, 4, 5]; run(|| async move { let recv = receiver([DATA]); - let _: Fp32BitPrime = recv.recv(0u8).await.unwrap(); + let _: Fp32BitPrime = recv.recv(0_u8).await.unwrap(); assert!(matches!( - recv.recv::(1u8).await, + recv.recv::(1_u8).await, + Err(EndOfStream { .. }) + )); + assert!(recv.is_closed()); + }); + } + + #[test] + fn read_order() { + const DATA: &[u8] = &[0_u8, 1, 2, 3, 4]; + // reading 3 u8 and then u32 - can read 3 items from the receiver + run(|| async move { + let recv = receiver([DATA]); + for i in 0_u8..3 { + assert_eq!( + Fp31::truncate_from(i), + recv.recv::(i).await.unwrap() + ); + } + assert!(matches!( + recv.recv::(3_u8).await, + Err(EndOfStream { .. }) + )); + }); + // reading 2 u8 and then u32 - can read 2 items from the receiver + run(|| async move { + let recv = receiver([DATA]); + for i in 0_u8..2 { + assert_eq!( + Fp31::truncate_from(i), + recv.recv::(i).await.unwrap() + ); + } + assert!(matches!( + recv.recv::(2_u8).await, Err(EndOfStream { .. }) )); });