diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index fb6f9fdb7..792215b1e 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -161,7 +161,7 @@ impl HelperApp { Ok(self .inner .query_processor - .complete(query_id) + .complete(query_id, self.inner.shard_transport.clone_ref()) .await? .to_bytes()) } @@ -251,7 +251,10 @@ impl RequestHandler for Inner { } RouteId::CompleteQuery => { let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.complete(query_id).await?) + HelperResponse::from( + qp.complete(query_id, self.shard_transport.clone_ref()) + .await?, + ) } RouteId::KillQuery => { let query_id = ext_query_id(&req)?; diff --git a/ipa-core/src/helpers/transport/routing.rs b/ipa-core/src/helpers/transport/routing.rs index 3d9c2bb5f..6cb1006df 100644 --- a/ipa-core/src/helpers/transport/routing.rs +++ b/ipa-core/src/helpers/transport/routing.rs @@ -8,7 +8,7 @@ use crate::{ }; // The type of request made to an MPC helper. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum RouteId { Records, ReceiveQuery, diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 120e1c5ca..d53277e09 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -11,6 +11,7 @@ use crate::{ executor::IpaRuntime, helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, + routing::RouteId, BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, }, @@ -123,6 +124,8 @@ pub enum QueryCompletionError { }, #[error("query execution failed: {0}")] ExecutionError(#[from] ProtocolError), + #[error("one or more shards rejected this request: {0}")] + ShardError(#[from] BroadcastError), } impl Debug for Processor { @@ -373,6 +376,7 @@ impl Processor { pub async fn complete( &self, query_id: QueryId, + shard_transport: ShardTransportImpl, ) -> Result, QueryCompletionError> { let handle = { let mut queries = self.queries.inner.lock().unwrap(); @@ -397,6 +401,18 @@ impl Processor { } }; // release mutex before await + // Inform other shards about our intent to complete the query. + // If any of them rejects it, report the error back. We expect all shards + // to be in the same state. In normal cycle, this API is called only after + // query status reports completion. + if shard_transport.identity() == ShardIndex::FIRST { + // See shard finalizer protocol to see how shards merge their results together. + // At the end, only leader holds the value + shard_transport + .broadcast((RouteId::CompleteQuery, query_id)) + .await?; + } + Ok(handle.await?) } @@ -440,7 +456,8 @@ mod tests { use tokio::sync::Barrier; use crate::{ - ff::FieldType, + executor::IpaRuntime, + ff::{boolean_array::BA64, FieldType}, helpers::{ make_owned_handler, query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, @@ -450,8 +467,9 @@ mod tests { }, protocol::QueryId, query::{ - processor::Processor, state::StateError, NewQueryError, PrepareQueryError, QueryStatus, - QueryStatusError, + processor::Processor, + state::{QueryState, RunningQuery, StateError}, + NewQueryError, PrepareQueryError, QueryStatus, QueryStatusError, }, sharding::ShardIndex, }; @@ -472,7 +490,8 @@ mod tests { } fn shard_respond_ok(_si: ShardIndex) -> Arc> { - prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) + make_owned_handler(move |_req, _| futures::future::ok(HelperResponse::ok())) + // prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) } fn test_multiply_config() -> QueryConfig { @@ -559,6 +578,12 @@ mod tests { shard_transport: InMemoryTransport, } + impl Default for TestComponents { + fn default() -> Self { + Self::new(TestComponentsArgs::default()) + } + } + impl TestComponents { fn new(mut args: TestComponentsArgs) -> Self { let mpc_network = InMemoryMpcNetwork::new( @@ -584,6 +609,32 @@ mod tests { shard_transport, } } + + /// This initiates a new query on all shards and puts them all on running state. + /// It also makes up a fake query result + async fn new_running_query(&self) -> QueryId { + self.processor + .new_query( + self.first_transport.clone_ref(), + self.shard_transport.clone_ref(), + self.query_config, + ) + .await + .unwrap(); + // don't care about the result here + let (tx, rx) = tokio::sync::oneshot::channel(); + self.processor + .queries + .handle(QueryId) + .set_state(QueryState::Running(RunningQuery { + result: rx, + join_handle: IpaRuntime::current().spawn(async {}), + })) + .unwrap(); + tx.send(Ok(Box::new(Vec::::new()))).unwrap(); + + QueryId + } } #[tokio::test] @@ -755,6 +806,81 @@ mod tests { assert!(t.processor.get_status(QueryId).is_none()); } + mod complete { + + use crate::{ + helpers::{make_owned_handler, routing::RouteId, ApiError, Transport}, + query::{ + processor::{ + tests::{HelperResponse, TestComponents, TestComponentsArgs}, + QueryId, + }, + QueryCompletionError, + }, + sharding::ShardIndex, + }; + + #[tokio::test] + async fn complete_basic() { + let t = TestComponents::default(); + let query_id = t.new_running_query().await; + + t.processor + .complete(query_id, t.shard_transport.clone_ref()) + .await + .unwrap(); + } + + #[tokio::test] + async fn complete_one_shard_fails() { + let mut args = TestComponentsArgs::default(); + + args.set_shard_handler(|shard_id| { + make_owned_handler(move |req, _| { + if shard_id != ShardIndex::from(1) || req.route != RouteId::CompleteQuery { + futures::future::ok(HelperResponse::ok()) + } else { + futures::future::err(QueryCompletionError::NoSuchQuery(QueryId).into()) + } + }) + }); + + let t = TestComponents::new(args); + let query_id = t.new_running_query().await; + + let _ = t + .processor + .complete(query_id, t.shard_transport.clone_ref()) + .await + .unwrap_err(); + } + + #[tokio::test] + async fn only_leader_broadcasts() { + let mut args = TestComponentsArgs::default(); + + args.set_shard_handler(|shard_id| { + make_owned_handler(move |_req, _| { + if shard_id == ShardIndex::FIRST { + futures::future::err(ApiError::BadRequest( + "Leader shard must not receive requests through shard channels".into(), + )) + } else { + futures::future::ok(HelperResponse::ok()) + } + }) + }); + + let t = TestComponents::new(args); + let query_id = t.new_running_query().await; + + t.processor + .complete(query_id, t.shard_transport.clone_ref()) + .await + .unwrap(); + } + } + mod prepare { use super::*; use crate::query::QueryStatusError; diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index 460296022..f745ada31 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -60,13 +60,14 @@ pub enum QueryState { impl QueryState { pub fn transition(cur_state: &Self, new_state: Self) -> Result { - use QueryState::{AwaitingInputs, Empty, Preparing}; + use QueryState::{AwaitingInputs, Empty, Preparing, Running}; match (cur_state, &new_state) { // If query is not running, coordinator initial state is preparing // and followers initial state is awaiting inputs (Empty, Preparing(_) | AwaitingInputs(_, _, _)) - | (Preparing(_), AwaitingInputs(_, _, _)) => Ok(new_state), + | (Preparing(_), AwaitingInputs(_, _, _)) + | (AwaitingInputs(_, _, _), Running(_)) => Ok(new_state), (_, Preparing(_)) => Err(StateError::AlreadyRunning), (_, _) => Err(StateError::InvalidState { from: cur_state.into(),