Skip to content

Commit

Permalink
Complete query API for sharded environments
Browse files Browse the repository at this point in the history
As we discussed, we want to make the complete query API simple to implement. That means, we have an assumption that all shards communicate their results to the leader. That leaves complete API to just clean up the state on each shard and return response from the leader. This PR does exactly that.
  • Loading branch information
akoshelev committed Nov 19, 2024
1 parent e831a21 commit c5e3924
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 9 deletions.
7 changes: 5 additions & 2 deletions ipa-core/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl HelperApp {
Ok(self
.inner
.query_processor
.complete(query_id)
.complete(query_id, self.inner.shard_transport.clone_ref())
.await?
.to_bytes())
}
Expand Down Expand Up @@ -251,7 +251,10 @@ impl RequestHandler<HelperIdentity> for Inner {
}
RouteId::CompleteQuery => {
let query_id = ext_query_id(&req)?;
HelperResponse::from(qp.complete(query_id).await?)
HelperResponse::from(
qp.complete(query_id, self.shard_transport.clone_ref())
.await?,
)
}
RouteId::KillQuery => {
let query_id = ext_query_id(&req)?;
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/helpers/transport/routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
};

// The type of request made to an MPC helper.
#[derive(Debug, Copy, Clone)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum RouteId {
Records,
ReceiveQuery,
Expand Down
134 changes: 130 additions & 4 deletions ipa-core/src/query/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::{
executor::IpaRuntime,
helpers::{
query::{PrepareQuery, QueryConfig, QueryInput},
routing::RouteId,
BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role,
RoleAssignment, ShardTransportError, ShardTransportImpl, Transport,
},
Expand Down Expand Up @@ -123,6 +124,8 @@ pub enum QueryCompletionError {
},
#[error("query execution failed: {0}")]
ExecutionError(#[from] ProtocolError),
#[error("one or more shards rejected this request: {0}")]
ShardError(#[from] BroadcastError<ShardIndex, ShardTransportError>),
}

impl Debug for Processor {
Expand Down Expand Up @@ -373,6 +376,7 @@ impl Processor {
pub async fn complete(
&self,
query_id: QueryId,
shard_transport: ShardTransportImpl,
) -> Result<Box<dyn ProtocolResult>, QueryCompletionError> {
let handle = {
let mut queries = self.queries.inner.lock().unwrap();
Expand All @@ -397,6 +401,18 @@ impl Processor {
}
}; // release mutex before await

// Inform other shards about our intent to complete the query.
// If any of them rejects it, report the error back. We expect all shards
// to be in the same state. In normal cycle, this API is called only after
// query status reports completion.
if shard_transport.identity() == ShardIndex::FIRST {
// See shard finalizer protocol to see how shards merge their results together.
// At the end, only leader holds the value
shard_transport
.broadcast((RouteId::CompleteQuery, query_id))
.await?;
}

Ok(handle.await?)
}

Expand Down Expand Up @@ -440,7 +456,8 @@ mod tests {
use tokio::sync::Barrier;

use crate::{
ff::FieldType,
executor::IpaRuntime,
ff::{boolean_array::BA64, FieldType},
helpers::{
make_owned_handler,
query::{PrepareQuery, QueryConfig, QueryType::TestMultiply},
Expand All @@ -450,8 +467,9 @@ mod tests {
},
protocol::QueryId,
query::{
processor::Processor, state::StateError, NewQueryError, PrepareQueryError, QueryStatus,
QueryStatusError,
processor::Processor,
state::{QueryState, RunningQuery, StateError},
NewQueryError, PrepareQueryError, QueryStatus, QueryStatusError,
},
sharding::ShardIndex,
};
Expand All @@ -472,7 +490,8 @@ mod tests {
}

fn shard_respond_ok(_si: ShardIndex) -> Arc<dyn RequestHandler<ShardIndex>> {
prepare_query_handler(|_| async { Ok(HelperResponse::ok()) })
make_owned_handler(move |_req, _| futures::future::ok(HelperResponse::ok()))
// prepare_query_handler(|_| async { Ok(HelperResponse::ok()) })
}

fn test_multiply_config() -> QueryConfig {
Expand Down Expand Up @@ -559,6 +578,12 @@ mod tests {
shard_transport: InMemoryTransport<ShardIndex>,
}

impl Default for TestComponents {
fn default() -> Self {
Self::new(TestComponentsArgs::default())
}
}

impl TestComponents {
fn new(mut args: TestComponentsArgs) -> Self {
let mpc_network = InMemoryMpcNetwork::new(
Expand All @@ -584,6 +609,32 @@ mod tests {
shard_transport,
}
}

/// This initiates a new query on all shards and puts them all on running state.
/// It also makes up a fake query result
async fn new_running_query(&self) -> QueryId {
self.processor
.new_query(
self.first_transport.clone_ref(),
self.shard_transport.clone_ref(),
self.query_config,
)
.await
.unwrap();
// don't care about the result here
let (tx, rx) = tokio::sync::oneshot::channel();
self.processor
.queries
.handle(QueryId)
.set_state(QueryState::Running(RunningQuery {
result: rx,
join_handle: IpaRuntime::current().spawn(async {}),
}))
.unwrap();
tx.send(Ok(Box::new(Vec::<BA64>::new()))).unwrap();

QueryId
}
}

#[tokio::test]
Expand Down Expand Up @@ -755,6 +806,81 @@ mod tests {
assert!(t.processor.get_status(QueryId).is_none());
}

mod complete {

use crate::{
helpers::{make_owned_handler, routing::RouteId, ApiError, Transport},
query::{
processor::{
tests::{HelperResponse, TestComponents, TestComponentsArgs},
QueryId,
},
QueryCompletionError,
},
sharding::ShardIndex,
};

#[tokio::test]
async fn complete_basic() {
let t = TestComponents::default();
let query_id = t.new_running_query().await;

t.processor
.complete(query_id, t.shard_transport.clone_ref())
.await
.unwrap();
}

#[tokio::test]
async fn complete_one_shard_fails() {
let mut args = TestComponentsArgs::default();

args.set_shard_handler(|shard_id| {
make_owned_handler(move |req, _| {
if shard_id != ShardIndex::from(1) || req.route != RouteId::CompleteQuery {
futures::future::ok(HelperResponse::ok())
} else {
futures::future::err(QueryCompletionError::NoSuchQuery(QueryId).into())
}
})
});

let t = TestComponents::new(args);
let query_id = t.new_running_query().await;

let _ = t
.processor
.complete(query_id, t.shard_transport.clone_ref())
.await
.unwrap_err();
}

#[tokio::test]
async fn only_leader_broadcasts() {
let mut args = TestComponentsArgs::default();

args.set_shard_handler(|shard_id| {
make_owned_handler(move |_req, _| {
if shard_id == ShardIndex::FIRST {
futures::future::err(ApiError::BadRequest(
"Leader shard must not receive requests through shard channels".into(),
))
} else {
futures::future::ok(HelperResponse::ok())
}
})
});

let t = TestComponents::new(args);
let query_id = t.new_running_query().await;

t.processor
.complete(query_id, t.shard_transport.clone_ref())
.await
.unwrap();
}
}

mod prepare {
use super::*;
use crate::query::QueryStatusError;
Expand Down
5 changes: 3 additions & 2 deletions ipa-core/src/query/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ pub enum QueryState {

impl QueryState {
pub fn transition(cur_state: &Self, new_state: Self) -> Result<Self, StateError> {
use QueryState::{AwaitingInputs, Empty, Preparing};
use QueryState::{AwaitingInputs, Empty, Preparing, Running};

match (cur_state, &new_state) {
// If query is not running, coordinator initial state is preparing
// and followers initial state is awaiting inputs
(Empty, Preparing(_) | AwaitingInputs(_, _, _))
| (Preparing(_), AwaitingInputs(_, _, _)) => Ok(new_state),
| (Preparing(_), AwaitingInputs(_, _, _))
| (AwaitingInputs(_, _, _), Running(_)) => Ok(new_state),
(_, Preparing(_)) => Err(StateError::AlreadyRunning),
(_, _) => Err(StateError::InvalidState {
from: cur_state.into(),
Expand Down

0 comments on commit c5e3924

Please sign in to comment.