Skip to content

Commit

Permalink
Remove async scope
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Nov 8, 2023
1 parent b6231e9 commit 5f3e801
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 63 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ ipa-prf = []

[dependencies]
aes = "0.8.3"
async-scoped = { version = "0.7.1", features = ["use-tokio"], path = "../rmstuff/async-scoped" }
async-trait = "0.1.68"
axum = { version = "0.5.17", optional = true, features = ["http2"] }
axum-server = { version = "0.5.1", optional = true, features = ["rustls", "rustls-pemfile", "tls-rustls"] }
Expand Down
2 changes: 1 addition & 1 deletion src/secret_sharing/replicated/malicious/additive_share.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl<V: SharedValue + ExtendableField> LinearSecretSharing<V> for AdditiveShare<
/// when the protocol is done. This should not be used directly.
#[async_trait]
pub trait Downgrade: Send {
type Target: Send;
type Target: Send + 'static;
async fn downgrade(self) -> UnauthorizedDowngradeWrapper<Self::Target>;
}

Expand Down
178 changes: 117 additions & 61 deletions src/seq_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,65 @@ use std::{
pin::Pin,
task::{Context, Poll},
};
use async_scoped::spawner::use_tokio::Tokio;
use async_scoped::TokioScope;
use std::marker::PhantomData;
use async_trait::async_trait;
use clap::builder::TypedValueParser;

use futures::{
stream::{iter, Iter as StreamIter, TryCollect},
Future, Stream, StreamExt, TryStreamExt,
};
use futures::{stream::{iter, Iter as StreamIter, TryCollect}, Future, Stream, StreamExt, TryStreamExt, TryFuture};
use futures_util::future::TryJoinAll;
use futures_util::stream::FuturesOrdered;
use pin_project::pin_project;

use crate::exact::ExactSizeStream;




struct UnsafeSpawner<'a, T> {
_t_marker: PhantomData<T>,
// Future proof against variance changes
_marker: PhantomData<fn(&'a ()) -> &'a ()>,
}

impl <'a, T> Default for UnsafeSpawner<'a, T> {
fn default() -> Self {
Self {
_t_marker: PhantomData,
_marker: PhantomData,
}
}
}

#[pin_project]
struct UnsafeSpawnerHandle<T> {
#[pin]
inner: tokio::task::JoinHandle<T>
}

impl <T: Send + 'static> Future for UnsafeSpawnerHandle<T> {
type Output = T;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().inner.poll(cx) {
Poll::Ready(Ok(t)) => Poll::Ready(t),
Poll::Ready(Err(e)) => panic!("cancelled: {e}"),
Poll::Pending => Poll::Pending,
}
}
}

impl <'a, T: Send + 'static> UnsafeSpawner<'a, T> {
fn spawn<F: Future<Output = T> + Send + 'a>(&self, f: F) -> UnsafeSpawnerHandle<T> {
let handle = tokio::spawn(unsafe {
std::mem::transmute::<_, Pin<Box<dyn Future<Output = T> + Send>>>(
Box::pin(f) as Pin<Box<dyn Future<Output = T>>>
)
});

UnsafeSpawnerHandle { inner: handle }
}
}

