From b6231e918d765d8d9618a895446c9ce033e9573c Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 8 Nov 2023 14:58:09 -0800 Subject: [PATCH] Make `parallel_join` spawn tasks --- Cargo.toml | 1 + src/protocol/basics/reshare.rs | 2 +- src/protocol/boolean/generate_random_bits.rs | 4 +- .../modulus_conversion/convert_shares.rs | 32 ++++++++-------- src/protocol/sort/generate_permutation_opt.rs | 4 +- src/secret_sharing/scheme.rs | 3 +- src/seq_join.rs | 37 ++++++++++++++++--- 7 files changed, 57 insertions(+), 26 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9962d04ec..3eb6d5cf7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ ipa-prf = [] [dependencies] aes = "0.8.3" +async-scoped = { version = "0.7.1", features = ["use-tokio"], path = "../rmstuff/async-scoped" } async-trait = "0.1.68" axum = { version = "0.5.17", optional = true, features = ["http2"] } axum-server = { version = "0.5.1", optional = true, features = ["rustls", "rustls-pemfile", "tls-rustls"] } diff --git a/src/protocol/basics/reshare.rs b/src/protocol/basics/reshare.rs index 2e9a868e6..70b65c1c3 100644 --- a/src/protocol/basics/reshare.rs +++ b/src/protocol/basics/reshare.rs @@ -42,7 +42,7 @@ use crate::{ /// `to_helper` = (`rand_left`, `rand_right`) = (r0, r1) /// `to_helper.right` = (`rand_right`, part1 + part2) = (r0, part1 + part2) #[async_trait] -pub trait Reshare: Sized { +pub trait Reshare: Sized + 'static { async fn reshare<'fut>( &self, ctx: C, diff --git a/src/protocol/boolean/generate_random_bits.rs b/src/protocol/boolean/generate_random_bits.rs index 184bf7d80..67b3a209f 100644 --- a/src/protocol/boolean/generate_random_bits.rs +++ b/src/protocol/boolean/generate_random_bits.rs @@ -101,10 +101,10 @@ impl Iterator for RawRandomBitIter { /// # Panics /// If the provided context has an unspecified total record count. /// An indeterminate limit works, but setting a fixed value greatly helps performance. -pub fn random_bits(ctx: C) -> impl Stream, Error>> +pub fn random_bits<'ctx, F, C>(ctx: C) -> impl Stream, Error>> + 'ctx where F: PrimeField, - C: UpgradedContext, + C: UpgradedContext + 'ctx, C::Share: LinearSecretSharing + SecureMul, { debug_assert!(ctx.total_records().is_specified()); diff --git a/src/protocol/modulus_conversion/convert_shares.rs b/src/protocol/modulus_conversion/convert_shares.rs index 87df09abf..dae9ae8c1 100644 --- a/src/protocol/modulus_conversion/convert_shares.rs +++ b/src/protocol/modulus_conversion/convert_shares.rs @@ -292,17 +292,17 @@ where /// # Panics /// If the total record count on the context is unspecified. #[tracing::instrument(name = "modulus_conversion", skip_all, fields(bits = ?bit_range, gate = %ctx.gate().as_ref()))] -pub fn convert_bits( +pub fn convert_bits<'a, F, V, C, S, VS>( ctx: C, binary_shares: VS, bit_range: Range, -) -> impl Stream, Error>> +) -> impl Stream, Error>> + 'a where F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'a, + C: UpgradedContext + 'a, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'a, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { @@ -313,35 +313,37 @@ where /// Note that unconverted fields are not upgraded, so they might need to be upgraded either before or /// after invoking this function. #[tracing::instrument(name = "modulus_conversion", skip_all, fields(bits = ?bit_range, gate = %ctx.gate().as_ref()))] -pub fn convert_selected_bits( +pub fn convert_selected_bits<'inp, F, V, C, S, VS, R>( ctx: C, binary_shares: VS, bit_range: Range, -) -> impl Stream, R), Error>> +) -> impl Stream, R), Error>> + 'inp where + R: Send + 'static, F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'inp, + C: UpgradedContext + 'inp, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'inp, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { convert_some_bits(ctx, binary_shares, RecordId::FIRST, bit_range) } -pub(crate) fn convert_some_bits( +pub(crate) fn convert_some_bits<'a, F, V, C, S, VS, R>( ctx: C, binary_shares: VS, first_record: RecordId, bit_range: Range, -) -> impl Stream, R), Error>> +) -> impl Stream, R), Error>> + 'a where + R: Send + 'static, F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'a, + C: UpgradedContext + 'a, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'a, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { diff --git a/src/protocol/sort/generate_permutation_opt.rs b/src/protocol/sort/generate_permutation_opt.rs index 3d1db3dff..0e8790be2 100644 --- a/src/protocol/sort/generate_permutation_opt.rs +++ b/src/protocol/sort/generate_permutation_opt.rs @@ -298,7 +298,9 @@ mod tests { } /// Passing 32 records for Fp31 doesn't work. - #[tokio::test] + /// + /// Requires one extra thread to cancel futures running in parallel with the one that panics. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[should_panic = "prime field ipa::ff::prime_field::fp31::Fp31 is too small to sort 32 records"] async fn fp31_overflow() { const COUNT: usize = 32; diff --git a/src/secret_sharing/scheme.rs b/src/secret_sharing/scheme.rs index 0d2131eeb..b20348539 100644 --- a/src/secret_sharing/scheme.rs +++ b/src/secret_sharing/scheme.rs @@ -7,7 +7,7 @@ use super::{SharedValue, WeakSharedValue}; use crate::ff::{AddSub, AddSubAssign, GaloisField}; /// Secret sharing scheme i.e. Replicated secret sharing -pub trait SecretSharing: Clone + Debug + Sized + Send + Sync { +pub trait SecretSharing: Clone + Debug + Sized + Send + Sync + 'static { const ZERO: Self; } @@ -21,6 +21,7 @@ pub trait Linear: + Mul + for<'r> Mul<&'r V, Output = Self> + Neg + + 'static { } diff --git a/src/seq_join.rs b/src/seq_join.rs index e3cb8c5b3..34832b436 100644 --- a/src/seq_join.rs +++ b/src/seq_join.rs @@ -5,6 +5,9 @@ use std::{ pin::Pin, task::{Context, Poll}, }; +use async_scoped::spawner::use_tokio::Tokio; +use async_scoped::TokioScope; +use async_trait::async_trait; use futures::{ stream::{iter, Iter as StreamIter, TryCollect}, @@ -82,13 +85,35 @@ pub trait SeqJoin { } /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. - fn parallel_join(&self, iterable: I) -> futures::future::TryJoinAll - where - I: IntoIterator, - I::Item: futures::future::TryFuture, + fn parallel_join<'a, I, F, O, E>(&self, iterable: I) -> Pin, E>> + Send + 'a>> + where + I: IntoIterator + Send, + F: Future> + Send + 'a, + O: Send + 'static, + E: Send + 'static { - #[allow(clippy::disallowed_methods)] // Just in this one place. - futures::future::try_join_all(iterable) + // TODO: implement spawner for shuttle + let mut scope = { + let iter = iterable.into_iter(); + let mut scope = unsafe { TokioScope::create(Tokio) }; + for element in iter { + // it is important to make those cancellable. + // TODO: elaborate why + scope.spawn_cancellable(element, || panic!("Future is cancelled.")); + } + scope + }; + + Box::pin(async move { + let mut result = Vec::with_capacity(scope.len()); + while let Some(item) = scope.next().await { // join error is nothing we can do about + result.push(item.unwrap()?) + } + Ok(result) + }) + + // #[allow(clippy::disallowed_methods)] // Just in this one place. + // futures::future::try_join_all(iterable) } /// The amount of active work that is concurrently permitted.