Skip to content

Commit

Permalink
Make parallel_join spawn tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Nov 8, 2023
1 parent cb0fe9e commit b6231e9
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 26 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ ipa-prf = []

[dependencies]
aes = "0.8.3"
async-scoped = { version = "0.7.1", features = ["use-tokio"], path = "../rmstuff/async-scoped" }
async-trait = "0.1.68"
axum = { version = "0.5.17", optional = true, features = ["http2"] }
axum-server = { version = "0.5.1", optional = true, features = ["rustls", "rustls-pemfile", "tls-rustls"] }
Expand Down
2 changes: 1 addition & 1 deletion src/protocol/basics/reshare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use crate::{
/// `to_helper` = (`rand_left`, `rand_right`) = (r0, r1)
/// `to_helper.right` = (`rand_right`, part1 + part2) = (r0, part1 + part2)
#[async_trait]
pub trait Reshare<C: Context, B: RecordBinding>: Sized {
pub trait Reshare<C: Context, B: RecordBinding>: Sized + 'static {
async fn reshare<'fut>(
&self,
ctx: C,
Expand Down
4 changes: 2 additions & 2 deletions src/protocol/boolean/generate_random_bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ impl<F: PrimeField, C: Context> Iterator for RawRandomBitIter<F, C> {
/// # Panics
/// If the provided context has an unspecified total record count.
/// An indeterminate limit works, but setting a fixed value greatly helps performance.
pub fn random_bits<F, C>(ctx: C) -> impl Stream<Item = Result<BitDecomposed<C::Share>, Error>>
pub fn random_bits<'ctx, F, C>(ctx: C) -> impl Stream<Item = Result<BitDecomposed<C::Share>, Error>> + 'ctx
where
F: PrimeField,
C: UpgradedContext<F>,
C: UpgradedContext<F> + 'ctx,
C::Share: LinearSecretSharing<F> + SecureMul<C>,
{
debug_assert!(ctx.total_records().is_specified());
Expand Down
32 changes: 17 additions & 15 deletions src/protocol/modulus_conversion/convert_shares.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,17 +292,17 @@ where
/// # Panics
/// If the total record count on the context is unspecified.
#[tracing::instrument(name = "modulus_conversion", skip_all, fields(bits = ?bit_range, gate = %ctx.gate().as_ref()))]
pub fn convert_bits<F, V, C, S, VS>(
pub fn convert_bits<'a, F, V, C, S, VS>(
ctx: C,
binary_shares: VS,
bit_range: Range<u32>,
) -> impl Stream<Item = Result<BitDecomposed<S>, Error>>
) -> impl Stream<Item = Result<BitDecomposed<S>, Error>> + 'a
where
F: PrimeField,
V: ToBitConversionTriples<Residual = ()>,
C: UpgradedContext<F, Share = S>,
V: ToBitConversionTriples<Residual = ()> + 'a,
C: UpgradedContext<F, Share = S> + 'a,
S: LinearSecretSharing<F> + SecureMul<C>,
VS: Stream<Item = V> + Unpin + Send,
VS: Stream<Item = V> + Unpin + Send + 'a,
for<'u> UpgradeContext<'u, C, F, RecordId>:
UpgradeToMalicious<'u, BitConversionTriple<Replicated<F>>, BitConversionTriple<C::Share>>,
{
Expand All @@ -313,35 +313,37 @@ where
/// Note that unconverted fields are not upgraded, so they might need to be upgraded either before or
/// after invoking this function.
#[tracing::instrument(name = "modulus_conversion", skip_all, fields(bits = ?bit_range, gate = %ctx.gate().as_ref()))]
pub fn convert_selected_bits<F, V, C, S, VS, R>(
pub fn convert_selected_bits<'inp, F, V, C, S, VS, R>(
ctx: C,
binary_shares: VS,
bit_range: Range<u32>,
) -> impl Stream<Item = Result<(BitDecomposed<S>, R), Error>>
) -> impl Stream<Item = Result<(BitDecomposed<S>, R), Error>> + 'inp
where
R: Send + 'static,
F: PrimeField,
V: ToBitConversionTriples<Residual = R>,
C: UpgradedContext<F, Share = S>,
V: ToBitConversionTriples<Residual = R> + 'inp,
C: UpgradedContext<F, Share = S> + 'inp,
S: LinearSecretSharing<F> + SecureMul<C>,
VS: Stream<Item = V> + Unpin + Send,
VS: Stream<Item = V> + Unpin + Send + 'inp,
for<'u> UpgradeContext<'u, C, F, RecordId>:
UpgradeToMalicious<'u, BitConversionTriple<Replicated<F>>, BitConversionTriple<C::Share>>,
{
convert_some_bits(ctx, binary_shares, RecordId::FIRST, bit_range)
}

pub(crate) fn convert_some_bits<F, V, C, S, VS, R>(
pub(crate) fn convert_some_bits<'a, F, V, C, S, VS, R>(
ctx: C,
binary_shares: VS,
first_record: RecordId,
bit_range: Range<u32>,
) -> impl Stream<Item = Result<(BitDecomposed<S>, R), Error>>
) -> impl Stream<Item = Result<(BitDecomposed<S>, R), Error>> + 'a
where
R: Send + 'static,
F: PrimeField,
V: ToBitConversionTriples<Residual = R>,
C: UpgradedContext<F, Share = S>,
V: ToBitConversionTriples<Residual = R> + 'a,
C: UpgradedContext<F, Share = S> + 'a,
S: LinearSecretSharing<F> + SecureMul<C>,
VS: Stream<Item = V> + Unpin + Send,
VS: Stream<Item = V> + Unpin + Send + 'a,
for<'u> UpgradeContext<'u, C, F, RecordId>:
UpgradeToMalicious<'u, BitConversionTriple<Replicated<F>>, BitConversionTriple<C::Share>>,
{
Expand Down
4 changes: 3 additions & 1 deletion src/protocol/sort/generate_permutation_opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,9 @@ mod tests {
}

/// Passing 32 records for Fp31 doesn't work.
#[tokio::test]
///
/// Requires one extra thread to cancel futures running in parallel with the one that panics.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[should_panic = "prime field ipa::ff::prime_field::fp31::Fp31 is too small to sort 32 records"]
async fn fp31_overflow() {
const COUNT: usize = 32;
Expand Down
3 changes: 2 additions & 1 deletion src/secret_sharing/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::{SharedValue, WeakSharedValue};
use crate::ff::{AddSub, AddSubAssign, GaloisField};

/// Secret sharing scheme i.e. Replicated secret sharing
pub trait SecretSharing<V: WeakSharedValue>: Clone + Debug + Sized + Send + Sync {
pub trait SecretSharing<V: WeakSharedValue>: Clone + Debug + Sized + Send + Sync + 'static {
const ZERO: Self;
}

Expand All @@ -21,6 +21,7 @@ pub trait Linear<V: SharedValue>:
+ Mul<V, Output = Self>
+ for<'r> Mul<&'r V, Output = Self>
+ Neg<Output = Self>
+ 'static
{
}

Expand Down
37 changes: 31 additions & 6 deletions src/seq_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use std::{
pin::Pin,
task::{Context, Poll},
};
use async_scoped::spawner::use_tokio::Tokio;
use async_scoped::TokioScope;
use async_trait::async_trait;

use futures::{
stream::{iter, Iter as StreamIter, TryCollect},
Expand Down Expand Up @@ -82,13 +85,35 @@ pub trait SeqJoin {
}

/// Join multiple tasks in parallel. Only do this if you can't use a sequential join.
fn parallel_join<I>(&self, iterable: I) -> futures::future::TryJoinAll<I::Item>
where
I: IntoIterator,
I::Item: futures::future::TryFuture,
fn parallel_join<'a, I, F, O, E>(&self, iterable: I) -> Pin<Box<dyn Future<Output = Result<Vec<O>, E>> + Send + 'a>>
where
I: IntoIterator<Item = F> + Send,
F: Future<Output = Result<O, E>> + Send + 'a,
O: Send + 'static,
E: Send + 'static
{
#[allow(clippy::disallowed_methods)] // Just in this one place.
futures::future::try_join_all(iterable)
// TODO: implement spawner for shuttle
let mut scope = {
let iter = iterable.into_iter();
let mut scope = unsafe { TokioScope::create(Tokio) };
for element in iter {
// it is important to make those cancellable.
// TODO: elaborate why
scope.spawn_cancellable(element, || panic!("Future is cancelled."));
}
scope
};

Box::pin(async move {
let mut result = Vec::with_capacity(scope.len());
while let Some(item) = scope.next().await { // join error is nothing we can do about
result.push(item.unwrap()?)
}
Ok(result)
})

// #[allow(clippy::disallowed_methods)] // Just in this one place.
// futures::future::try_join_all(iterable)
}

/// The amount of active work that is concurrently permitted.
Expand Down

0 comments on commit b6231e9

Please sign in to comment.