/// This helper function might be necessary to convince the compiler that
/// the return value from [`seq_try_join_all`] implements `Send`.
/// Use this if you get higher-ranked lifetime errors that mention `std::marker::Send`.
Expand All @@ -42,12 +89,14 @@ pub fn assert_send<'a, O>(
/// [`try_join_all`]: futures::future::try_join_all
/// [`Stream`]: futures::stream::Stream
/// [`StreamExt::buffered`]: futures::stream::StreamExt::buffered
pub fn seq_join<S, F, O>(active: NonZeroUsize, source: S) -> SequentialFutures<S, F>
where
S: Stream<Item = F> + Send,
F: Future<Output = O>,
pub fn seq_join<'a, S, F, O>(active: NonZeroUsize, source: S) -> SequentialFutures<'a, S, F>
where
S: Stream<Item = F> + Send,
F: Future<Output = O> + Send + 'a,
O: Send + 'static
{
SequentialFutures {
spawner: UnsafeSpawner::default(),
source: source.fuse(),
active: VecDeque::with_capacity(active.get()),
}
Expand All @@ -56,6 +105,7 @@ where
/// The `SeqJoin` trait wraps `seq_try_join_all`, providing the `active` parameter
/// from the provided context so that the value can be made consistent.
pub trait SeqJoin {

/// Perform a sequential join of the futures from the provided iterable.
/// This uses [`seq_join`], with the current state of the associated object
/// being used to determine the number of active items to track (see [`active_work`]).
Expand All @@ -75,74 +125,74 @@ pub trait SeqJoin {
/// [`active_work`]: Self::active_work
/// [`parallel_join`]: Self::parallel_join
/// [`join3`]: futures::future::join3
fn try_join<I, F, O, E>(&self, iterable: I) -> TryCollect<SeqTryJoinAll<I, F>, Vec<O>>
where
I: IntoIterator<Item = F> + Send,
I::IntoIter: Send,
F: Future<Output = Result<O, E>>,
fn try_join<'a, I, F, O, E>(&self, iterable: I) -> TryCollect<SeqTryJoinAll<'a, I, F>, Vec<O>>
where
I: IntoIterator<Item = F> + Send,
I::IntoIter: Send,
F: Future<Output = Result<O, E>> + Send + 'a,
O: Send + 'static,
E: Send + 'static
{
seq_try_join_all(self.active_work(), iterable)
}

/// Join multiple tasks in parallel. Only do this if you can't use a sequential join.
fn parallel_join<'a, I, F, O, E>(&self, iterable: I) -> Pin<Box<dyn Future<Output = Result<Vec<O>, E>> + Send + 'a>>
fn parallel_join<I, F, O, E>(&self, iterable: I) -> Pin<Box<dyn Future<Output = Result<Vec<O>, E>> + Send>>
where
I: IntoIterator<Item = F> + Send,
F: Future<Output = Result<O, E>> + Send + 'a,
F: Future<Output = Result<O, E>> + Send,
O: Send + 'static,
E: Send + 'static
{
// TODO: implement spawner for shuttle
let mut scope = {
let iter = iterable.into_iter();
let mut scope = unsafe { TokioScope::create(Tokio) };
for element in iter {
// it is important to make those cancellable.
// TODO: elaborate why
scope.spawn_cancellable(element, || panic!("Future is cancelled."));
}
scope
};

Box::pin(async move {
let mut result = Vec::with_capacity(scope.len());
while let Some(item) = scope.next().await { // join error is nothing we can do about
result.push(item.unwrap()?)
}
Ok(result)
})
// let iterable = iterable.into_iter().map(|f| {
// spawner.spawn(f)
// });
// let spawner = UnsafeSpawner::default();
let mut futures = FuturesOrdered::default();
let spawner = UnsafeSpawner::default();
for f in iterable.into_iter() {
futures.push_back(spawner.spawn(f.into_future()));
}

Box::pin(async move { futures.try_collect().await })
// ParallelFutures2 {
// spawner,
// inner: futures::future::try_join_all(iterable.into_iter().map(|f| spawner.spawn(f))),
// }
// #[allow(clippy::disallowed_methods)] // Just in this one place.
// futures::future::try_join_all(iterable)
// futures::future::try_join_all(iterable.into_iter()
// .map(|f| tokio::spawn()))
}

/// The amount of active work that is concurrently permitted.
fn active_work(&self) -> NonZeroUsize;
}

type SeqTryJoinAll<I, F> = SequentialFutures<StreamIter<<I as IntoIterator>::IntoIter>, F>;
type SeqTryJoinAll<'a, I, F> = SequentialFutures<'a, StreamIter<<I as IntoIterator>::IntoIter>, F>;

/// A substitute for [`futures::future::try_join_all`] that uses [`seq_join`].
/// This awaits all the provided futures in order,
/// aborting early if any future returns `Result::Err`.
pub fn seq_try_join_all<I, F, O, E>(
pub fn seq_try_join_all<'a, I, F, O, E>(
active: NonZeroUsize,
source: I,
) -> TryCollect<SeqTryJoinAll<I, F>, Vec<O>>
where
I: IntoIterator<Item = F> + Send,
I::IntoIter: Send,
F: Future<Output = Result<O, E>>,
) -> TryCollect<SeqTryJoinAll<'a, I, F>, Vec<O>>
where
I: IntoIterator<Item = F> + Send,
I::IntoIter: Send,
F: Future<Output = Result<O, E>> + Send + 'a,
O: Send + 'static,
E: Send + 'static
{
seq_join(active, iter(source)).try_collect()
}

enum ActiveItem<F: IntoFuture> {
Pending(Pin<Box<F::IntoFuture>>),
Pending(Pin<Box<UnsafeSpawnerHandle<F::Output>>>),
Resolved(F::Output),
}

impl<F: IntoFuture> ActiveItem<F> {
impl<F: IntoFuture<Output = T>, T: Send + 'static> ActiveItem<F> {
/// Drives this item to resolved state when value is ready to be taken out. Has no effect
/// if the value is ready.
///
Expand Down Expand Up @@ -175,20 +225,23 @@ impl<F: IntoFuture> ActiveItem<F> {
}

#[pin_project]
pub struct SequentialFutures<S, F>
where
S: Stream<Item = F> + Send,
F: IntoFuture,
pub struct SequentialFutures<'a, S, F>
where
S: Stream<Item = F> + Send,
F: IntoFuture,
{
spawner: UnsafeSpawner<'a, F::Output>,
#[pin]
source: futures::stream::Fuse<S>,
active: VecDeque<ActiveItem<F>>,
}

impl<S, F> Stream for SequentialFutures<S, F>
where
S: Stream<Item = F> + Send,
F: IntoFuture,
impl <'a, S, F, T> Stream for SequentialFutures<'a, S, F>
where
S: Stream<Item = F> + Send,
F: IntoFuture<Output = T>,
<F as IntoFuture>::IntoFuture: Send + 'a,
T: Send + 'static
{
type Item = F::Output;

Expand All @@ -198,8 +251,9 @@ where
// Draw more values from the input, up to the capacity.
while this.active.len() < this.active.capacity() {
if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) {
this.active
.push_back(ActiveItem::Pending(Box::pin(f.into_future())));
this.active.push_back(ActiveItem::Pending(Box::pin(this.spawner.spawn(f.into_future()))));
// this.active
// .push_back(ActiveItem::Pending(Box::pin(f.into_future())));
} else {
break;
}
Expand Down Expand Up @@ -232,10 +286,12 @@ where
}
}

impl<S, F> ExactSizeStream for SequentialFutures<S, F>
where
S: Stream<Item = F> + Send + ExactSizeStream,
F: IntoFuture,
impl<'a, S, F, T> ExactSizeStream for SequentialFutures<'a, S, F>
where
S: Stream<Item = F> + Send + ExactSizeStream,
F: IntoFuture<Output = T>,
<F as IntoFuture>::IntoFuture: Send + 'a,
T: Send + 'static
{
}

Expand Down Expand Up @@ -403,7 +459,7 @@ mod test {
*produced_w.lock().unwrap() += 1;
lazy(|_| VALUE)
})
.take(COUNT);
.take(COUNT);
let mut joined = seq_join(capacity, stream);
let waker = fake_waker();
let mut cx = Context::from_waker(&waker);
Expand Down

0 comments on commit 5f3e801

Please sign in to comment.