From 5f3e801e41fba888d432a63077b6e16e85c0bebd Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 8 Nov 2023 15:46:04 -0800 Subject: [PATCH] Remove async scope --- Cargo.toml | 1 - .../replicated/malicious/additive_share.rs | 2 +- src/seq_join.rs | 178 ++++++++++++------ 3 files changed, 118 insertions(+), 63 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3eb6d5cf7..9962d04ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,6 @@ ipa-prf = [] [dependencies] aes = "0.8.3" -async-scoped = { version = "0.7.1", features = ["use-tokio"], path = "../rmstuff/async-scoped" } async-trait = "0.1.68" axum = { version = "0.5.17", optional = true, features = ["http2"] } axum-server = { version = "0.5.1", optional = true, features = ["rustls", "rustls-pemfile", "tls-rustls"] } diff --git a/src/secret_sharing/replicated/malicious/additive_share.rs b/src/secret_sharing/replicated/malicious/additive_share.rs index c2fe440c4..d92a0d000 100644 --- a/src/secret_sharing/replicated/malicious/additive_share.rs +++ b/src/secret_sharing/replicated/malicious/additive_share.rs @@ -77,7 +77,7 @@ impl LinearSecretSharing for AdditiveShare< /// when the protocol is done. This should not be used directly. #[async_trait] pub trait Downgrade: Send { - type Target: Send; + type Target: Send + 'static; async fn downgrade(self) -> UnauthorizedDowngradeWrapper; } diff --git a/src/seq_join.rs b/src/seq_join.rs index 34832b436..6269bb4fb 100644 --- a/src/seq_join.rs +++ b/src/seq_join.rs @@ -5,18 +5,65 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use async_scoped::spawner::use_tokio::Tokio; -use async_scoped::TokioScope; +use std::marker::PhantomData; use async_trait::async_trait; +use clap::builder::TypedValueParser; -use futures::{ - stream::{iter, Iter as StreamIter, TryCollect}, - Future, Stream, StreamExt, TryStreamExt, -}; +use futures::{stream::{iter, Iter as StreamIter, TryCollect}, Future, Stream, StreamExt, TryStreamExt, TryFuture}; +use futures_util::future::TryJoinAll; +use futures_util::stream::FuturesOrdered; use pin_project::pin_project; use crate::exact::ExactSizeStream; + + + +struct UnsafeSpawner<'a, T> { + _t_marker: PhantomData, + // Future proof against variance changes + _marker: PhantomData &'a ()>, +} + +impl <'a, T> Default for UnsafeSpawner<'a, T> { + fn default() -> Self { + Self { + _t_marker: PhantomData, + _marker: PhantomData, + } + } +} + +#[pin_project] +struct UnsafeSpawnerHandle { + #[pin] + inner: tokio::task::JoinHandle +} + +impl Future for UnsafeSpawnerHandle { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().inner.poll(cx) { + Poll::Ready(Ok(t)) => Poll::Ready(t), + Poll::Ready(Err(e)) => panic!("cancelled: {e}"), + Poll::Pending => Poll::Pending, + } + } +} + +impl <'a, T: Send + 'static> UnsafeSpawner<'a, T> { + fn spawn + Send + 'a>(&self, f: F) -> UnsafeSpawnerHandle { + let handle = tokio::spawn(unsafe { + std::mem::transmute::<_, Pin + Send>>>( + Box::pin(f) as Pin>> + ) + }); + + UnsafeSpawnerHandle { inner: handle } + } +} + /// This helper function might be necessary to convince the compiler that /// the return value from [`seq_try_join_all`] implements `Send`. /// Use this if you get higher-ranked lifetime errors that mention `std::marker::Send`. @@ -42,12 +89,14 @@ pub fn assert_send<'a, O>( /// [`try_join_all`]: futures::future::try_join_all /// [`Stream`]: futures::stream::Stream /// [`StreamExt::buffered`]: futures::stream::StreamExt::buffered -pub fn seq_join(active: NonZeroUsize, source: S) -> SequentialFutures -where - S: Stream + Send, - F: Future, +pub fn seq_join<'a, S, F, O>(active: NonZeroUsize, source: S) -> SequentialFutures<'a, S, F> + where + S: Stream + Send, + F: Future + Send + 'a, + O: Send + 'static { SequentialFutures { + spawner: UnsafeSpawner::default(), source: source.fuse(), active: VecDeque::with_capacity(active.get()), } @@ -56,6 +105,7 @@ where /// The `SeqJoin` trait wraps `seq_try_join_all`, providing the `active` parameter /// from the provided context so that the value can be made consistent. pub trait SeqJoin { + /// Perform a sequential join of the futures from the provided iterable. /// This uses [`seq_join`], with the current state of the associated object /// being used to determine the number of active items to track (see [`active_work`]). @@ -75,74 +125,74 @@ pub trait SeqJoin { /// [`active_work`]: Self::active_work /// [`parallel_join`]: Self::parallel_join /// [`join3`]: futures::future::join3 - fn try_join(&self, iterable: I) -> TryCollect, Vec> - where - I: IntoIterator + Send, - I::IntoIter: Send, - F: Future>, + fn try_join<'a, I, F, O, E>(&self, iterable: I) -> TryCollect, Vec> + where + I: IntoIterator + Send, + I::IntoIter: Send, + F: Future> + Send + 'a, + O: Send + 'static, + E: Send + 'static { seq_try_join_all(self.active_work(), iterable) } /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. - fn parallel_join<'a, I, F, O, E>(&self, iterable: I) -> Pin, E>> + Send + 'a>> + fn parallel_join(&self, iterable: I) -> Pin, E>> + Send>> where I: IntoIterator + Send, - F: Future> + Send + 'a, + F: Future> + Send, O: Send + 'static, E: Send + 'static { - // TODO: implement spawner for shuttle - let mut scope = { - let iter = iterable.into_iter(); - let mut scope = unsafe { TokioScope::create(Tokio) }; - for element in iter { - // it is important to make those cancellable. - // TODO: elaborate why - scope.spawn_cancellable(element, || panic!("Future is cancelled.")); - } - scope - }; - - Box::pin(async move { - let mut result = Vec::with_capacity(scope.len()); - while let Some(item) = scope.next().await { // join error is nothing we can do about - result.push(item.unwrap()?) - } - Ok(result) - }) + // let iterable = iterable.into_iter().map(|f| { + // spawner.spawn(f) + // }); + // let spawner = UnsafeSpawner::default(); + let mut futures = FuturesOrdered::default(); + let spawner = UnsafeSpawner::default(); + for f in iterable.into_iter() { + futures.push_back(spawner.spawn(f.into_future())); + } + Box::pin(async move { futures.try_collect().await }) + // ParallelFutures2 { + // spawner, + // inner: futures::future::try_join_all(iterable.into_iter().map(|f| spawner.spawn(f))), + // } // #[allow(clippy::disallowed_methods)] // Just in this one place. - // futures::future::try_join_all(iterable) + // futures::future::try_join_all(iterable.into_iter() + // .map(|f| tokio::spawn())) } /// The amount of active work that is concurrently permitted. fn active_work(&self) -> NonZeroUsize; } -type SeqTryJoinAll = SequentialFutures::IntoIter>, F>; +type SeqTryJoinAll<'a, I, F> = SequentialFutures<'a, StreamIter<::IntoIter>, F>; /// A substitute for [`futures::future::try_join_all`] that uses [`seq_join`]. /// This awaits all the provided futures in order, /// aborting early if any future returns `Result::Err`. -pub fn seq_try_join_all( +pub fn seq_try_join_all<'a, I, F, O, E>( active: NonZeroUsize, source: I, -) -> TryCollect, Vec> -where - I: IntoIterator + Send, - I::IntoIter: Send, - F: Future>, +) -> TryCollect, Vec> + where + I: IntoIterator + Send, + I::IntoIter: Send, + F: Future> + Send + 'a, + O: Send + 'static, + E: Send + 'static { seq_join(active, iter(source)).try_collect() } enum ActiveItem { - Pending(Pin>), + Pending(Pin>>), Resolved(F::Output), } -impl ActiveItem { +impl, T: Send + 'static> ActiveItem { /// Drives this item to resolved state when value is ready to be taken out. Has no effect /// if the value is ready. /// @@ -175,20 +225,23 @@ impl ActiveItem { } #[pin_project] -pub struct SequentialFutures -where - S: Stream + Send, - F: IntoFuture, +pub struct SequentialFutures<'a, S, F> + where + S: Stream + Send, + F: IntoFuture, { + spawner: UnsafeSpawner<'a, F::Output>, #[pin] source: futures::stream::Fuse, active: VecDeque>, } -impl Stream for SequentialFutures -where - S: Stream + Send, - F: IntoFuture, +impl <'a, S, F, T> Stream for SequentialFutures<'a, S, F> + where + S: Stream + Send, + F: IntoFuture, + ::IntoFuture: Send + 'a, + T: Send + 'static { type Item = F::Output; @@ -198,8 +251,9 @@ where // Draw more values from the input, up to the capacity. while this.active.len() < this.active.capacity() { if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { - this.active - .push_back(ActiveItem::Pending(Box::pin(f.into_future()))); + this.active.push_back(ActiveItem::Pending(Box::pin(this.spawner.spawn(f.into_future())))); + // this.active + // .push_back(ActiveItem::Pending(Box::pin(f.into_future()))); } else { break; } @@ -232,10 +286,12 @@ where } } -impl ExactSizeStream for SequentialFutures -where - S: Stream + Send + ExactSizeStream, - F: IntoFuture, +impl<'a, S, F, T> ExactSizeStream for SequentialFutures<'a, S, F> + where + S: Stream + Send + ExactSizeStream, + F: IntoFuture, + ::IntoFuture: Send + 'a, + T: Send + 'static { } @@ -403,7 +459,7 @@ mod test { *produced_w.lock().unwrap() += 1; lazy(|_| VALUE) }) - .take(COUNT); + .take(COUNT); let mut joined = seq_join(capacity, stream); let waker = fake_waker(); let mut cx = Context::from_waker(&waker);