Skip to content

Commit

Permalink
Merge pull request private-attribution#1336 from akoshelev/executor-r…
Browse files Browse the repository at this point in the history
…untime

Introduce IpaRuntime and plumb it all the way down to executor
  • Loading branch information
akoshelev authored Oct 9, 2024
2 parents 58e3a25 + 67e3828 commit 7e1c180
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 32 deletions.
10 changes: 9 additions & 1 deletion ipa-core/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Weak;
use async_trait::async_trait;

use crate::{
executor::IpaRuntime,
helpers::{
query::{PrepareQuery, QueryConfig, QueryInput},
routing::{Addr, RouteId},
Expand All @@ -20,6 +21,7 @@ use crate::{
pub struct AppConfig {
active_work: Option<NonZeroU32PowerOfTwo>,
key_registry: Option<KeyRegistry<PrivateKeyOnly>>,
runtime: IpaRuntime,
}

impl AppConfig {
Expand All @@ -34,6 +36,12 @@ impl AppConfig {
self.key_registry = Some(key_registry);
self
}

#[must_use]
pub fn with_runtime(mut self, runtime: IpaRuntime) -> Self {
self.runtime = runtime;
self
}
}

pub struct Setup {
Expand Down Expand Up @@ -61,7 +69,7 @@ impl Setup {
#[must_use]
pub fn new(config: AppConfig) -> (Self, HandlerRef) {
let key_registry = config.key_registry.unwrap_or_else(KeyRegistry::empty);
let query_processor = QueryProcessor::new(key_registry, config.active_work);
let query_processor = QueryProcessor::new(key_registry, config.active_work, config.runtime);
let handler = HandlerBox::empty();
let this = Self {
query_processor,
Expand Down
82 changes: 82 additions & 0 deletions ipa-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,91 @@ pub(crate) mod shim {

#[cfg(not(all(feature = "shuttle", test)))]
pub(crate) mod task {
#[allow(unused_imports)]
pub use tokio::task::{JoinError, JoinHandle};
}

#[cfg(not(feature = "shuttle"))]
pub mod executor {
use std::future::Future;

use tokio::{runtime::Handle, task::JoinHandle};

/// In prod we use Tokio scheduler, so this struct just wraps
/// its runtime handle and mimics the standard executor API.
/// The name was chosen to avoid clashes with tokio runtime
/// when importing it
#[derive(Clone)]
pub struct IpaRuntime(Handle);

/// Wrapper around Tokio's [`JoinHandle`]
pub struct IpaJoinHandle<T>(JoinHandle<T>);

impl Default for IpaRuntime {
fn default() -> Self {
Self::current()
}
}

impl IpaRuntime {
#[must_use]
pub fn current() -> Self {
Self(Handle::current())
}

#[must_use]
pub fn spawn<F>(&self, future: F) -> IpaJoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
IpaJoinHandle(self.0.spawn(future))
}
}

impl<T> IpaJoinHandle<T> {
pub fn abort(self) {
self.0.abort();
}
}
}

#[cfg(feature = "shuttle")]
pub(crate) mod executor {
use std::future::Future;

use shuttle_crate::future::{spawn, JoinHandle};

/// Shuttle does not support more than one runtime
/// so we always use its default
#[derive(Clone, Default)]
pub struct IpaRuntime;
pub struct IpaJoinHandle<T>(JoinHandle<T>);

impl IpaRuntime {
#[must_use]
pub fn current() -> Self {
Self
}

#[must_use]
#[allow(clippy::unused_self)] // to conform with runtime API
pub fn spawn<F>(&self, future: F) -> IpaJoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
IpaJoinHandle(spawn(future))
}
}

impl<T> IpaJoinHandle<T> {
pub fn abort(self) {
self.0.abort();
}
}
}

#[cfg(all(feature = "shuttle", test))]
pub(crate) mod test_executor {
use std::future::Future;
Expand Down
69 changes: 46 additions & 23 deletions ipa-core/src/query/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ use generic_array::GenericArray;
use ipa_step::StepNarrow;
use rand::rngs::StdRng;
use rand_core::SeedableRng;
#[cfg(all(feature = "shuttle", test))]
use shuttle::future as tokio;
use typenum::Unsigned;

#[cfg(any(
Expand All @@ -26,11 +24,8 @@ use typenum::Unsigned;
feature = "weak-field"
))]
use crate::ff::FieldType;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
use crate::{
ff::Fp32BitPrime, query::runner::execute_test_multiply, query::runner::test_add_in_prime_field,
};
use crate::{
executor::IpaRuntime,
ff::{boolean_array::BA32, Serializable},
helpers::{
negotiate_prss,
Expand All @@ -49,6 +44,10 @@ use crate::{
},
sync::Arc,
};
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
use crate::{
ff::Fp32BitPrime, query::runner::execute_test_multiply, query::runner::test_add_in_prime_field,
};

pub trait Result: Send + Debug {
fn to_bytes(&self) -> Vec<u8>;
Expand All @@ -74,52 +73,71 @@ where
/// Needless pass by value because IPA v3 does not make use of key registry yet.
#[allow(clippy::too_many_lines, clippy::needless_pass_by_value)]
pub fn execute<R: PrivateKeyRegistry>(
runtime: &IpaRuntime,
config: QueryConfig,
key_registry: Arc<R>,
gateway: Gateway,
input: BodyStream,
) -> RunningQuery {
match (config.query_type, config.field_type) {
#[cfg(any(test, feature = "weak-field"))]
(QueryType::TestMultiply, FieldType::Fp31) => {
do_query(config, gateway, input, |prss, gateway, _config, input| {
(QueryType::TestMultiply, FieldType::Fp31) => do_query(
runtime,
config,
gateway,
input,
|prss, gateway, _config, input| {
Box::pin(execute_test_multiply::<crate::ff::Fp31>(
prss, gateway, input,
))
})
}
},
),
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
(QueryType::TestMultiply, FieldType::Fp32BitPrime) => {
do_query(config, gateway, input, |prss, gateway, _config, input| {
(QueryType::TestMultiply, FieldType::Fp32BitPrime) => do_query(
runtime,
config,
gateway,
input,
|prss, gateway, _config, input| {
Box::pin(execute_test_multiply::<Fp32BitPrime>(prss, gateway, input))
})
}
},
),
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
(QueryType::TestShardedShuffle, _) => do_query(
runtime,
config,
gateway,
input,
|_prss, _gateway, _config, _input| unimplemented!(),
),
#[cfg(any(test, feature = "weak-field"))]
(QueryType::TestAddInPrimeField, FieldType::Fp31) => {
do_query(config, gateway, input, |prss, gateway, _config, input| {
(QueryType::TestAddInPrimeField, FieldType::Fp31) => do_query(
runtime,
config,
gateway,
input,
|prss, gateway, _config, input| {
Box::pin(test_add_in_prime_field::<crate::ff::Fp31>(
prss, gateway, input,
))
})
}
},
),
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
(QueryType::TestAddInPrimeField, FieldType::Fp32BitPrime) => {
do_query(config, gateway, input, |prss, gateway, _config, input| {
(QueryType::TestAddInPrimeField, FieldType::Fp32BitPrime) => do_query(
runtime,
config,
gateway,
input,
|prss, gateway, _config, input| {
Box::pin(test_add_in_prime_field::<Fp32BitPrime>(
prss, gateway, input,
))
})
}
},
),
// TODO(953): This is really using BA32, not Fp32bitPrime. The `FieldType` mechanism needs
// to be reworked.
(QueryType::SemiHonestOprfIpa(ipa_config), _) => do_query(
runtime,
config,
gateway,
input,
Expand All @@ -133,6 +151,7 @@ pub fn execute<R: PrivateKeyRegistry>(
},
),
(QueryType::MaliciousOprfIpa(ipa_config), _) => do_query(
runtime,
config,
gateway,
input,
Expand All @@ -146,6 +165,7 @@ pub fn execute<R: PrivateKeyRegistry>(
},
),
(QueryType::SemiHonestHybrid(query_params), _) => do_query(
runtime,
config,
gateway,
input,
Expand All @@ -162,6 +182,7 @@ pub fn execute<R: PrivateKeyRegistry>(
}

pub fn do_query<B, F>(
executor_handle: &IpaRuntime,
config: QueryConfig,
gateway: B,
input_stream: BodyStream,
Expand All @@ -180,7 +201,7 @@ where
{
let (tx, rx) = oneshot::channel();

let join_handle = tokio::spawn(async move {
let join_handle = executor_handle.spawn(async move {
let gateway = gateway.borrow();
// TODO: make it a generic argument for this function
let mut rng = StdRng::from_entropy();
Expand Down Expand Up @@ -232,6 +253,7 @@ mod tests {
use tokio::sync::Barrier;

use crate::{
executor::IpaRuntime,
ff::{FieldType, Fp31, U128Conversions},
helpers::{
query::{QueryConfig, QueryType},
Expand Down Expand Up @@ -352,6 +374,7 @@ mod tests {
Fut: Future<Output = ()> + Send,
{
do_query(
&IpaRuntime::current(),
QueryConfig {
size: 1.try_into().unwrap(),
field_type: FieldType::Fp31,
Expand Down
21 changes: 15 additions & 6 deletions ipa-core/src/query/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use serde::Serialize;

use crate::{
error::Error as ProtocolError,
executor::IpaRuntime,
helpers::{
query::{PrepareQuery, QueryConfig, QueryInput},
Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, RoleAssignment,
Expand Down Expand Up @@ -45,6 +46,7 @@ pub struct Processor {
queries: RunningQueries,
key_registry: Arc<KeyRegistry<PrivateKeyOnly>>,
active_work: Option<NonZeroU32PowerOfTwo>,
runtime: IpaRuntime,
}

impl Default for Processor {
Expand All @@ -53,6 +55,7 @@ impl Default for Processor {
queries: RunningQueries::default(),
key_registry: Arc::new(KeyRegistry::<PrivateKeyOnly>::empty()),
active_work: None,
runtime: IpaRuntime::current(),
}
}
}
Expand Down Expand Up @@ -119,11 +122,13 @@ impl Processor {
pub fn new(
key_registry: KeyRegistry<PrivateKeyOnly>,
active_work: Option<NonZeroU32PowerOfTwo>,
runtime: IpaRuntime,
) -> Self {
Self {
queries: RunningQueries::default(),
key_registry: Arc::new(key_registry),
active_work,
runtime,
}
}

Expand Down Expand Up @@ -249,6 +254,7 @@ impl Processor {
queries.insert(
input.query_id,
QueryState::Running(executor::execute(
&self.runtime,
config,
Arc::clone(&self.key_registry),
gateway,
Expand Down Expand Up @@ -584,6 +590,7 @@ mod tests {
use std::sync::Arc;

use crate::{
executor::IpaRuntime,
ff::FieldType,
helpers::{
query::{
Expand All @@ -603,11 +610,13 @@ mod tests {

#[test]
fn non_existent_query() {
let processor = Processor::default();
assert!(matches!(
processor.kill(QueryId),
Err(QueryKillStatus::NoSuchQuery(QueryId))
));
run(|| async {
let processor = Processor::default();
assert!(matches!(
processor.kill(QueryId),
Err(QueryKillStatus::NoSuchQuery(QueryId))
));
});
}

#[test]
Expand Down Expand Up @@ -650,7 +659,7 @@ mod tests {
let processor = Processor::default();
let (_tx, rx) = tokio::sync::oneshot::channel();
let counter = Arc::new(1);
let task = tokio::spawn({
let task = IpaRuntime::current().spawn({
let counter = Arc::clone(&counter);
async move {
loop {
Expand Down
Loading

0 comments on commit 7e1c180

Please sign in to comment.