diff --git a/sqlx-core/src/ext/async_stream.rs b/sqlx-core/src/ext/async_stream.rs index c08c0d9c74..85ff1be4f5 100644 --- a/sqlx-core/src/ext/async_stream.rs +++ b/sqlx-core/src/ext/async_stream.rs @@ -1,40 +1,97 @@ +//! A minimalist clone of the `async-stream` crate in 100% safe code, without proc macros. +//! +//! This was created initially to get around some weird compiler errors we were getting with +//! `async-stream`, and now it'd just be more work to replace. + use std::future::Future; use std::pin::Pin; +use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; -use futures_channel::mpsc; use futures_core::future::BoxFuture; use futures_core::stream::Stream; -use futures_util::{pin_mut, FutureExt, SinkExt}; +use futures_core::FusedFuture; +use futures_util::future::Fuse; +use futures_util::FutureExt; use crate::error::Error; pub struct TryAsyncStream<'a, T> { - receiver: mpsc::Receiver>, - future: BoxFuture<'a, Result<(), Error>>, + yielder: Yielder, + future: Fuse>>, } impl<'a, T> TryAsyncStream<'a, T> { pub fn new(f: F) -> Self where - F: FnOnce(mpsc::Sender>) -> Fut + Send, + F: FnOnce(Yielder) -> Fut + Send, Fut: 'a + Future> + Send, T: 'a + Send, { - let (mut sender, receiver) = mpsc::channel(0); + let yielder = Yielder::new(); - let future = f(sender.clone()); - let future = async move { - if let Err(error) = future.await { - let _ = sender.send(Err(error)).await; - } + let future = f(yielder.duplicate()).boxed().fuse(); + + Self { future, yielder } + } +} + +pub struct Yielder { + // This mutex should never have any contention in normal operation. + // We're just using it because `Rc>>` would not be `Send`. + value: Arc>>, +} + +impl Yielder { + fn new() -> Self { + Yielder { + value: Arc::new(Mutex::new(None)), + } + } - Ok(()) + // Don't want to expose a `Clone` impl + fn duplicate(&self) -> Self { + Yielder { + value: self.value.clone(), } - .fuse() - .boxed(); + } + + /// NOTE: may deadlock the task if called from outside the future passed to `TryAsyncStream`. + pub async fn r#yield(&self, val: T) { + let replaced = self + .value + .lock() + .expect("BUG: panicked while holding a lock") + .replace(val); - Self { future, receiver } + debug_assert!( + replaced.is_none(), + "BUG: previously yielded value not taken" + ); + + let mut yielded = false; + + // Allows the generating future to suspend its execution without changing the task priority, + // which would happen with `tokio::task::yield_now()`. + // + // Note that because this has no way to schedule a wakeup, this could deadlock the task + // if called in the wrong place. + futures_util::future::poll_fn(|_cx| { + if !yielded { + yielded = true; + Poll::Pending + } else { + Poll::Ready(()) + } + }) + .await + } + + fn take(&self) -> Option { + self.value + .lock() + .expect("BUG: panicked while holding a lock") + .take() } } @@ -42,29 +99,35 @@ impl<'a, T> Stream for TryAsyncStream<'a, T> { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let future = &mut self.future; - pin_mut!(future); - - // the future is fused so its safe to call forever - // the future advances our "stream" - // the future should be polled in tandem with the stream receiver - let _ = future.poll(cx); - - let receiver = &mut self.receiver; - pin_mut!(receiver); + if self.future.is_terminated() { + return Poll::Ready(None); + } - // then we check to see if we have anything to return - receiver.poll_next(cx) + match self.future.poll_unpin(cx) { + Poll::Ready(Ok(())) => { + // Future returned without yielding another value. + Poll::Ready(None) + } + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => self + .yielder + .take() + .map_or(Poll::Pending, |val| Poll::Ready(Some(Ok(val)))), + } } } #[macro_export] macro_rules! try_stream { ($($block:tt)*) => { - crate::ext::async_stream::TryAsyncStream::new(move |mut sender| async move { + crate::ext::async_stream::TryAsyncStream::new(move |yielder| async move { + // Anti-footgun: effectively pins `yielder` to this future to prevent any accidental + // move to another task, which could deadlock. + let ref yielder = yielder; + macro_rules! r#yield { ($v:expr) => {{ - let _ = futures_util::sink::SinkExt::send(&mut sender, Ok($v)).await; + yielder.r#yield($v).await; }} }