From abff373305572e6d706fb296959b4936867b289d Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Mon, 16 Dec 2024 13:37:37 -0800 Subject: [PATCH 01/14] Upgrade jemalloc to 0.6; add jemalloc feature jemalloc is used by default on linux. The feature enables it on non-linux platforms. --- ipa-core/Cargo.toml | 11 ++++++++--- ipa-core/benches/oneshot/ipa.rs | 8 ++------ ipa-core/build.rs | 11 +++++++++++ ipa-core/src/bin/helper.rs | 8 ++++++-- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 8bb49f189..101416e4f 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -63,7 +63,9 @@ enable-benches = ["cli", "in-memory-infra", "test-fixture", "criterion", "iai"] # of unit tests use it. Real world infra uses HTTP implementation and is suitable for integration/e2e tests in-memory-infra = [] real-world-infra = [] -dhat-heap = ["cli", "test-fixture"] +# Force use of jemalloc on non-Linux platforms. jemalloc is used by default on Linux. +jemalloc = ["tikv-jemallocator", "tikv-jemalloc-ctl"] +dhat-heap = ["cli", "dhat", "test-fixture"] # Enable this feature to enable our colossally weak Fp31. weak-field = [] # Enable using more than one thread for protocol execution. Most of the parallelism occurs at parallel/seq_join operations @@ -111,7 +113,7 @@ criterion = { version = "0.5.1", optional = true, default-features = false, feat curve25519-dalek = "4.1.1" dashmap = "5.4" delegate = "0.10.0" -dhat = "0.3.2" +dhat = { version = "0.3.2", optional = true } embed-doc-image = "0.1.4" futures = "0.3.28" futures-util = "0.3.28" @@ -144,6 +146,8 @@ sha2 = "0.10" shuttle-crate = { package = "shuttle", version = "0.6.1", optional = true } subtle = "2.6" thiserror = "1.0" +tikv-jemallocator = { version = "0.6", optional = true, features = ["profiling"] } +tikv-jemalloc-ctl = { version = "0.6", optional = true, features = ["stats"] } time = { version = "0.3", optional = true } tokio = { version = "1.42", features = ["fs", "rt", "rt-multi-thread", "macros"] } tokio-rustls = { version = "0.26", optional = true } @@ -158,7 +162,8 @@ typenum = { version = "1.17", features = ["i128"] } x25519-dalek = "2.0.0-rc.3" [target.'cfg(all(not(target_env = "msvc"), not(target_os = "macos")))'.dependencies] -tikv-jemallocator = "0.5.0" +tikv-jemallocator = { version = "0.6", features = ["profiling"] } +tikv-jemalloc-ctl = { version = "0.6", features = ["stats"] } [build-dependencies] cfg_aliases = "0.1.1" diff --git a/ipa-core/benches/oneshot/ipa.rs b/ipa-core/benches/oneshot/ipa.rs index b24d7ea7b..1f9eec376 100644 --- a/ipa-core/benches/oneshot/ipa.rs +++ b/ipa-core/benches/oneshot/ipa.rs @@ -19,13 +19,9 @@ use ipa_step::StepNarrow; use rand::{random, rngs::StdRng, SeedableRng}; use tokio::runtime::Builder; -#[cfg(all( - not(target_env = "msvc"), - not(feature = "dhat-heap"), - not(target_os = "macos") -))] +#[cfg(jemalloc)] #[global_allocator] -static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; +static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; #[cfg(feature = "dhat-heap")] #[global_allocator] diff --git a/ipa-core/build.rs b/ipa-core/build.rs index ce1987c72..07d70ffd4 100644 --- a/ipa-core/build.rs +++ b/ipa-core/build.rs @@ -47,10 +47,21 @@ fn main() { descriptive_gate: { all(not(feature = "compact-gate"), feature = "descriptive-gate") }, unit_test: { all(not(feature = "shuttle"), feature = "in-memory-infra", descriptive_gate) }, web_test: { all(not(feature = "shuttle"), feature = "real-world-infra") }, + jemalloc: { all( + not(feature = "dhat-heap"), + any( + feature = "jemalloc", + all( + not(target_env = "msvc"), + not(target_os = "macos") + ) + ) + ) }, } println!("cargo::rustc-check-cfg=cfg(descriptive_gate)"); println!("cargo::rustc-check-cfg=cfg(compact_gate)"); println!("cargo::rustc-check-cfg=cfg(unit_test)"); println!("cargo::rustc-check-cfg=cfg(web_test)"); + println!("cargo::rustc-check-cfg=cfg(jemalloc)"); println!("cargo::rustc-check-cfg=cfg(coverage)"); } diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 85b0c7110..8622ea998 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -30,9 +30,13 @@ use ipa_core::{ use tokio::runtime::Runtime; use tracing::{error, info}; -#[cfg(all(not(target_env = "msvc"), not(target_os = "macos")))] +#[cfg(jemalloc)] #[global_allocator] -static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; +static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + +#[cfg(feature = "dhat-heap")] +#[global_allocator] +static ALLOC: dhat::Alloc = dhat::Alloc; #[derive(Debug, Parser)] #[clap( From e15bcc363a96e681f4adf6116fbb5b39bb9f7ed5 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Tue, 17 Dec 2024 11:52:45 -0800 Subject: [PATCH 02/14] Remove unused features --- ipa-core/Cargo.toml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 101416e4f..1b273521f 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -18,7 +18,6 @@ default = [ "tracing/max_level_trace", "tracing/release_max_level_info", "stall-detection", - "ipa-prf", "descriptive-gate", ] cli = ["comfy-table", "clap", "num_cpus"] @@ -74,14 +73,6 @@ multi-threading = ["async-scoped"] # RUSTFLAGS="--cfg tokio_unstable" cargo run ... --features="tokio-console ...". # Note that if there are other flags enabled on your platform in .cargo/config.toml, you need to include them as well. tokio-console = ["console-subscriber", "tokio/tracing"] - -# If this flag is used, then the new breakdown reveal based aggregation is used -reveal-aggregation = [] -# Standalone aggregation protocol. We use IPA infra for communication -# but it has nothing to do with IPA. -aggregate-circuit = [] -# IPA protocol based on OPRF -ipa-prf = [] # relaxed DP, off by default relaxed-dp = [] From 62653e5a8dcc91ce6af4f507afcd5a4a85edf7c8 Mon Sep 17 00:00:00 2001 From: Shinta Liem Date: Thu, 19 Dec 2024 01:05:11 +0800 Subject: [PATCH 03/14] Metrics exporter to Prometheus with OTLP (#1438) * Add /metrics route for metric backend scraper * /metrics endpoint returns prometheus sample metrics with otel * Added metrics exporter for prometheus * Add conversion from ipa-metric counter to OTLP * Simplified test case * Move prometheus exporter to its own crate * Wiring in logging_handle to the handler * Move metrics out of query module, add test cases * cargo fmt * Add logging_handle to TestApp * imports * cargo fmt --------- Co-authored-by: Shinta Liem --- ipa-core/Cargo.toml | 2 + ipa-core/src/app.rs | 9 ++ ipa-core/src/bin/helper.rs | 2 +- ipa-core/src/cli/metric_collector.rs | 18 +++- ipa-core/src/helpers/transport/handler.rs | 6 ++ .../helpers/transport/in_memory/transport.rs | 3 +- ipa-core/src/helpers/transport/mod.rs | 20 +++++ ipa-core/src/helpers/transport/routing.rs | 1 + ipa-core/src/net/http_serde.rs | 10 +++ ipa-core/src/net/server/handlers/metrics.rs | 62 ++++++++++++++ ipa-core/src/net/server/handlers/mod.rs | 15 ++-- ipa-core/src/net/test.rs | 8 +- ipa-core/src/net/transport.rs | 5 +- ipa-core/src/test_fixture/app.rs | 12 ++- ipa-metrics-prometheus/Cargo.toml | 16 ++++ ipa-metrics-prometheus/src/exporter.rs | 83 +++++++++++++++++++ ipa-metrics-prometheus/src/lib.rs | 3 + ipa-metrics/Cargo.toml | 1 - ipa-metrics/src/partitioned.rs | 11 +++ 19 files changed, 272 insertions(+), 15 deletions(-) create mode 100644 ipa-core/src/net/server/handlers/metrics.rs create mode 100644 ipa-metrics-prometheus/Cargo.toml create mode 100644 ipa-metrics-prometheus/src/exporter.rs create mode 100644 ipa-metrics-prometheus/src/lib.rs diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 1b273521f..3268d1ddf 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -81,6 +81,7 @@ ipa-metrics = { path = "../ipa-metrics" } ipa-metrics-tracing = { optional = true, path = "../ipa-metrics-tracing" } ipa-step = { version = "*", path = "../ipa-step" } ipa-step-derive = { version = "*", path = "../ipa-step-derive" } +ipa-metrics-prometheus = { path = "../ipa-metrics-prometheus" } aes = "0.8.3" async-trait = "0.1.79" @@ -172,6 +173,7 @@ rustls = { version = "0.23" } tempfile = "3" ipa-metrics-tracing = { path = "../ipa-metrics-tracing" } ipa-metrics = { path = "../ipa-metrics", features = ["partitions"] } +ipa-metrics-prometheus = { path = "../ipa-metrics-prometheus" } [lib] path = "src/lib.rs" diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index 0d61e9bb6..b20a71c65 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -3,6 +3,7 @@ use std::sync::Weak; use async_trait::async_trait; use crate::{ + cli::LoggingHandle, executor::IpaRuntime, helpers::{ query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, @@ -65,6 +66,7 @@ struct Inner { /// the flamegraph mpc_transport: MpcTransportImpl, shard_transport: ShardTransportImpl, + logging_handle: LoggingHandle, } impl Setup { @@ -96,11 +98,13 @@ impl Setup { self, mpc_transport: MpcTransportImpl, shard_transport: ShardTransportImpl, + logging_handle: LoggingHandle, ) -> HelperApp { let app = Arc::new(Inner { query_processor: self.query_processor, mpc_transport, shard_transport, + logging_handle, }); self.mpc_handler .set_handler(Arc::downgrade(&app) as Weak>); @@ -277,6 +281,11 @@ impl RequestHandler for Inner { let query_id = ext_query_id(&req)?; HelperResponse::from(qp.kill(query_id)?) } + RouteId::Metrics => { + let logging_handler = &self.logging_handle; + let metrics_handle = &logging_handler.metrics_handle; + HelperResponse::from(metrics_handle.scrape_metrics()) + } }) } } diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 8622ea998..3fa554a6a 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -275,7 +275,7 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B Some(shard_handler), ); - let _app = setup.connect(transport.clone(), shard_transport.clone()); + let _app = setup.connect(transport.clone(), shard_transport.clone(), logging_handle); let listener = create_listener(args.server_socket_fd)?; let shard_listener = create_listener(args.shard_server_socket_fd)?; diff --git a/ipa-core/src/cli/metric_collector.rs b/ipa-core/src/cli/metric_collector.rs index 8f9a374b4..5b8f07b27 100644 --- a/ipa-core/src/cli/metric_collector.rs +++ b/ipa-core/src/cli/metric_collector.rs @@ -3,13 +3,14 @@ use std::{io, thread, thread::JoinHandle}; use ipa_metrics::{ MetricChannelType, MetricsCollectorController, MetricsCurrentThreadContext, MetricsProducer, }; +use ipa_metrics_prometheus::PrometheusMetricsExporter; use tokio::runtime::Builder; /// Holds a reference to metrics controller and producer pub struct CollectorHandle { thread_handle: JoinHandle<()>, /// This will be used once we start consuming metrics - _controller: MetricsCollectorController, + controller: MetricsCollectorController, producer: MetricsProducer, } @@ -26,7 +27,7 @@ pub fn install_collector() -> io::Result { Ok(CollectorHandle { thread_handle: handle, - _controller: controller, + controller, producer, }) } @@ -53,4 +54,17 @@ impl CollectorHandle { .on_thread_stop(flush_fn) .on_thread_park(flush_fn) } + + /// Export the metrics to be consumed by metrics scraper, e.g. Prometheus + /// + /// # Panics + /// If metrics is not initialized + #[must_use] + pub fn scrape_metrics(&self) -> Vec { + let mut store = self.controller.snapshot().expect("Metrics must be set up"); + let mut buff = Vec::new(); + store.export(&mut buff); + + buff + } } diff --git a/ipa-core/src/helpers/transport/handler.rs b/ipa-core/src/helpers/transport/handler.rs index a87e47ab5..9a1f1b457 100644 --- a/ipa-core/src/helpers/transport/handler.rs +++ b/ipa-core/src/helpers/transport/handler.rs @@ -149,6 +149,12 @@ impl> From for HelperResponse { } } +impl From> for HelperResponse { + fn from(value: Vec) -> Self { + Self { body: value } + } +} + /// Union of error types returned by API operations. #[derive(thiserror::Error, Debug)] pub enum Error { diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index a2a1abea6..504921eb5 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -116,7 +116,8 @@ impl InMemoryTransport { | RouteId::QueryInput | RouteId::QueryStatus | RouteId::CompleteQuery - | RouteId::KillQuery => { + | RouteId::KillQuery + | RouteId::Metrics => { handler .as_ref() .expect("Handler is set") diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index 6b8341966..00021c62c 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -229,6 +229,26 @@ where fn extra(&self) -> Self::Params; } +impl RouteParams for RouteId { + type Params = &'static str; + + fn resource_identifier(&self) -> RouteId { + *self + } + + fn query_id(&self) -> NoQueryId { + NoQueryId + } + + fn gate(&self) -> NoStep { + NoStep + } + + fn extra(&self) -> Self::Params { + "" + } +} + impl RouteParams for (QueryId, Gate) { type Params = &'static str; diff --git a/ipa-core/src/helpers/transport/routing.rs b/ipa-core/src/helpers/transport/routing.rs index c935704d4..138e496f9 100644 --- a/ipa-core/src/helpers/transport/routing.rs +++ b/ipa-core/src/helpers/transport/routing.rs @@ -24,6 +24,7 @@ pub enum RouteId { QueryStatus, CompleteQuery, KillQuery, + Metrics, } /// The header/metadata of the incoming request. diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index c7c3eb1a4..6d99498e0 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -68,6 +68,16 @@ pub mod echo { pub const AXUM_PATH: &str = "/echo"; } +pub mod metrics { + + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] + pub struct Request {} + + pub const AXUM_PATH: &str = "/metrics"; +} + pub mod query { use std::fmt::{Display, Formatter}; diff --git a/ipa-core/src/net/server/handlers/metrics.rs b/ipa-core/src/net/server/handlers/metrics.rs new file mode 100644 index 000000000..1c7d859b6 --- /dev/null +++ b/ipa-core/src/net/server/handlers/metrics.rs @@ -0,0 +1,62 @@ +use axum::{routing::get, Extension, Router}; +use hyper::StatusCode; + +use crate::{ + helpers::{routing::RouteId, BodyStream}, + net::{ + http_serde::{self}, + Error, MpcHttpTransport, + }, +}; + +/// Takes details from the HTTP request and creates a `[TransportCommand]::CreateQuery` that is sent +/// to the [`HttpTransport`]. +async fn handler(transport: Extension) -> Result, Error> { + match transport + .dispatch(RouteId::Metrics, BodyStream::empty()) + .await + { + Ok(resp) => Ok(resp.into_body()), + Err(err) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, err)), + } +} + +pub fn router(transport: MpcHttpTransport) -> Router { + Router::new() + .route(http_serde::metrics::AXUM_PATH, get(handler)) + .layer(Extension(transport)) +} + +#[cfg(all(test, unit_test))] +mod tests { + use axum::{ + body::Body, + http::uri::{self, Authority, Scheme}, + }; + + use super::*; + use crate::{ + helpers::{make_owned_handler, routing::Addr, HelperIdentity, HelperResponse}, + net::server::handlers::query::test_helpers::assert_success_with, + }; + + #[tokio::test] + async fn happy_case() { + let handler = make_owned_handler( + move |addr: Addr, _data: BodyStream| async move { + let RouteId::Metrics = addr.route else { + panic!("unexpected call"); + }; + Ok(HelperResponse::from(Vec::new())) + }, + ); + let uri = uri::Builder::new() + .scheme(Scheme::HTTP) + .authority(Authority::from_static("localhost")) + .path_and_query(String::from("/metrics")) + .build() + .unwrap(); + let req = hyper::Request::get(uri).body(Body::empty()).unwrap(); + assert_success_with(req, handler).await; + } +} diff --git a/ipa-core/src/net/server/handlers/mod.rs b/ipa-core/src/net/server/handlers/mod.rs index c8ab75875..2571a2be0 100644 --- a/ipa-core/src/net/server/handlers/mod.rs +++ b/ipa-core/src/net/server/handlers/mod.rs @@ -1,4 +1,5 @@ mod echo; +mod metrics; mod query; use axum::Router; @@ -9,12 +10,14 @@ use crate::{ }; pub fn mpc_router(transport: MpcHttpTransport) -> Router { - echo::router().nest( - http_serde::query::BASE_AXUM_PATH, - Router::new() - .merge(query::query_router(transport.clone())) - .merge(query::h2h_router(transport.inner_transport)), - ) + echo::router() + .merge(metrics::router(transport.clone())) + .nest( + http_serde::query::BASE_AXUM_PATH, + Router::new() + .merge(query::query_router(transport.clone())) + .merge(query::h2h_router(transport.inner_transport)), + ) } pub fn shard_router(transport: Arc>) -> Router { diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index 00441e6bf..12330f67a 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -23,6 +23,8 @@ use once_cell::sync::Lazy; use rustls_pki_types::CertificateDer; use super::{ConnectionFlavor, HttpTransport, Shard}; +#[cfg(all(test, web_test, descriptive_gate))] +use crate::cli::{install_collector, LoggingHandle}; use crate::{ config::{ ClientConfig, HpkeClientConfig, HpkeServerConfig, NetworkConfig, PeerConfig, ServerConfig, @@ -263,7 +265,11 @@ impl TestApp { shard_server.start_on(&IpaRuntime::current(), self.shard_server.socket.take(), ()), ) .await; - setup.connect(transport, shard_transport) + + let metrics_handle = install_collector().unwrap(); + let logging_handle = LoggingHandle { metrics_handle }; + + setup.connect(transport, shard_transport, logging_handle) } } diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index d8fbfc4b5..fda51ce70 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -115,7 +115,10 @@ impl HttpTransport { let req = serde_json::from_str(route.extra().borrow())?; self.clients[client_ix].status_match(req).await } - evt @ (RouteId::QueryInput | RouteId::ReceiveQuery | RouteId::KillQuery) => { + evt @ (RouteId::QueryInput + | RouteId::ReceiveQuery + | RouteId::KillQuery + | RouteId::Metrics) => { unimplemented!( "attempting to send client-specific request {evt:?} to another helper" ) diff --git a/ipa-core/src/test_fixture/app.rs b/ipa-core/src/test_fixture/app.rs index 1866cef74..1888b83e6 100644 --- a/ipa-core/src/test_fixture/app.rs +++ b/ipa-core/src/test_fixture/app.rs @@ -6,6 +6,7 @@ use typenum::Unsigned; use crate::{ app::AppConfig, + cli::{install_collector, LoggingHandle}, ff::Serializable, helpers::{ query::{QueryConfig, QueryInput}, @@ -68,8 +69,15 @@ impl Default for TestApp { let mpc_network = InMemoryMpcNetwork::new(handlers.map(Some)); let shard_network = InMemoryShardNetwork::with_shards(1); - let drivers = zip3(mpc_network.transports().each_ref(), setup) - .map(|(t, s)| s.connect(Clone::clone(t), shard_network.transport(t.identity(), 0))); + let drivers = zip3(mpc_network.transports().each_ref(), setup).map(|(t, s)| { + let metrics_handle = install_collector().unwrap(); + let logging_handle = LoggingHandle { metrics_handle }; + s.connect( + Clone::clone(t), + shard_network.transport(t.identity(), 0), + logging_handle, + ) + }); Self { drivers, diff --git a/ipa-metrics-prometheus/Cargo.toml b/ipa-metrics-prometheus/Cargo.toml new file mode 100644 index 000000000..36ce2be11 --- /dev/null +++ b/ipa-metrics-prometheus/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "ipa-metrics-prometheus" +version = "0.1.0" +edition = "2021" + +[features] +default = [] + +[dependencies] +ipa-metrics = { path = "../ipa-metrics" } + +# Open telemetry crates: opentelemetry-prometheus crate implementation is based on Opentelemetry API and SDK 0.23. (TBC) +opentelemetry = "0.24" +opentelemetry_sdk = { version = "0.24", features = ["metrics", "rt-tokio"] } +opentelemetry-prometheus = { version = "0.17" } +prometheus = "0.13.3" diff --git a/ipa-metrics-prometheus/src/exporter.rs b/ipa-metrics-prometheus/src/exporter.rs new file mode 100644 index 000000000..8ba17f905 --- /dev/null +++ b/ipa-metrics-prometheus/src/exporter.rs @@ -0,0 +1,83 @@ +use std::io; + +use ipa_metrics::MetricsStore; +use opentelemetry::{metrics::MeterProvider, KeyValue}; +use opentelemetry_sdk::metrics::SdkMeterProvider; +use prometheus::{self, Encoder, TextEncoder}; + +pub trait PrometheusMetricsExporter { + fn export(&mut self, w: &mut W); +} + +impl PrometheusMetricsExporter for MetricsStore { + fn export(&mut self, w: &mut W) { + // Setup prometheus registry and open-telemetry exporter + let registry = prometheus::Registry::new(); + + let exporter = opentelemetry_prometheus::exporter() + .with_registry(registry.clone()) + .build() + .unwrap(); + + let meter_provider = SdkMeterProvider::builder().with_reader(exporter).build(); + + // Convert the snapshot to otel struct + // TODO : We need to define a proper scope for the metrics + let meter = meter_provider.meter("ipa-helper"); + + let counters = self.counters(); + counters.for_each(|(counter_name, counter_value)| { + let otlp_counter = meter.u64_counter(counter_name.key).init(); + + let attributes: Vec = counter_name + .labels() + .map(|l| KeyValue::new(l.name, l.val.to_string())) + .collect(); + + otlp_counter.add(counter_value, &attributes[..]); + }); + + let encoder = TextEncoder::new(); + let metric_families = registry.gather(); + encoder.encode(&metric_families, w).unwrap(); + } +} + +#[cfg(test)] +mod test { + + use std::thread; + + use ipa_metrics::{counter, install_new_thread, MetricChannelType}; + + use super::PrometheusMetricsExporter; + + #[test] + fn export_to_prometheus() { + let (producer, controller, _) = install_new_thread(MetricChannelType::Rendezvous).unwrap(); + + thread::spawn(move || { + producer.install(); + counter!("baz", 4); + counter!("bar", 1); + let _ = producer.drop_handle(); + }) + .join() + .unwrap(); + + let mut store = controller.snapshot().unwrap(); + + let mut buff = Vec::new(); + store.export(&mut buff); + + let expected_result = "# TYPE bar_total counter +bar_total{otel_scope_name=\"ipa-helper\"} 1 +# TYPE baz_total counter +baz_total{otel_scope_name=\"ipa-helper\"} 4 +# HELP target_info Target metadata +# TYPE target_info gauge +target_info{service_name=\"unknown_service\",telemetry_sdk_language=\"rust\",telemetry_sdk_name=\"opentelemetry\",telemetry_sdk_version=\"0.24.1\"} 1\n"; + let result = String::from_utf8(buff).unwrap(); + assert_eq!(result, expected_result); + } +} diff --git a/ipa-metrics-prometheus/src/lib.rs b/ipa-metrics-prometheus/src/lib.rs new file mode 100644 index 000000000..b4cc164fc --- /dev/null +++ b/ipa-metrics-prometheus/src/lib.rs @@ -0,0 +1,3 @@ +mod exporter; + +pub use exporter::PrometheusMetricsExporter; diff --git a/ipa-metrics/Cargo.toml b/ipa-metrics/Cargo.toml index ebaeb9473..cc9ec52f0 100644 --- a/ipa-metrics/Cargo.toml +++ b/ipa-metrics/Cargo.toml @@ -17,4 +17,3 @@ hashbrown = "0.15" rustc-hash = "2.0.0" # logging tracing = "0.1" - diff --git a/ipa-metrics/src/partitioned.rs b/ipa-metrics/src/partitioned.rs index 0f71d0e28..3b2419fae 100644 --- a/ipa-metrics/src/partitioned.rs +++ b/ipa-metrics/src/partitioned.rs @@ -20,6 +20,7 @@ use hashbrown::hash_map::Entry; use rustc_hash::FxBuildHasher; use crate::{ + key::OwnedMetricName, kind::CounterValue, store::{CounterHandle, Store}, MetricName, @@ -120,6 +121,16 @@ impl PartitionedStore { self.get_mut(CurrentThreadContext::get()).counter(key) } + pub fn counters(&self) -> impl Iterator { + if let Some(partition) = CurrentThreadContext::get() { + return match self.inner.get(&partition) { + Some(store) => store.counters(), + None => self.default_store.counters(), + }; + } + self.default_store.counters() + } + #[must_use] pub fn len(&self) -> usize { self.inner.len() + self.default_store.len() From d6a244ee4d9faa25a2e4b1a747aa7d94328af7b1 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 18 Dec 2024 11:39:29 -0800 Subject: [PATCH 04/14] Split encryption and s3 presigned urls scripts (#1507) --- scripts/presigned-s3-urls.sh | 29 +++++++++++++ scripts/split-encrypted-files.py | 71 ++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100755 scripts/presigned-s3-urls.sh create mode 100644 scripts/split-encrypted-files.py diff --git a/scripts/presigned-s3-urls.sh b/scripts/presigned-s3-urls.sh new file mode 100755 index 000000000..3e0ff3295 --- /dev/null +++ b/scripts/presigned-s3-urls.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Set the usage message +usage="Usage: $0 " + +# Example invocation +# from ipa/input_data_S02/ +# ../scripts/presigned-s3-urls.sh encryptions/1B_cat/30_shards/ s3://stg-ipa-encrypted-reports/testing-sharded-data/1B/30_shards presigned_urls_30_shards.txt + +# Check if the correct number of arguments were provided +if [ $# -ne 3 ]; then + echo "$usage" + exit 1 +fi + +# Set the directory path and S3 URI from the command-line arguments +dir_path="$1" +s3_uri="$2" +output_file="$3" + +# Iterate over the files in the directory +for file in "$dir_path"/*; do + # Get the file name without the directory path + filename=$(basename "$file") + echo "Processing: $(basename "$file")" + # Call the aws s3 presign command and append the output to the output file + # expires in 14 days (14 * 24 * 60 * 60) + aws s3 presign "$s3_uri/$filename" --expires-in 1209600 >> "$output_file" +done diff --git a/scripts/split-encrypted-files.py b/scripts/split-encrypted-files.py new file mode 100644 index 000000000..900b5a5c4 --- /dev/null +++ b/scripts/split-encrypted-files.py @@ -0,0 +1,71 @@ +import argparse +import binascii +import os + +try: + from tqdm import tqdm +except ImportError: + print("tqdm not installed. run `pip install tqdm` to see progress") + + def tqdm(iterable, *args, **kwargs): + return iterable + + +def split_hex_file(input_filename, output_stem, num_files): + """ + Reads in a file of hex strings, one per line, splits it up into N files, + and writes out each line as length-delimited binary data. + + :param input_filename: The name of the input file containing hex strings. + :param num_files: The number of output files to split the input into. + """ + output_files = [ + open(f"{output_stem}_shard_{i:03d}.bin", "wb") for i in range(num_files) + ] + + input_filesize = os.path.getsize(input_filename) + # estimation each line is about 250 bits + approx_row_count = input_filesize / 250 + with open(input_filename, "r") as input_file: + for i, line in enumerate( + tqdm(input_file, desc="Processing lines", total=approx_row_count) + ): + # Remove any leading or trailing whitespace from the line + line = line.strip() + + # Convert the hex string to bytes + data = binascii.unhexlify(line) + + # Write the length of the data as a 2-byte integer (big-endian) + output_files[i % num_files].write(len(data).to_bytes(2, byteorder="little")) + + # Write the data itself + output_files[i % num_files].write(data) + + for f in output_files: + f.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Splits a file of hex strings into N length-delimited binary files" + ) + parser.add_argument( + "-i", "--input_file", required=True, help="Input file containing hex strings" + ) + parser.add_argument( + "-o", + "--output_stem", + required=True, + help="Output file stem for generated files", + ) + parser.add_argument( + "-n", + "--num-files", + type=int, + required=True, + help="Number of output files to split the input into", + ) + args = parser.parse_args() + + split_hex_file(args.input_file, args.output_stem, args.num_files) From 1cdc2f74f9ab801d6a3ff7520cf1cd82fa3eb42c Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Thu, 19 Dec 2024 11:53:37 -0800 Subject: [PATCH 05/14] s3 presigned urls hot fix (#1511) * update presigned url script expires-in flag to 7 days - max allowed by aws * update s3 script to use s3 ls instead of a local directory * typo * only get filenames from ls command * typo * different loop approach * typo * specify expires_in as hours not seconds --- scripts/presigned-s3-urls.sh | 44 +++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/scripts/presigned-s3-urls.sh b/scripts/presigned-s3-urls.sh index 3e0ff3295..26515be2e 100755 --- a/scripts/presigned-s3-urls.sh +++ b/scripts/presigned-s3-urls.sh @@ -1,29 +1,41 @@ #!/bin/bash # Set the usage message -usage="Usage: $0 " +usage="Usage: $0 []" # Example invocation -# from ipa/input_data_S02/ -# ../scripts/presigned-s3-urls.sh encryptions/1B_cat/30_shards/ s3://stg-ipa-encrypted-reports/testing-sharded-data/1B/30_shards presigned_urls_30_shards.txt +# ../scripts/presigned-s3-urls.sh s3://stg-ipa-encrypted-reports/testing-sharded-data/1B/30_shards presigned_urls_30_shards.txt 168 # Check if the correct number of arguments were provided -if [ $# -ne 3 ]; then +if [ $# -lt 2 ] || [ $# -gt 3 ]; then echo "$usage" exit 1 fi -# Set the directory path and S3 URI from the command-line arguments -dir_path="$1" -s3_uri="$2" -output_file="$3" +# Set and validate the S3 URI, output_file, and expires_in from the command-line arguments +s3_uri="$1" +output_file="$2" +# Check if the output file already exists +if [ -f "$output_file" ]; then + echo "Error: Output file '$output_file' already exists. Please remove it before running this script." + exit 1 +fi + +# default expires_in: 7 days (7 * 24). this is the max allowed +expires_in_hours="${1:-168}" +expires_in=$((expires_in_hours* 3600 - 1)) + +if [ $# -gt 604799 ]; then + echo "expires_in must be less than 168 hours" + exit 1 +fi -# Iterate over the files in the directory -for file in "$dir_path"/*; do - # Get the file name without the directory path - filename=$(basename "$file") - echo "Processing: $(basename "$file")" - # Call the aws s3 presign command and append the output to the output file - # expires in 14 days (14 * 24 * 60 * 60) - aws s3 presign "$s3_uri/$filename" --expires-in 1209600 >> "$output_file" +# Iterate over the files in the s3 bucket +aws s3 ls "$s3_uri" | awk '{print $4}' | while read -r line; do +# Skip directories (they end with a slash) + if [[ "$line" != */ ]]; then + echo "Processing: $(basename "$s3_uri""$line")" + # Call the aws s3 presign command and append the output to the output file + aws s3 presign "$s3_uri$line" --expires-in "$expires_in" >> "$output_file" + fi done From fc61e53f0e89a8f9a0b31eadd921bccd42dc1123 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 20 Dec 2024 10:58:43 -0800 Subject: [PATCH 06/14] Query input from URL (#1508) Support for helper to pull query input from a URL, rather than receiving it directly from the client in an HTTP request body. Also adds a simple HTTP server in test_mpc to serve local files, for testing purposes. --- ipa-core/Cargo.toml | 2 + ipa-core/src/app.rs | 24 +++-- ipa-core/src/bin/test_mpc.rs | 80 ++++++++++++++- ipa-core/src/cli/playbook/add.rs | 2 +- ipa-core/src/cli/playbook/hybrid.rs | 2 +- ipa-core/src/cli/playbook/ipa.rs | 2 +- ipa-core/src/cli/playbook/multiply.rs | 2 +- ipa-core/src/cli/playbook/sharded_shuffle.rs | 2 +- ipa-core/src/helpers/transport/query/mod.rs | 52 +++++++++- ipa-core/src/net/client/mod.rs | 2 +- ipa-core/src/net/http_serde.rs | 64 ++++++++++-- ipa-core/src/net/mod.rs | 2 + ipa-core/src/net/query_input.rs | 59 +++++++++++ .../src/net/server/handlers/query/input.rs | 99 +++++++++++++++---- ipa-core/src/net/transport.rs | 6 +- ipa-core/src/query/processor.rs | 39 +++----- ipa-core/src/query/state.rs | 10 +- ipa-core/src/test_fixture/app.rs | 2 +- 18 files changed, 373 insertions(+), 78 deletions(-) create mode 100644 ipa-core/src/net/query_input.rs diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 3268d1ddf..7d2216df6 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -41,6 +41,7 @@ web-app = [ "rustls", "rustls-pemfile", "time", + "tiny_http", "tokio-rustls", "toml", "tower", @@ -141,6 +142,7 @@ thiserror = "1.0" tikv-jemallocator = { version = "0.6", optional = true, features = ["profiling"] } tikv-jemalloc-ctl = { version = "0.6", optional = true, features = ["stats"] } time = { version = "0.3", optional = true } +tiny_http = { version = "0.12", optional = true } tokio = { version = "1.42", features = ["fs", "rt", "rt-multi-thread", "macros"] } tokio-rustls = { version = "0.26", optional = true } tokio-stream = "0.1.14" diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index b20a71c65..a87c253a0 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -140,12 +140,24 @@ impl HelperApp { /// /// ## Errors /// Propagates errors from the helper. + /// ## Panics + /// If `input` asks to obtain query input from a remote URL. pub fn execute_query(&self, input: QueryInput) -> Result<(), ApiError> { let mpc_transport = self.inner.mpc_transport.clone_ref(); let shard_transport = self.inner.shard_transport.clone_ref(); - self.inner - .query_processor - .receive_inputs(mpc_transport, shard_transport, input)?; + let QueryInput::Inline { + query_id, + input_stream, + } = input + else { + panic!("this client does not support pulling query input from a URL"); + }; + self.inner.query_processor.receive_inputs( + mpc_transport, + shard_transport, + query_id, + input_stream, + )?; Ok(()) } @@ -258,10 +270,8 @@ impl RequestHandler for Inner { HelperResponse::from(qp.receive_inputs( Transport::clone_ref(&self.mpc_transport), Transport::clone_ref(&self.shard_transport), - QueryInput { - query_id, - input_stream: data, - }, + query_id, + data, )?) } RouteId::QueryStatus => { diff --git a/ipa-core/src/bin/test_mpc.rs b/ipa-core/src/bin/test_mpc.rs index baf99a2ca..3735c4c18 100644 --- a/ipa-core/src/bin/test_mpc.rs +++ b/ipa-core/src/bin/test_mpc.rs @@ -1,4 +1,13 @@ -use std::{error::Error, fmt::Debug, ops::Add, path::PathBuf}; +use std::{ + error::Error, + fmt::Debug, + fs::File, + io::ErrorKind, + net::TcpListener, + ops::Add, + os::fd::{FromRawFd, RawFd}, + path::PathBuf, +}; use clap::{Parser, Subcommand}; use generic_array::ArrayLength; @@ -21,6 +30,8 @@ use ipa_core::{ net::{Helper, IpaHttpClient}, secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, }; +use tiny_http::{Response, ResponseBox, Server, StatusCode}; +use tracing::{error, info}; #[derive(Debug, Parser)] #[clap( @@ -95,6 +106,23 @@ enum TestAction { /// This is exactly what shuffle does and that's why it is picked /// for this purpose. ShardedShuffle, + ServeInput(ServeInputArgs), +} + +#[derive(Debug, clap::Args)] +#[clap(about = "Run a simple HTTP server to serve query input files")] +pub struct ServeInputArgs { + /// Port to listen on + #[arg(short, long)] + port: Option, + + /// Listen on the supplied prebound socket instead of binding a new socket + #[arg(long, conflicts_with = "port")] + fd: Option, + + /// Directory with input files to serve + #[arg(short, long = "dir")] + directory: PathBuf, } #[tokio::main] @@ -129,6 +157,7 @@ async fn main() -> Result<(), Box> { .await; sharded_shuffle(&args, clients).await } + TestAction::ServeInput(options) => serve_input(options), }; Ok(()) @@ -204,3 +233,52 @@ async fn sharded_shuffle(args: &Args, helper_clients: Vec<[IpaHttpClient assert_eq!(shuffled.len(), input_rows.len()); assert_ne!(shuffled, input_rows); } + +fn not_found() -> ResponseBox { + Response::from_string("not found") + .with_status_code(StatusCode(404)) + .boxed() +} + +#[tracing::instrument("serve_input", skip_all)] +fn serve_input(args: ServeInputArgs) { + let server = if let Some(port) = args.port { + Server::http(("localhost", port)).unwrap() + } else if let Some(fd) = args.fd { + Server::from_listener(unsafe { TcpListener::from_raw_fd(fd) }, None).unwrap() + } else { + Server::http("localhost:0").unwrap() + }; + + if args.port.is_none() { + info!( + "Listening on :{}", + server.server_addr().to_ip().unwrap().port() + ); + } + + loop { + let request = server.recv().unwrap(); + tracing::info!(target: "request_url", "{}", request.url()); + + let url = request.url()[1..].to_owned(); + let response = if url.contains('/') { + error!(target: "error", "Request URL contains a slash"); + not_found() + } else { + match File::open(args.directory.join(&url)) { + Ok(file) => Response::from_file(file).boxed(), + Err(err) => { + if err.kind() != ErrorKind::NotFound { + error!(target: "error", "{err}"); + } + not_found() + } + } + }; + + let _ = request.respond(response).map_err(|err| { + error!(target: "error", "{err}"); + }); + } +} diff --git a/ipa-core/src/cli/playbook/add.rs b/ipa-core/src/cli/playbook/add.rs index eafa1da8d..5c785bd1b 100644 --- a/ipa-core/src/cli/playbook/add.rs +++ b/ipa-core/src/cli/playbook/add.rs @@ -47,7 +47,7 @@ where .into_iter() .zip(clients) .map(|(input_stream, client)| { - client.query_input(QueryInput { + client.query_input(QueryInput::Inline { query_id, input_stream, }) diff --git a/ipa-core/src/cli/playbook/hybrid.rs b/ipa-core/src/cli/playbook/hybrid.rs index 92d74b383..53bfc6c28 100644 --- a/ipa-core/src/cli/playbook/hybrid.rs +++ b/ipa-core/src/cli/playbook/hybrid.rs @@ -44,7 +44,7 @@ where |(shard_clients, shard_inputs)| { try_join_all(shard_clients.iter().zip(shard_inputs.into_iter()).map( |(client, input)| { - client.query_input(QueryInput { + client.query_input(QueryInput::Inline { query_id, input_stream: input, }) diff --git a/ipa-core/src/cli/playbook/ipa.rs b/ipa-core/src/cli/playbook/ipa.rs index 6f3691306..b294d7695 100644 --- a/ipa-core/src/cli/playbook/ipa.rs +++ b/ipa-core/src/cli/playbook/ipa.rs @@ -115,7 +115,7 @@ where .into_iter() .zip(clients) .map(|(input_stream, client)| { - client.query_input(QueryInput { + client.query_input(QueryInput::Inline { query_id, input_stream, }) diff --git a/ipa-core/src/cli/playbook/multiply.rs b/ipa-core/src/cli/playbook/multiply.rs index ec777005a..265659c85 100644 --- a/ipa-core/src/cli/playbook/multiply.rs +++ b/ipa-core/src/cli/playbook/multiply.rs @@ -55,7 +55,7 @@ where .into_iter() .zip(clients) .map(|(input_stream, client)| { - client.query_input(QueryInput { + client.query_input(QueryInput::Inline { query_id, input_stream, }) diff --git a/ipa-core/src/cli/playbook/sharded_shuffle.rs b/ipa-core/src/cli/playbook/sharded_shuffle.rs index 0139a8171..fceb1f4e4 100644 --- a/ipa-core/src/cli/playbook/sharded_shuffle.rs +++ b/ipa-core/src/cli/playbook/sharded_shuffle.rs @@ -45,7 +45,7 @@ where let shared = chunk.iter().copied().share(); try_join_all(mpc_clients.each_ref().iter().zip(shared).map( |(mpc_client, input)| { - mpc_client.query_input(QueryInput { + mpc_client.query_input(QueryInput::Inline { query_id, input_stream: BodyStream::from_serializable_iter(input), }) diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index c1daa7c71..cd3e389d1 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -184,14 +184,58 @@ impl RouteParams for &PrepareQuery { } } -pub struct QueryInput { - pub query_id: QueryId, - pub input_stream: BodyStream, +pub enum QueryInput { + FromUrl { + query_id: QueryId, + url: String, + }, + Inline { + query_id: QueryId, + input_stream: BodyStream, + }, +} + +impl QueryInput { + #[must_use] + pub fn query_id(&self) -> QueryId { + match self { + Self::FromUrl { query_id, .. } | Self::Inline { query_id, .. } => *query_id, + } + } + + #[must_use] + pub fn input_stream(self) -> Option { + match self { + Self::Inline { input_stream, .. } => Some(input_stream), + Self::FromUrl { .. } => None, + } + } + + #[must_use] + pub fn url(&self) -> Option<&str> { + match self { + Self::FromUrl { url, .. } => Some(url), + Self::Inline { .. } => None, + } + } } impl Debug for QueryInput { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "query_inputs[{:?}]", self.query_id) + match self { + QueryInput::Inline { + query_id, + input_stream: _, + } => f + .debug_struct("QueryInput::Inline") + .field("query_id", query_id) + .finish(), + QueryInput::FromUrl { query_id, url } => f + .debug_struct("QueryInput::FromUrl") + .field("query_id", query_id) + .field("url", url) + .finish(), + } } } diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index d4789b198..42d7c1377 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -726,7 +726,7 @@ pub(crate) mod tests { }; test_query_command( |client| async move { - let data = QueryInput { + let data = QueryInput::Inline { query_id: expected_query_id, input_stream: expected_input.to_vec().into(), }; diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 6d99498e0..be2bc6d83 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -322,12 +322,23 @@ pub mod query { } pub mod input { - use axum::{body::Body, http::uri}; - use hyper::header::CONTENT_TYPE; + use axum::{ + async_trait, + body::Body, + extract::FromRequestParts, + http::{request::Parts, uri}, + }; + use hyper::{ + header::{HeaderValue, CONTENT_TYPE}, + Uri, + }; use crate::{ helpers::query::QueryInput, - net::{http_serde::query::BASE_AXUM_PATH, APPLICATION_OCTET_STREAM}, + net::{ + http_serde::query::BASE_AXUM_PATH, Error, APPLICATION_OCTET_STREAM, + HTTP_QUERY_INPUT_URL_HEADER, + }, }; #[derive(Debug)] @@ -351,17 +362,54 @@ pub mod query { .path_and_query(format!( "{}/{}/input", BASE_AXUM_PATH, - self.query_input.query_id.as_ref(), + self.query_input.query_id().as_ref(), )) .build()?; - let body = Body::from_stream(self.query_input.input_stream); - Ok(hyper::Request::post(uri) - .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) - .body(body)?) + let query_input_url = self.query_input.url().map(ToOwned::to_owned); + let body = self + .query_input + .input_stream() + .map_or_else(Body::empty, Body::from_stream); + let mut request = + hyper::Request::post(uri).header(CONTENT_TYPE, APPLICATION_OCTET_STREAM); + if let Some(url) = query_input_url { + request.headers_mut().unwrap().insert( + &HTTP_QUERY_INPUT_URL_HEADER, + HeaderValue::try_from(url).unwrap(), + ); + } + Ok(request.body(body)?) } } pub const AXUM_PATH: &str = "/:query_id/input"; + + pub struct QueryInputUrl(Option); + + #[async_trait] + impl FromRequestParts for QueryInputUrl { + type Rejection = Error; + + async fn from_request_parts( + req: &mut Parts, + _state: &S, + ) -> Result { + match req.headers.get(&HTTP_QUERY_INPUT_URL_HEADER) { + None => Ok(QueryInputUrl(None)), + Some(value) => { + let value_str = value.to_str()?; + let uri = value_str.parse()?; + Ok(QueryInputUrl(Some(uri))) + } + } + } + } + + impl From for Option { + fn from(value: QueryInputUrl) -> Self { + value.0 + } + } } pub mod step { diff --git a/ipa-core/src/net/mod.rs b/ipa-core/src/net/mod.rs index 621365077..5bce44553 100644 --- a/ipa-core/src/net/mod.rs +++ b/ipa-core/src/net/mod.rs @@ -18,6 +18,7 @@ use crate::{ mod client; mod error; mod http_serde; +pub mod query_input; mod server; #[cfg(all(test, not(feature = "shuttle")))] pub mod test; @@ -32,6 +33,7 @@ const APPLICATION_JSON: &str = "application/json"; const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; static HTTP_HELPER_ID_HEADER: HeaderName = HeaderName::from_static("x-unverified-helper-identity"); static HTTP_SHARD_INDEX_HEADER: HeaderName = HeaderName::from_static("x-unverified-shard-index"); +static HTTP_QUERY_INPUT_URL_HEADER: HeaderName = HeaderName::from_static("x-query-input-url"); /// This has the same meaning as const defined in h2 crate, but we don't import it directly. /// According to the [`spec`] it cannot exceed 2^31 - 1. diff --git a/ipa-core/src/net/query_input.rs b/ipa-core/src/net/query_input.rs new file mode 100644 index 000000000..f2608392f --- /dev/null +++ b/ipa-core/src/net/query_input.rs @@ -0,0 +1,59 @@ +use axum::{body::Body, BoxError}; +use http_body_util::BodyExt; +use hyper::Uri; +use hyper_rustls::HttpsConnectorBuilder; +use hyper_util::{ + client::legacy::Client, + rt::{TokioExecutor, TokioTimer}, +}; + +use crate::{helpers::BodyStream, net::Error}; + +/// Connect to a remote URL to download query input. +/// +/// # Errors +/// If the connection to the remote URL fails or returns an HTTP error. +/// +/// # Panics +/// If unable to create an HTTPS client using the system truststore. +pub async fn stream_query_input_from_url(uri: &Uri) -> Result { + let mut builder = Client::builder(TokioExecutor::new()); + // the following timer is necessary for http2, in particular for any timeouts + // and waits the clients will need to make + // TODO: implement IpaTimer to allow wrapping other than Tokio runtimes + builder.timer(TokioTimer::new()); + let client = builder.build::<_, Body>( + HttpsConnectorBuilder::default() + .with_native_roots() + .expect("System truststore is required") + .https_only() + .enable_all_versions() + .build(), + ); + + let resp = client + .get(uri.clone()) + .await + .map_err(|inner| Error::ConnectError { + dest: uri.to_string(), + inner, + })?; + + if !resp.status().is_success() { + let status = resp.status(); + assert!(status.is_client_error() || status.is_server_error()); // must be failure + return Err( + axum::body::to_bytes(Body::new(resp.into_body()), 36_000_000) // Roughly 36mb + .await + .map_or_else(Into::into, |reason_bytes| Error::FailedHttpRequest { + dest: uri.to_string(), + status, + reason: String::from_utf8_lossy(&reason_bytes).to_string(), + }), + ); + } + + Ok(BodyStream::from_bytes_stream( + resp.into_body().map_err(BoxError::from).into_data_stream(), + )) +} diff --git a/ipa-core/src/net/server/handlers/query/input.rs b/ipa-core/src/net/server/handlers/query/input.rs index da47e9386..1de6ea570 100644 --- a/ipa-core/src/net/server/handlers/query/input.rs +++ b/ipa-core/src/net/server/handlers/query/input.rs @@ -2,25 +2,29 @@ use axum::{extract::Path, routing::post, Extension, Router}; use hyper::StatusCode; use crate::{ - helpers::{query::QueryInput, routing::RouteId, BodyStream}, - net::{http_serde, transport::MpcHttpTransport, Error}, + helpers::{routing::RouteId, BodyStream}, + net::{ + http_serde::{self, query::input::QueryInputUrl}, + query_input::stream_query_input_from_url, + transport::MpcHttpTransport, + Error, + }, protocol::QueryId, }; async fn handler( transport: Extension, Path(query_id): Path, + input_url: QueryInputUrl, input_stream: BodyStream, ) -> Result<(), Error> { - let query_input = QueryInput { - query_id, - input_stream, + let input_stream = if let Some(url) = input_url.into() { + stream_query_input_from_url(&url).await? + } else { + input_stream }; let _ = transport - .dispatch( - (RouteId::QueryInput, query_input.query_id), - query_input.input_stream, - ) + .dispatch((RouteId::QueryInput, query_id), input_stream) .await .map_err(|e| Error::application(StatusCode::INTERNAL_SERVER_ERROR, e))?; @@ -35,10 +39,15 @@ pub fn router(transport: MpcHttpTransport) -> Router { #[cfg(all(test, unit_test))] mod tests { + use std::thread; + use axum::{ body::Body, http::uri::{Authority, Scheme}, }; + use bytes::BytesMut; + use futures::TryStreamExt; + use http_body_util::BodyExt; use hyper::StatusCode; use tokio::runtime::Handle; @@ -49,24 +58,22 @@ mod tests { net::{ http_serde, server::handlers::query::test_helpers::{assert_fails_with, assert_success_with}, + test::TestServer, }, protocol::QueryId, }; #[tokio::test(flavor = "multi_thread")] - async fn input_test() { - let expected_query_id = QueryId; + async fn input_inline() { + const QUERY_ID: QueryId = QueryId; let expected_input = &[4u8; 4]; - let req = http_serde::query::input::Request::new(QueryInput { - query_id: expected_query_id, - input_stream: expected_input.to_vec().into(), - }); + let req_handler = make_owned_handler(move |addr, data| async move { let RouteId::QueryInput = addr.route else { panic!("unexpected call"); }; - assert_eq!(addr.query_id, Some(expected_query_id)); + assert_eq!(addr.query_id, Some(QUERY_ID)); assert_eq!( tokio::task::block_in_place(move || { Handle::current().block_on(async move { data.to_vec().await }) @@ -76,10 +83,66 @@ mod tests { Ok(HelperResponse::ok()) }); - let req = req + + let req = http_serde::query::input::Request::new(QueryInput::Inline { + query_id: QUERY_ID, + input_stream: expected_input.to_vec().into(), + }); + let hyper_req = req .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) .unwrap(); - assert_success_with(req, req_handler).await; + + assert_success_with(hyper_req, req_handler).await; + } + + #[tokio::test(flavor = "multi_thread")] + async fn input_from_url() { + const QUERY_ID: QueryId = QueryId; + const DATA: &str = "input records"; + + let server = tiny_http::Server::http("localhost:0").unwrap(); + let addr = server.server_addr(); + thread::spawn(move || { + let request = server.recv().unwrap(); + let response = tiny_http::Response::from_string(DATA); + request.respond(response).unwrap(); + }); + + let req_handler = make_owned_handler(move |addr, body| async move { + let RouteId::QueryInput = addr.route else { + panic!("unexpected call"); + }; + + assert_eq!(addr.query_id, Some(QUERY_ID)); + assert_eq!(body.try_collect::().await.unwrap(), DATA); + + Ok(HelperResponse::ok()) + }); + let test_server = TestServer::builder() + .with_request_handler(req_handler) + .build() + .await; + + let url = format!( + "http://localhost:{}{}/{QUERY_ID}/input", + addr.to_ip().unwrap().port(), + http_serde::query::BASE_AXUM_PATH, + ); + let req = http_serde::query::input::Request::new(QueryInput::FromUrl { + query_id: QUERY_ID, + url, + }); + let hyper_req = req + .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) + .unwrap(); + + let resp = test_server.server.handle_req(hyper_req).await; + if !resp.status().is_success() { + let (head, body) = resp.into_parts(); + let body_bytes = body.collect().await.unwrap().to_bytes(); + let body = String::from_utf8_lossy(&body_bytes); + panic!("{head:?}\n{body}"); + } } struct OverrideReq { diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index fda51ce70..599024271 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -577,7 +577,7 @@ mod tests { let mut handle_resps = Vec::with_capacity(helper_shares.len()); for (i, input_stream) in helper_shares.into_iter().enumerate() { - let data = QueryInput { + let data = QueryInput::Inline { query_id, input_stream, }; @@ -589,7 +589,7 @@ mod tests { // convention - first client is shard leader, and we submitted the inputs to it. try_join_all(clients.iter().skip(1).map(|ring| { try_join_all(ring.each_ref().map(|shard_client| { - shard_client.query_input(QueryInput { + shard_client.query_input(QueryInput::Inline { query_id, input_stream: BodyStream::empty(), }) @@ -641,7 +641,7 @@ mod tests { |(helper, shard_streams)| async move { try_join_all(shard_streams.into_iter().enumerate().map( |(shard, input_stream)| { - clients[shard][helper].query_input(QueryInput { + clients[shard][helper].query_input(QueryInput::Inline { query_id, input_stream, }) diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index df392b44f..95cfa0f44 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -11,10 +11,10 @@ use crate::{ error::Error as ProtocolError, executor::IpaRuntime, helpers::{ - query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, + query::{CompareStatusRequest, PrepareQuery, QueryConfig}, routing::RouteId, - BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, - RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, + BodyStream, BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, + Role, RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, }, hpke::{KeyRegistry, PrivateKeyOnly}, protocol::QueryId, @@ -213,7 +213,7 @@ impl Processor { // to rollback 1,2 and 3 shard_transport.broadcast(prepare_request.clone()).await?; - handle.set_state(QueryState::AwaitingInputs(query_id, req, roles))?; + handle.set_state(QueryState::AwaitingInputs(req, roles))?; guard.restore(); Ok(prepare_request) @@ -249,11 +249,7 @@ impl Processor { // TODO: If shards 1,2 and 3 succeed but 4 fails, then we need to rollback 1,2 and 3. shard_transport.broadcast(req.clone()).await?; - handle.set_state(QueryState::AwaitingInputs( - req.query_id, - req.config, - req.roles, - ))?; + handle.set_state(QueryState::AwaitingInputs(req.config, req.roles))?; Ok(()) } @@ -280,11 +276,7 @@ impl Processor { return Err(PrepareQueryError::AlreadyRunning); } - handle.set_state(QueryState::AwaitingInputs( - req.query_id, - req.config, - req.roles, - ))?; + handle.set_state(QueryState::AwaitingInputs(req.config, req.roles))?; Ok(()) } @@ -300,17 +292,14 @@ impl Processor { &self, mpc_transport: MpcTransportImpl, shard_transport: ShardTransportImpl, - input: QueryInput, + query_id: QueryId, + input_stream: BodyStream, ) -> Result<(), QueryInputError> { let mut queries = self.queries.inner.lock().unwrap(); - match queries.entry(input.query_id) { + match queries.entry(query_id) { Entry::Occupied(entry) => { let state = entry.remove(); - if let QueryState::AwaitingInputs(query_id, config, role_assignment) = state { - assert_eq!( - input.query_id, query_id, - "received inputs for a different query" - ); + if let QueryState::AwaitingInputs(config, role_assignment) = state { let mut gateway_config = GatewayConfig::default(); if let Some(active_work) = self.active_work { gateway_config.active = active_work; @@ -325,13 +314,13 @@ impl Processor { shard_transport, ); queries.insert( - input.query_id, + query_id, QueryState::Running(executor::execute( &self.runtime, config, Arc::clone(&self.key_registry), gateway, - input.input_stream, + input_stream, )), ); Ok(()) @@ -340,11 +329,11 @@ impl Processor { from: QueryStatus::from(&state), to: QueryStatus::Running, }; - queries.insert(input.query_id, state); + queries.insert(query_id, state); Err(QueryInputError::StateError { source: error }) } } - Entry::Vacant(_) => Err(QueryInputError::NoSuchQuery(input.query_id)), + Entry::Vacant(_) => Err(QueryInputError::NoSuchQuery(query_id)), } } diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index bca5c7e1d..148a40565 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -46,7 +46,7 @@ impl From<&QueryState> for QueryStatus { match source { QueryState::Empty => panic!("Query cannot be in the empty state"), QueryState::Preparing(_) => QueryStatus::Preparing, - QueryState::AwaitingInputs(_, _, _) => QueryStatus::AwaitingInputs, + QueryState::AwaitingInputs(_, _) => QueryStatus::AwaitingInputs, QueryState::Running(_) => QueryStatus::Running, QueryState::AwaitingCompletion => QueryStatus::AwaitingCompletion, QueryState::Completed(_) => QueryStatus::Completed, @@ -78,7 +78,7 @@ pub fn min_status(a: QueryStatus, b: QueryStatus) -> QueryStatus { pub enum QueryState { Empty, Preparing(QueryConfig), - AwaitingInputs(QueryId, QueryConfig, RoleAssignment), + AwaitingInputs(QueryConfig, RoleAssignment), Running(RunningQuery), AwaitingCompletion, Completed(QueryResult), @@ -91,9 +91,9 @@ impl QueryState { match (cur_state, &new_state) { // If query is not running, coordinator initial state is preparing // and followers initial state is awaiting inputs - (Empty, Preparing(_) | AwaitingInputs(_, _, _)) - | (Preparing(_), AwaitingInputs(_, _, _)) - | (AwaitingInputs(_, _, _), Running(_)) => Ok(new_state), + (Empty, Preparing(_) | AwaitingInputs(_, _)) + | (Preparing(_), AwaitingInputs(_, _)) + | (AwaitingInputs(_, _), Running(_)) => Ok(new_state), (_, Preparing(_)) => Err(StateError::AlreadyRunning), (_, _) => Err(StateError::InvalidState { from: cur_state.into(), diff --git a/ipa-core/src/test_fixture/app.rs b/ipa-core/src/test_fixture/app.rs index 1888b83e6..7ea7957cf 100644 --- a/ipa-core/src/test_fixture/app.rs +++ b/ipa-core/src/test_fixture/app.rs @@ -112,7 +112,7 @@ impl TestApp { .into_iter() .enumerate() .map(|(i, input)| { - self.drivers[i].execute_query(QueryInput { + self.drivers[i].execute_query(QueryInput::Inline { query_id, input_stream: input.into(), }) From c49d3f9b83faf74a7429444892973405aaedd22c Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Fri, 20 Dec 2024 11:32:31 -0800 Subject: [PATCH 07/14] add flag to encrypt_util for length delimited binary format (#1509) --- ipa-core/src/cli/crypto/hybrid_decrypt.rs | 22 +++- ipa-core/src/cli/crypto/hybrid_encrypt.rs | 148 ++++++++++++++++++---- ipa-core/src/cli/crypto/mod.rs | 11 +- 3 files changed, 144 insertions(+), 37 deletions(-) diff --git a/ipa-core/src/cli/crypto/hybrid_decrypt.rs b/ipa-core/src/cli/crypto/hybrid_decrypt.rs index be705e062..285df9e7e 100644 --- a/ipa-core/src/cli/crypto/hybrid_decrypt.rs +++ b/ipa-core/src/cli/crypto/hybrid_decrypt.rs @@ -226,9 +226,14 @@ mod tests { let output_dir = tempdir().unwrap(); let network_file = hybrid_sample_data::test_keys().network_config(); - HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) - .encrypt() - .unwrap(); + HybridEncryptArgs::new( + input_file.path(), + output_dir.path(), + network_file.path(), + false, + ) + .encrypt() + .unwrap(); let decrypt_output = output_dir.path().join("output"); let enc1 = output_dir.path().join("DOES_NOT_EXIST.enc"); @@ -258,9 +263,14 @@ mod tests { let network_file = hybrid_sample_data::test_keys().network_config(); let output_dir = tempdir().unwrap(); - HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) - .encrypt() - .unwrap(); + HybridEncryptArgs::new( + input_file.path(), + output_dir.path(), + network_file.path(), + false, + ) + .encrypt() + .unwrap(); let decrypt_output = output_dir.path().join("output"); let enc1 = output_dir.path().join("helper1.enc"); diff --git a/ipa-core/src/cli/crypto/hybrid_encrypt.rs b/ipa-core/src/cli/crypto/hybrid_encrypt.rs index 1a60d6944..c90bc16e5 100644 --- a/ipa-core/src/cli/crypto/hybrid_encrypt.rs +++ b/ipa-core/src/cli/crypto/hybrid_encrypt.rs @@ -52,15 +52,38 @@ pub struct HybridEncryptArgs { /// Path to helper network configuration file #[arg(long)] network: PathBuf, + /// a flag to produce length delimited binary instead of newline delimited hex + #[arg(long)] + length_delimited: bool, +} + +#[derive(Copy, Clone)] +enum FileFormat { + LengthDelimitedBinary, + NewlineDelimitedHex, } impl HybridEncryptArgs { #[must_use] - pub fn new(input_file: &Path, output_dir: &Path, network: &Path) -> Self { + pub fn new( + input_file: &Path, + output_dir: &Path, + network: &Path, + length_delimited: bool, + ) -> Self { Self { input_file: input_file.to_path_buf(), output_dir: output_dir.to_path_buf(), network: network.to_path_buf(), + length_delimited, + } + } + + fn file_format(&self) -> FileFormat { + if self.length_delimited { + FileFormat::LengthDelimitedBinary + } else { + FileFormat::NewlineDelimitedHex } } @@ -89,7 +112,8 @@ impl HybridEncryptArgs { panic!("could not load network file") }; - let mut worker_pool = ReportWriter::new(key_registries, &self.output_dir); + let mut worker_pool = + ReportWriter::new(key_registries, &self.output_dir, self.file_format()); for (report_id, record) in input.iter::().enumerate() { worker_pool.submit(report_id, record.share())?; } @@ -118,6 +142,7 @@ impl EncryptorPool { thread_count: usize, file_writer: [SyncSender; 3], key_registries: [KeyRegistry; 3], + file_format: FileFormat, ) -> Self { Self { pool: (0..thread_count) @@ -132,11 +157,23 @@ impl EncryptorPool { .spawn(move || { for (i, helper_id, report) in rx { let key_registry = &key_registries[helper_id]; - let output = report.encrypt( - DEFAULT_KEY_ID, - key_registry, - &mut thread_rng(), - )?; + let mut output = + Vec::with_capacity(usize::from(report.encrypted_len() + 2)); + match file_format { + FileFormat::NewlineDelimitedHex => report.encrypt_to( + DEFAULT_KEY_ID, + key_registry, + &mut thread_rng(), + &mut output, + )?, + FileFormat::LengthDelimitedBinary => report + .delimited_encrypt_to( + DEFAULT_KEY_ID, + key_registry, + &mut thread_rng(), + &mut output, + )?, + } file_writer[helper_id].send((i, output))?; } @@ -178,7 +215,11 @@ struct ReportWriter { } impl ReportWriter { - pub fn new(key_registries: [KeyRegistry; 3], output_dir: &Path) -> Self { + pub fn new( + key_registries: [KeyRegistry; 3], + output_dir: &Path, + file_format: FileFormat, + ) -> Self { // create 3 worker threads to write data into 3 files let workers = array::from_fn(|i| { let output_filename = format!("helper{}.enc", i + 1); @@ -188,12 +229,13 @@ impl ReportWriter { .open(output_dir.join(&output_filename)) .unwrap_or_else(|e| panic!("unable write to {:?}. {}", &output_filename, e)); - FileWriteWorker::new(file) + FileWriteWorker::new(file, file_format) }); let encryptor_pool = EncryptorPool::with_worker_threads( num_cpus::get(), workers.each_ref().map(|x| x.sender.clone()), key_registries, + file_format, ); Self { @@ -239,17 +281,26 @@ struct FileWriteWorker { } impl FileWriteWorker { - pub fn new(file: File) -> Self { + pub fn new(file: File, file_format: FileFormat) -> Self { + fn write_report( + writer: &mut W, + report: &[u8], + file_format: FileFormat, + ) -> Result<(), BoxError> { + match file_format { + FileFormat::LengthDelimitedBinary => { + FileWriteWorker::write_report_length_delimited_binary(writer, report) + } + FileFormat::NewlineDelimitedHex => { + FileWriteWorker::write_report_newline_delimited_hex(writer, report) + } + } + } + let (tx, rx) = std::sync::mpsc::sync_channel(65535); Self { sender: tx, handle: thread::spawn(move || { - fn write_report(writer: &mut W, report: &[u8]) -> Result<(), BoxError> { - let hex_output = hex::encode(report); - writeln!(writer, "{hex_output}")?; - Ok(()) - } - // write low watermark. All reports below this line have been written let mut lw = 0; let mut pending_reports = BTreeMap::new(); @@ -271,7 +322,7 @@ impl FileWriteWorker { "Internal error: received a duplicate report {report_id}" ); while let Some(report) = pending_reports.remove(&lw) { - write_report(&mut writer, &report)?; + write_report(&mut writer, &report, file_format)?; lw += 1; if lw % 1_000_000 == 0 { tracing::info!("Encrypted {}M reports", lw / 1_000_000); @@ -282,6 +333,23 @@ impl FileWriteWorker { }), } } + + fn write_report_newline_delimited_hex( + writer: &mut W, + report: &[u8], + ) -> Result<(), BoxError> { + let hex_output = hex::encode(report); + writeln!(writer, "{hex_output}")?; + Ok(()) + } + + fn write_report_length_delimited_binary( + writer: &mut W, + report: &[u8], + ) -> Result<(), BoxError> { + writer.write_all(report)?; + Ok(()) + } } #[cfg(all(test, unit_test))] @@ -334,12 +402,26 @@ mod tests { } input_file.flush().unwrap(); - let output_dir = tempdir().unwrap(); + let output_dir_1 = tempdir().unwrap(); + let output_dir_2 = tempdir().unwrap(); let network_file = sample_data::test_keys().network_config(); - HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) - .encrypt() - .unwrap(); + HybridEncryptArgs::new( + input_file.path(), + output_dir_1.path(), + network_file.path(), + false, + ) + .encrypt() + .unwrap(); + HybridEncryptArgs::new( + input_file.path(), + output_dir_2.path(), + network_file.path(), + true, + ) + .encrypt() + .unwrap(); } #[test] @@ -350,7 +432,7 @@ mod tests { let output_dir = tempdir().unwrap(); let network_dir = tempdir().unwrap(); let network_file = network_dir.path().join("does_not_exist"); - HybridEncryptArgs::new(input_file.path(), output_dir.path(), &network_file) + HybridEncryptArgs::new(input_file.path(), output_dir.path(), &network_file, true) .encrypt() .unwrap(); } @@ -368,9 +450,14 @@ this is not toml! let mut network_file = NamedTempFile::new().unwrap(); writeln!(network_file.as_file_mut(), "{network_data}").unwrap(); - HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) - .encrypt() - .unwrap(); + HybridEncryptArgs::new( + input_file.path(), + output_dir.path(), + network_file.path(), + true, + ) + .encrypt() + .unwrap(); } #[test] @@ -392,8 +479,13 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" let mut network_file = NamedTempFile::new().unwrap(); writeln!(network_file.as_file_mut(), "{network_data}").unwrap(); - HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) - .encrypt() - .unwrap(); + HybridEncryptArgs::new( + input_file.path(), + output_dir.path(), + network_file.path(), + true, + ) + .encrypt() + .unwrap(); } } diff --git a/ipa-core/src/cli/crypto/mod.rs b/ipa-core/src/cli/crypto/mod.rs index ac0fd1a06..47180efef 100644 --- a/ipa-core/src/cli/crypto/mod.rs +++ b/ipa-core/src/cli/crypto/mod.rs @@ -345,9 +345,14 @@ mod tests { let input = hybrid_sample_data::test_hybrid_data().take(10); let input_file = hybrid_sample_data::write_csv(input).unwrap(); let network_file = hybrid_sample_data::test_keys().network_config(); - HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) - .encrypt() - .unwrap(); + HybridEncryptArgs::new( + input_file.path(), + output_dir.path(), + network_file.path(), + false, + ) + .encrypt() + .unwrap(); let decrypt_output = output_dir.path().join("output"); let enc1 = output_dir.path().join("helper1.enc"); From 4d412748070e28826ae5f119ed2585eab45e6375 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 20 Dec 2024 12:08:53 -0800 Subject: [PATCH 08/14] Fix input_from_url test --- ipa-core/src/net/error.rs | 2 +- ipa-core/src/net/query_input.rs | 2 +- ipa-core/src/net/server/handlers/query/input.rs | 5 ++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ipa-core/src/net/error.rs b/ipa-core/src/net/error.rs index 82426b402..f137c3232 100644 --- a/ipa-core/src/net/error.rs +++ b/ipa-core/src/net/error.rs @@ -54,7 +54,7 @@ pub enum Error { status: hyper::StatusCode, reason: String, }, - #[error("Failed to connect to {dest}: {inner}")] + #[error("Failed to connect to {dest}: {inner:?}")] ConnectError { dest: String, #[source] diff --git a/ipa-core/src/net/query_input.rs b/ipa-core/src/net/query_input.rs index f2608392f..7a71b6f77 100644 --- a/ipa-core/src/net/query_input.rs +++ b/ipa-core/src/net/query_input.rs @@ -26,7 +26,7 @@ pub async fn stream_query_input_from_url(uri: &Uri) -> Result HttpsConnectorBuilder::default() .with_native_roots() .expect("System truststore is required") - .https_only() + .https_or_http() .enable_all_versions() .build(), ); diff --git a/ipa-core/src/net/server/handlers/query/input.rs b/ipa-core/src/net/server/handlers/query/input.rs index 1de6ea570..fa97b124b 100644 --- a/ipa-core/src/net/server/handlers/query/input.rs +++ b/ipa-core/src/net/server/handlers/query/input.rs @@ -98,7 +98,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn input_from_url() { const QUERY_ID: QueryId = QueryId; - const DATA: &str = "input records"; + const DATA: &str = ""; let server = tiny_http::Server::http("localhost:0").unwrap(); let addr = server.server_addr(); @@ -124,9 +124,8 @@ mod tests { .await; let url = format!( - "http://localhost:{}{}/{QUERY_ID}/input", + "http://localhost:{}/input-data", addr.to_ip().unwrap().port(), - http_serde::query::BASE_AXUM_PATH, ); let req = http_serde::query::input::Request::new(QueryInput::FromUrl { query_id: QUERY_ID, From e0efbf816ac3f4544795483ba3acf88a33b8fb5e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 20 Dec 2024 15:08:06 -0700 Subject: [PATCH 09/14] Support Query input from URL on report collector side + integration test (#1510) * url_file_list parameter for report collector * e2e integration test for upload * Clippy * Feedback --- ipa-core/src/bin/report_collector.rs | 153 ++++++++++++++---- ipa-core/src/cli/playbook/hybrid.rs | 27 ++-- ipa-core/tests/hybrid.rs | 227 ++++++++++++++++++++++++++- 3 files changed, 356 insertions(+), 51 deletions(-) diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index fb41fcb9a..1a0d6755c 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -4,13 +4,14 @@ use std::{ fmt::Debug, fs::{File, OpenOptions}, io, - io::{stdout, BufReader, Write}, + io::{stdout, BufRead, BufReader, Write}, + iter::zip, ops::Deref, path::{Path, PathBuf}, }; use clap::{Parser, Subcommand}; -use hyper::http::uri::Scheme; +use hyper::{http::uri::Scheme, Uri}; use ipa_core::{ cli::{ playbook::{ @@ -24,11 +25,13 @@ use ipa_core::{ ff::{boolean_array::BA32, FieldType}, helpers::{ query::{ - DpMechanism, HybridQueryParams, IpaQueryConfig, QueryConfig, QuerySize, QueryType, + DpMechanism, HybridQueryParams, IpaQueryConfig, QueryConfig, QueryInput, QuerySize, + QueryType, }, BodyStream, }, net::{Helper, IpaHttpClient}, + protocol::QueryId, report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID}, test_fixture::{ ipa::{ipa_in_the_clear, CappingOrder, IpaSecurityModel, TestRawDataRecord}, @@ -143,7 +146,14 @@ enum ReportCollectorCommand { }, MaliciousHybrid { #[clap(flatten)] - encrypted_inputs: EncryptedInputs, + encrypted_inputs: Option, + + #[arg( + long, + help = "Read the list of URLs that contain the input from the provided file", + conflicts_with_all = ["enc_input_file1", "enc_input_file2", "enc_input_file3"] + )] + url_file_list: Option, #[clap(flatten)] hybrid_query_config: HybridQueryParams, @@ -267,6 +277,7 @@ async fn main() -> Result<(), Box> { } ReportCollectorCommand::MaliciousHybrid { ref encrypted_inputs, + ref url_file_list, hybrid_query_config, count, set_fixed_polling_ms, @@ -275,7 +286,19 @@ async fn main() -> Result<(), Box> { &args, hybrid_query_config, clients, - encrypted_inputs, + |query_id| { + if let Some(ref url_file_list) = url_file_list { + inputs_from_url_file(url_file_list, query_id, args.shard_count) + } else if let Some(ref encrypted_inputs) = encrypted_inputs { + Ok(inputs_from_encrypted_inputs( + encrypted_inputs, + query_id, + args.shard_count, + )) + } else { + panic!("Either --url-file-list or --enc-input-file1, --enc-input-file2, and --enc-input-file3 must be provided"); + } + }, count.try_into().expect("u32 should fit into usize"), set_fixed_polling_ms, ) @@ -286,6 +309,95 @@ async fn main() -> Result<(), Box> { Ok(()) } +fn inputs_from_url_file( + url_file_path: &Path, + query_id: QueryId, + shard_count: usize, +) -> Result, Box> { + let mut file = BufReader::new(File::open(url_file_path)?); + let mut buf = String::new(); + let mut inputs = [Vec::new(), Vec::new(), Vec::new()]; + for helper_input in inputs.iter_mut() { + for _ in 0..shard_count { + buf.clear(); + if file.read_line(&mut buf)? == 0 { + break; + } + helper_input + .push(Uri::try_from(buf.trim()).map_err(|e| format!("Invalid URL {buf:?}: {e}"))?); + } + } + + // make sure all helpers have the expected number of inputs (one per shard) + let all_rows = inputs.iter().map(|v| v.len()).sum::(); + if all_rows != 3 * shard_count { + return Err(format!( + "The number of URLs in {url_file_path:?} '{all_rows}' is less than 3*{shard_count}." + ) + .into()); + } + + let [h1, h2, h3] = inputs; + Ok(zip(zip(h1, h2), h3) + .map(|((h1, h2), h3)| { + [ + QueryInput::FromUrl { + url: h1.to_string(), + query_id, + }, + QueryInput::FromUrl { + url: h2.to_string(), + query_id, + }, + QueryInput::FromUrl { + url: h3.to_string(), + query_id, + }, + ] + }) + .collect()) +} + +fn inputs_from_encrypted_inputs( + encrypted_inputs: &EncryptedInputs, + query_id: QueryId, + shard_count: usize, +) -> Vec<[QueryInput; 3]> { + let [h1_streams, h2_streams, h3_streams] = [ + &encrypted_inputs.enc_input_file1, + &encrypted_inputs.enc_input_file2, + &encrypted_inputs.enc_input_file3, + ] + .map(|path| { + let file = File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}")); + RoundRobinSubmission::new(BufReader::new(file)) + }) + .map(|s| s.into_byte_streams(shard_count)); + + // create byte streams for each shard + h1_streams + .into_iter() + .zip(h2_streams) + .zip(h3_streams) + .map(|((s1, s2), s3)| { + [ + QueryInput::Inline { + input_stream: BodyStream::from_bytes_stream(s1), + query_id, + }, + QueryInput::Inline { + input_stream: BodyStream::from_bytes_stream(s2), + query_id, + }, + QueryInput::Inline { + input_stream: BodyStream::from_bytes_stream(s3), + query_id, + }, + ] + }) + .collect::>() +} + fn gen_hybrid_inputs( count: u32, seed: Option, @@ -422,41 +534,16 @@ fn write_hybrid_output_file( Ok(()) } -async fn hybrid( +async fn hybrid Result, Box>>( args: &Args, hybrid_query_config: HybridQueryParams, helper_clients: Vec<[IpaHttpClient; 3]>, - encrypted_inputs: &EncryptedInputs, + make_inputs_fn: F, count: usize, set_fixed_polling_ms: Option, ) -> Result<(), Box> { let query_type = QueryType::MaliciousHybrid(hybrid_query_config); - let [h1_streams, h2_streams, h3_streams] = [ - &encrypted_inputs.enc_input_file1, - &encrypted_inputs.enc_input_file2, - &encrypted_inputs.enc_input_file3, - ] - .map(|path| { - let file = File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}")); - RoundRobinSubmission::new(BufReader::new(file)) - }) - .map(|s| s.into_byte_streams(args.shard_count)); - - // create byte streams for each shard - let submissions = h1_streams - .into_iter() - .zip(h2_streams.into_iter()) - .zip(h3_streams.into_iter()) - .map(|((s1, s2), s3)| { - [ - BodyStream::from_bytes_stream(s1), - BodyStream::from_bytes_stream(s2), - BodyStream::from_bytes_stream(s3), - ] - }) - .collect::>(); - let query_config = QueryConfig { size: QuerySize::try_from(count).unwrap(), field_type: FieldType::Fp32BitPrime, @@ -469,6 +556,7 @@ async fn hybrid( .expect("Unable to create query!"); tracing::info!("Starting query for OPRF"); + let submissions = make_inputs_fn(query_id)?; // the value for histogram values (BA32) must be kept in sync with the server-side // implementation, otherwise a runtime reconstruct error will be generated. @@ -477,7 +565,6 @@ async fn hybrid( submissions, count, helper_clients, - query_id, hybrid_query_config, set_fixed_polling_ms, ) diff --git a/ipa-core/src/cli/playbook/hybrid.rs b/ipa-core/src/cli/playbook/hybrid.rs index 53bfc6c28..ff6f7d3a4 100644 --- a/ipa-core/src/cli/playbook/hybrid.rs +++ b/ipa-core/src/cli/playbook/hybrid.rs @@ -11,12 +11,8 @@ use tokio::time::sleep; use crate::{ ff::{Serializable, U128Conversions}, - helpers::{ - query::{HybridQueryParams, QueryInput, QuerySize}, - BodyStream, - }, + helpers::query::{HybridQueryParams, QueryInput, QuerySize}, net::{Helper, IpaHttpClient}, - protocol::QueryId, query::QueryStatus, secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue}, test_fixture::Reconstruct, @@ -26,10 +22,9 @@ use crate::{ /// if results are invalid #[allow(clippy::disallowed_methods)] // allow try_join_all pub async fn run_hybrid_query_and_validate( - inputs: Vec<[BodyStream; 3]>, + inputs: Vec<[QueryInput; 3]>, query_size: usize, clients: Vec<[IpaHttpClient; 3]>, - query_id: QueryId, query_config: HybridQueryParams, set_fixed_polling_ms: Option, ) -> HybridQueryResult @@ -37,19 +32,21 @@ where HV: SharedValue + U128Conversions, AdditiveShare: Serializable, { + let query_id = inputs + .first() + .map(|v| v[0].query_id()) + .expect("At least one shard must be used to run a Hybrid query"); let mpc_time = Instant::now(); assert_eq!(clients.len(), inputs.len()); // submit inputs to each shard let _ = try_join_all(zip(clients.iter(), inputs.into_iter()).map( |(shard_clients, shard_inputs)| { - try_join_all(shard_clients.iter().zip(shard_inputs.into_iter()).map( - |(client, input)| { - client.query_input(QueryInput::Inline { - query_id, - input_stream: input, - }) - }, - )) + try_join_all( + shard_clients + .iter() + .zip(shard_inputs.into_iter()) + .map(|(client, input)| client.query_input(input)), + ) }, )) .await diff --git a/ipa-core/tests/hybrid.rs b/ipa-core/tests/hybrid.rs index b40f524f0..50251f8bc 100644 --- a/ipa-core/tests/hybrid.rs +++ b/ipa-core/tests/hybrid.rs @@ -4,22 +4,34 @@ mod common; use std::{ fs::File, + io::{BufReader, Read, Write}, + iter::once, + net::TcpListener, + os::fd::AsRawFd, + path::{Path, PathBuf}, process::{Command, Stdio}, }; +use bytes::Bytes; +use command_fds::CommandFdExt; use common::{ spawn_shards, tempdir::TempDir, test_sharded_setup, CommandExt, TerminateOnDropExt, UnwrapStatusExt, CRYPTO_UTIL_BIN, TEST_RC_BIN, }; -use ipa_core::{cli::playbook::HybridQueryResult, helpers::query::HybridQueryParams}; +use futures_util::{StreamExt, TryStreamExt}; +use ipa_core::{ + cli::playbook::HybridQueryResult, + error::BoxError, + helpers::{query::HybridQueryParams, LengthDelimitedStream}, +}; use rand::thread_rng; use rand_core::RngCore; use serde_json::from_reader; +use crate::common::TEST_MPC_BIN; + pub const IN_THE_CLEAR_BIN: &str = env!("CARGO_BIN_EXE_in_the_clear"); -// this currently only generates data and runs in the clear -// eventaully we'll want to add the MPC as well #[test] fn test_hybrid() { const INPUT_SIZE: usize = 100; @@ -134,3 +146,212 @@ fn test_hybrid() { .zip(expected_result.iter()) .all(|(a, b)| a == b)); } + +#[test] +fn test_hybrid_poll() { + const INPUT_SIZE: usize = 100; + const SHARDS: usize = 5; + const MAX_CONVERSION_VALUE: usize = 5; + + let config = HybridQueryParams { + max_breakdown_key: 5, + with_dp: 0, + epsilon: 0.0, + // only encrypted inputs are supported + plaintext_match_keys: false, + }; + + let dir = TempDir::new_delete_on_drop(); + + // Gen inputs + let input_file = dir.path().join("ipa_inputs.txt"); + let in_the_clear_output_file = dir.path().join("ipa_output_in_the_clear.json"); + let output_file = dir.path().join("ipa_output.json"); + + let mut command = Command::new(TEST_RC_BIN); + command + .args(["--output-file".as_ref(), input_file.as_os_str()]) + .arg("gen-hybrid-inputs") + .args(["--count", &INPUT_SIZE.to_string()]) + .args(["--max-conversion-value", &MAX_CONVERSION_VALUE.to_string()]) + .args(["--max-breakdown-key", &config.max_breakdown_key.to_string()]) + .args(["--seed", &thread_rng().next_u64().to_string()]) + .silent() + .stdin(Stdio::piped()); + command.status().unwrap_status(); + + let mut command = Command::new(IN_THE_CLEAR_BIN); + command + .args(["--input-file".as_ref(), input_file.as_os_str()]) + .args([ + "--output-file".as_ref(), + in_the_clear_output_file.as_os_str(), + ]) + .silent() + .stdin(Stdio::piped()); + command.status().unwrap_status(); + + let config_path = dir.path().join("config"); + let sockets = test_sharded_setup::(&config_path); + let _helpers = spawn_shards(&config_path, &sockets, true); + + // encrypt input + let mut command = Command::new(CRYPTO_UTIL_BIN); + command + .arg("hybrid-encrypt") + .args(["--input-file".as_ref(), input_file.as_os_str()]) + .args(["--output-dir".as_ref(), dir.path().as_os_str()]) + .args(["--length-delimited"]) + .args(["--network".into(), config_path.join("network.toml")]) + .stdin(Stdio::piped()); + command.status().unwrap_status(); + let enc1 = dir.path().join("helper1.enc"); + let enc2 = dir.path().join("helper2.enc"); + let enc3 = dir.path().join("helper3.enc"); + + let poll_port = TcpListener::bind("127.0.0.1:0").unwrap(); + + // split encryption into N shards and create a metadata file that contains + // all files + let upload_metadata = create_upload_files::( + &enc1, + &enc2, + &enc3, + poll_port.local_addr().unwrap().port(), + dir.path(), + ) + .unwrap(); + + // spawn HTTP server to serve the uploaded files + let mut command = Command::new(TEST_MPC_BIN); + command + .arg("serve-input") + .preserved_fds(vec![poll_port.as_raw_fd()]) + .args(["--fd", &poll_port.as_raw_fd().to_string()]) + .args([ + "--dir".as_ref(), + upload_metadata.parent().unwrap().as_os_str(), + ]) + .silent(); + + let _server_handle = command.spawn().unwrap().terminate_on_drop(); + + // Run Hybrid + let mut command = Command::new(TEST_RC_BIN); + command + .args(["--network".into(), config_path.join("network.toml")]) + .args(["--output-file".as_ref(), output_file.as_os_str()]) + .args(["--shard-count", SHARDS.to_string().as_str()]) + .args(["--wait", "2"]) + .arg("malicious-hybrid") + .silent() + .args(["--count", INPUT_SIZE.to_string().as_str()]) + .args(["--url-file-list".into(), upload_metadata]) + .args(["--max-breakdown-key", &config.max_breakdown_key.to_string()]); + + match config.with_dp { + 0 => { + command.args(["--with-dp", &config.with_dp.to_string()]); + } + _ => { + command + .args(["--with-dp", &config.with_dp.to_string()]) + .args(["--epsilon", &config.epsilon.to_string()]); + } + } + command.stdin(Stdio::piped()); + + let test_mpc = command.spawn().unwrap().terminate_on_drop(); + test_mpc.wait().unwrap_status(); + + // basic output checks - output should have the exact size as number of breakdowns + let output = serde_json::from_str::( + &std::fs::read_to_string(&output_file).expect("IPA results file should exist"), + ) + .expect("IPA results file is valid JSON"); + + assert_eq!( + usize::try_from(config.max_breakdown_key).unwrap(), + output.breakdowns.len(), + "Number of breakdowns does not match the expected", + ); + assert_eq!(INPUT_SIZE, usize::from(output.input_size)); + + let expected_result: Vec = from_reader( + File::open(in_the_clear_output_file) + .expect("file should exist as it's created above in the test"), + ) + .expect("should match hard coded format from in_the_clear"); + assert!(output + .breakdowns + .iter() + .zip(expected_result.iter()) + .all(|(a, b)| a == b)); +} + +fn create_upload_files( + enc_file1: &Path, + enc_file2: &Path, + enc_file3: &Path, + port: u16, + dest: &Path, +) -> Result { + let manifest_path = dest.join("manifest.txt"); + let mut manifest_file = File::create_new(&manifest_path)?; + create_upload_file::("h1", enc_file1, port, dest, &mut manifest_file)?; + create_upload_file::("h2", enc_file2, port, dest, &mut manifest_file)?; + create_upload_file::("h3", enc_file3, port, dest, &mut manifest_file)?; + + manifest_file.flush()?; + + Ok(manifest_path) +} + +fn create_upload_file( + prefix: &str, + enc_file: &Path, + port: u16, + dest_dir: &Path, + metadata_file: &mut File, +) -> Result<(), BoxError> { + let mut files = (0..SHARDS) + .map(|i| { + let path = dest_dir.join(format!("{prefix}_shard_{i}.enc")); + let file = File::create_new(&path)?; + Ok((path, file)) + }) + .collect::>>()?; + + // we assume files are tiny for the integration tests + let mut input = BufReader::new(File::open(enc_file)?); + let mut buf = Vec::new(); + if input.read_to_end(&mut buf)? == 0 { + panic!("{:?} file is empty", enc_file); + } + + // read length delimited data and write it to each file + let stream = + LengthDelimitedStream::::new(futures::stream::iter(once(Ok::<_, BoxError>( + buf.into(), + )))) + .map_ok(|v| futures::stream::iter(v).map(Ok::<_, BoxError>)) + .try_flatten(); + + for (i, next_bytes) in futures::executor::block_on_stream(stream).enumerate() { + let next_bytes = next_bytes?; + let file = &mut files[i % SHARDS].1; + let len = u16::try_from(next_bytes.len()) + .map_err(|_| format!("record is too too big: {} > 65535", next_bytes.len()))?; + file.write(&len.to_le_bytes())?; + file.write_all(&next_bytes)?; + } + + // update manifest file + for (path, mut file) in files { + file.flush()?; + let path = path.file_name().and_then(|p| p.to_str()).unwrap(); + writeln!(metadata_file, "http://localhost:{port}/{path}")?; + } + + Ok(()) +} From dc91dcda3d206f3e4f781b64d6c27b92ccc33275 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 20 Dec 2024 20:42:06 -0800 Subject: [PATCH 10/14] Parallelize decryption of reports (#1512) --- ipa-core/src/query/runner/hybrid.rs | 33 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index d9d2099fe..e41fbc152 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -5,7 +5,7 @@ use std::{ sync::Arc, }; -use futures::{stream::iter, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use generic_array::ArrayLength; use super::QueryResult; @@ -20,7 +20,9 @@ use crate::{ }, helpers::{ query::{DpMechanism, HybridQueryParams, QueryConfig, QuerySize}, - setup_cross_shard_prss, BodyStream, Gateway, LengthDelimitedStream, + setup_cross_shard_prss, + stream::TryFlattenItersExt, + BodyStream, Gateway, LengthDelimitedStream, }, hpke::PrivateKeyRegistry, protocol::{ @@ -105,7 +107,7 @@ where config, key_registry, phantom_data: _, - } = self; + } = &self; tracing::info!("New hybrid query: {config:?}"); let ctx = ctx.narrow(&Hybrid); @@ -118,21 +120,18 @@ where } let stream = LengthDelimitedStream::, _>::new(input_stream) - .map_err(Into::::into) - .map_ok(|enc_reports| { - iter(enc_reports.into_iter().map({ - |enc_report| { - let dec_report = enc_report - .decrypt(key_registry.as_ref()) - .map_err(Into::::into); - let unique_tag = UniqueTag::from_unique_bytes(&enc_report); - dec_report.map(|dec_report1| (dec_report1, unique_tag)) - } - })) + .map_err(Into::into) + .try_flatten_iters() + .map(|enc_report_res| async move { + enc_report_res.and_then(|enc_report| { + let dec_report = enc_report + .decrypt(key_registry.as_ref()) + .map_err(Into::::into); + let unique_tag = UniqueTag::from_unique_bytes(&enc_report); + dec_report.map(|dec_report1| (dec_report1, unique_tag)) + }) }) - .try_flatten() - .take(sz) - .map(|v| async move { v }); + .take(sz); let (decrypted_reports, resharded_tags) = reshard_aad( ctx.narrow(&HybridStep::ReshardByTag), From 3413c3c351e8a5d1b73dfc348af6b3ef1fa8013d Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Sat, 21 Dec 2024 12:33:50 -0800 Subject: [PATCH 11/14] Periodically report memory usage (#1503) --- ipa-core/benches/oneshot/ipa.rs | 15 +++--- ipa-core/src/bin/helper.rs | 15 +++--- ipa-core/src/lib.rs | 16 ++++++ ipa-core/src/seq_join/local.rs | 7 +++ ipa-core/src/seq_join/multi_thread.rs | 10 +++- ipa-core/src/telemetry/memory.rs | 76 +++++++++++++++++++++++++++ ipa-core/src/telemetry/mod.rs | 1 + 7 files changed, 123 insertions(+), 17 deletions(-) create mode 100644 ipa-core/src/telemetry/memory.rs diff --git a/ipa-core/benches/oneshot/ipa.rs b/ipa-core/benches/oneshot/ipa.rs index 1f9eec376..cb2d963d2 100644 --- a/ipa-core/benches/oneshot/ipa.rs +++ b/ipa-core/benches/oneshot/ipa.rs @@ -19,14 +19,6 @@ use ipa_step::StepNarrow; use rand::{random, rngs::StdRng, SeedableRng}; use tokio::runtime::Builder; -#[cfg(jemalloc)] -#[global_allocator] -static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; - -#[cfg(feature = "dhat-heap")] -#[global_allocator] -static ALLOC: dhat::Alloc = dhat::Alloc; - /// A benchmark for the full IPA protocol. #[derive(Parser)] #[command(about, long_about = None)] @@ -165,6 +157,13 @@ async fn run(args: Args) -> Result<(), Error> { } fn main() -> Result<(), Error> { + #[cfg(jemalloc)] + ipa_core::use_jemalloc!(); + + #[cfg(feature = "dhat-heap")] + #[global_allocator] + static ALLOC: dhat::Alloc = dhat::Alloc; + #[cfg(feature = "dhat-heap")] let _profiler = dhat::Profiler::new_heap(); diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 3fa554a6a..3b62653a4 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -30,14 +30,6 @@ use ipa_core::{ use tokio::runtime::Runtime; use tracing::{error, info}; -#[cfg(jemalloc)] -#[global_allocator] -static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; - -#[cfg(feature = "dhat-heap")] -#[global_allocator] -static ALLOC: dhat::Alloc = dhat::Alloc; - #[derive(Debug, Parser)] #[clap( name = "helper", @@ -369,6 +361,13 @@ fn new_query_runtime(logging_handle: &LoggingHandle) -> Runtime { /// runtimes to use in MPC queries and HTTP. #[tokio::main(flavor = "current_thread")] pub async fn main() { + #[cfg(jemalloc)] + ipa_core::use_jemalloc!(); + + #[cfg(feature = "dhat-heap")] + #[global_allocator] + static ALLOC: dhat::Alloc = dhat::Alloc; + let args = Args::parse(); let handle = args.logging.setup_logging(); diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index ce98182f3..188f8154b 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -32,6 +32,7 @@ mod seq_join; mod serde; pub mod sharding; pub mod utils; + pub use app::{AppConfig, HelperApp, Setup as AppSetup}; pub use utils::NonZeroU32PowerOfTwo; @@ -348,6 +349,21 @@ pub(crate) mod test_executor { pub const CRATE_NAME: &str = env!("CARGO_CRATE_NAME"); +/// This macro should be called in a binary that uses `ipa_core`, if that binary wishes +/// to use jemalloc. +/// +/// Besides declaring the `#[global_allocator]`, the macro also activates some memory +/// reporting. +#[macro_export] +macro_rules! use_jemalloc { + () => { + #[global_allocator] + static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + + $crate::telemetry::memory::jemalloc::activate(); + }; +} + #[macro_export] macro_rules! const_assert { ($x:expr $(,)?) => { diff --git a/ipa-core/src/seq_join/local.rs b/ipa-core/src/seq_join/local.rs index 89ba2ca91..951df9701 100644 --- a/ipa-core/src/seq_join/local.rs +++ b/ipa-core/src/seq_join/local.rs @@ -10,6 +10,8 @@ use std::{ use futures::{stream::Fuse, Future, Stream, StreamExt}; use pin_project::pin_project; +use crate::telemetry::memory::periodic_memory_report; + enum ActiveItem { Pending(Pin>), Resolved(F::Output), @@ -56,6 +58,7 @@ where #[pin] source: Fuse, active: VecDeque>, + spawned: usize, _marker: PhantomData &'unused ()>, } @@ -68,6 +71,7 @@ where Self { source: source.fuse(), active: VecDeque::with_capacity(active.get()), + spawned: 0, _marker: PhantomData, } } @@ -88,6 +92,8 @@ where if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { this.active .push_back(ActiveItem::Pending(Box::pin(f.into_future()))); + periodic_memory_report(*this.spawned); + *this.spawned += 1; } else { break; } @@ -104,6 +110,7 @@ where Poll::Pending } } else if this.source.is_done() { + periodic_memory_report(*this.spawned); Poll::Ready(None) } else { Poll::Pending diff --git a/ipa-core/src/seq_join/multi_thread.rs b/ipa-core/src/seq_join/multi_thread.rs index 2ac8f458f..d68253417 100644 --- a/ipa-core/src/seq_join/multi_thread.rs +++ b/ipa-core/src/seq_join/multi_thread.rs @@ -9,6 +9,8 @@ use futures::{stream::Fuse, Stream, StreamExt}; use pin_project::pin_project; use tracing::{Instrument, Span}; +use crate::telemetry::memory::periodic_memory_report; + #[cfg(feature = "shuttle")] mod shuttle_spawner { use std::future::Future; @@ -62,6 +64,7 @@ where #[pin] source: Fuse, capacity: usize, + spawned: usize, } impl SequentialFutures<'_, S, F> @@ -75,6 +78,7 @@ where spawner: unsafe { create_spawner() }, source: source.fuse(), capacity: active.get(), + spawned: 0, } } } @@ -103,11 +107,14 @@ where // a dependency between futures, pending one will never complete. // Cancellable futures will be cancelled when spawner is dropped which is // the behavior we want. - let task_index = this.spawner.len(); + let task_index = *this.spawned; this.spawner .spawn_cancellable(f.into_future().instrument(Span::current()), move || { panic!("SequentialFutures: spawned task {task_index} cancelled") }); + + periodic_memory_report(*this.spawned); + *this.spawned += 1; } else { break; } @@ -127,6 +134,7 @@ where None => None, }) } else if this.source.is_done() { + periodic_memory_report(*this.spawned); Poll::Ready(None) } else { Poll::Pending diff --git a/ipa-core/src/telemetry/memory.rs b/ipa-core/src/telemetry/memory.rs new file mode 100644 index 000000000..d43e791f1 --- /dev/null +++ b/ipa-core/src/telemetry/memory.rs @@ -0,0 +1,76 @@ +pub fn periodic_memory_report(count: usize) { + #[cfg(not(jemalloc))] + let _ = count; + + #[cfg(jemalloc)] + jemalloc::periodic_memory_report(count); +} + +#[cfg(jemalloc)] +pub mod jemalloc { + use std::sync::RwLock; + + use tikv_jemalloc_ctl::{epoch_mib, stats::allocated_mib}; + + const MB: usize = 2 << 20; + + // In an unfortunate acronym collision, `mib` in the names of the jemalloc + // statistics stands for "Management Information Base", not "mebibytes". + // The reporting unit is bytes. + + struct JemallocControls { + epoch: epoch_mib, + allocated: allocated_mib, + } + + static CONTROLS: RwLock> = RwLock::new(None); + + /// Activates periodic memory usage reporting during `seq_join`. + /// + /// # Panics + /// If `RwLock` is poisoned. + pub fn activate() { + let mut controls = CONTROLS.write().unwrap(); + + let epoch = tikv_jemalloc_ctl::epoch::mib().unwrap(); + let allocated = tikv_jemalloc_ctl::stats::allocated::mib().unwrap(); + + *controls = Some(JemallocControls { epoch, allocated }); + } + + fn report_memory_usage(controls: &JemallocControls, count: usize) { + // Some of the information jemalloc uses when reporting statistics is cached, and + // refreshed only upon advancing the epoch. + controls.epoch.advance().unwrap(); + let allocated = controls.allocated.read().unwrap() / MB; + tracing::debug!("i={count}: {allocated} MiB allocated"); + } + + fn should_print_report(count: usize) -> bool { + if count == 0 { + return true; + } + + let bits = count.ilog2(); + let report_interval_log2 = std::cmp::max(bits.saturating_sub(1), 8); + let report_interval_mask = (1 << report_interval_log2) - 1; + (count & report_interval_mask) == 0 + } + + /// Print a memory report periodically, based on the value of `count`. + /// + /// As `count` increases, so does the report interval. This results in + /// a tolerable amount of log messages for loops with many iterations, + /// while still providing some reporting for shorter loops. + /// + /// # Panics + /// If `RwLock` is poisoned. + pub fn periodic_memory_report(count: usize) { + let controls_opt = CONTROLS.read().unwrap(); + if let Some(controls) = controls_opt.as_ref() { + if should_print_report(count) { + report_memory_usage(controls, count); + } + } + } +} diff --git a/ipa-core/src/telemetry/mod.rs b/ipa-core/src/telemetry/mod.rs index aae6ea1d4..eb31325d7 100644 --- a/ipa-core/src/telemetry/mod.rs +++ b/ipa-core/src/telemetry/mod.rs @@ -1,3 +1,4 @@ +pub mod memory; pub mod stats; mod step_stats; From 14d8e2a140dce99c1d44404106888ff681fdbe2e Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Sat, 21 Dec 2024 12:57:55 -0800 Subject: [PATCH 12/14] Increase a couple test timeouts --- ipa-core/src/protocol/dp/mod.rs | 6 +++++- ipa-core/src/protocol/hybrid/breakdown_reveal.rs | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/protocol/dp/mod.rs b/ipa-core/src/protocol/dp/mod.rs index fbd8263f4..a8b9938d7 100644 --- a/ipa-core/src/protocol/dp/mod.rs +++ b/ipa-core/src/protocol/dp/mod.rs @@ -976,7 +976,11 @@ mod test { type OutputValue = BA32; const NUM_BREAKDOWNS: u32 = 32; - let world = TestWorld::new_with(TestWorldConfig::default().enable_metrics()); + let world = TestWorld::new_with( + TestWorldConfig::default() + .with_timeout_secs(30) + .enable_metrics(), + ); let num_bernoulli: u32 = 1_000; let result: [Vec>; 3] = world diff --git a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs index d94bf001c..ab6c8cec9 100644 --- a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs +++ b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs @@ -628,7 +628,7 @@ mod proptests { } = input_struct; let config = TestWorldConfig { seed, - timeout: Some(Duration::from_secs(20)), + timeout: Some(Duration::from_secs(30)), ..Default::default() }; let result = TestWorld::>::with_config(&config) From 80743d1ee9e2ec7de63ede3c2fb458e6baf36a34 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 26 Dec 2024 12:47:01 -0500 Subject: [PATCH 13/14] Use sharded shuffle in Hybrid It turns out, we've never used sharded shuffle in Hybrid, running local shuffles (per shard) instead. This change fixes that. Also, drop `Shuffle` implementation for sharded contexts to prevent this from happening again. --- ipa-core/src/protocol/hybrid/breakdown_reveal.rs | 6 +++--- ipa-core/src/protocol/hybrid/mod.rs | 6 +++--- ipa-core/src/protocol/hybrid/oprf.rs | 8 +++++--- ipa-core/src/protocol/ipa_prf/shuffle/mod.rs | 8 ++++---- ipa-core/src/query/runner/hybrid.rs | 4 ++-- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs index d94bf001c..9af2b7ede 100644 --- a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs +++ b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs @@ -20,7 +20,7 @@ use crate::{ AGGREGATE_DEPTH, }, oprf_padding::{apply_dp_padding, PaddingParameters}, - shuffle::Shuffle, + shuffle::ShardedShuffle, }, BooleanProtocols, RecordId, }, @@ -66,7 +66,7 @@ pub async fn breakdown_reveal_aggregation( padding_params: &PaddingParameters, ) -> Result>, Error> where - C: UpgradableContext + Shuffle + ShardedContext, + C: UpgradableContext + ShardedShuffle + ShardedContext, Boolean: FieldSimd, Replicated: BooleanProtocols, B>, BK: BooleanArray + U128Conversions, @@ -94,7 +94,7 @@ where let attributions = ctx .narrow(&Step::Shuffle) - .shuffle(attributed_values_padded) + .sharded_shuffle(attributed_values_padded) .instrument(info_span!("shuffle_attribution_outputs")) .await?; diff --git a/ipa-core/src/protocol/hybrid/mod.rs b/ipa-core/src/protocol/hybrid/mod.rs index 5217365b6..b6de81fc6 100644 --- a/ipa-core/src/protocol/hybrid/mod.rs +++ b/ipa-core/src/protocol/hybrid/mod.rs @@ -33,7 +33,7 @@ use crate::{ ipa_prf::{ oprf_padding::{apply_dp_padding, PaddingParameters}, prf_eval::PrfSharing, - shuffle::Shuffle, + shuffle::ShardedShuffle, }, prss::FromPrss, BooleanProtocols, @@ -79,7 +79,7 @@ pub async fn hybrid_protocol<'ctx, C, BK, V, HV, const SS_BITS: usize, const B: where C: UpgradableContext + 'ctx - + Shuffle + + ShardedShuffle + ShardedContext + FinalizerContext>, BK: BreakdownKey, @@ -121,7 +121,7 @@ where let shuffled_input_rows = ctx .narrow(&Step::InputShuffle) - .shuffle(padded_input_rows) + .sharded_shuffle(padded_input_rows) .instrument(info_span!("shuffle_inputs")) .await?; diff --git a/ipa-core/src/protocol/hybrid/oprf.rs b/ipa-core/src/protocol/hybrid/oprf.rs index 29d13d755..8bed507ea 100644 --- a/ipa-core/src/protocol/hybrid/oprf.rs +++ b/ipa-core/src/protocol/hybrid/oprf.rs @@ -181,9 +181,11 @@ where // reshard reports based on OPRF values. This ensures at the end of this function // reports with the same value end up on the same shard. - reshard_try_stream(ctx, report_stream, |ctx, _, report| { - report.match_key % ctx.shard_count() - }) + reshard_try_stream( + ctx.narrow(&HybridStep::ReshardByPrf), + report_stream, + |ctx, _, report| report.match_key % ctx.shard_count(), + ) .await } diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index 0ee65d169..189b3e75b 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -9,7 +9,7 @@ use crate::{ context::{Context, MaliciousContext, SemiHonestContext}, ipa_prf::shuffle::sharded::ShuffleContext, }, - sharding::{ShardBinding, Sharded}, + sharding::{Sharded}, }; mod base; @@ -21,6 +21,7 @@ use base::shuffle_protocol as base_shuffle; use malicious::{malicious_sharded_shuffle, malicious_shuffle}; use sharded::shuffle as sharded_shuffle; pub use sharded::{MaliciousShuffleable, Shuffleable}; +use crate::sharding::NotSharded; /// This struct stores some intermediate messages during the shuffle. /// In a maliciously secure shuffle, @@ -63,7 +64,7 @@ pub trait Shuffle: Context { S: MaliciousShuffleable; } -impl Shuffle for SemiHonestContext<'_, T> { +impl Shuffle for SemiHonestContext<'_, NotSharded> { fn shuffle(self, shares: Vec) -> impl Future, Error>> + Send where S: MaliciousShuffleable, @@ -73,7 +74,7 @@ impl Shuffle for SemiHonestContext<'_, T> { } } -impl Shuffle for MaliciousContext<'_, T> { +impl Shuffle for MaliciousContext<'_, NotSharded> { fn shuffle(self, shares: Vec) -> impl Future, Error>> + Send where S: MaliciousShuffleable, @@ -84,7 +85,6 @@ impl Shuffle for MaliciousContext<'_, T> { /// Trait used by protocols to invoke either semi-honest or malicious sharded shuffle, /// depending on the type of context being used. -#[allow(dead_code)] pub trait ShardedShuffle: ShuffleContext { fn sharded_shuffle( self, diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index e41fbc152..e3b4f6311 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -35,7 +35,7 @@ use crate::{ oprf::{CONV_CHUNK, PRF_CHUNK}, step::HybridStep, }, - ipa_prf::{oprf_padding::PaddingParameters, prf_eval::PrfSharing, shuffle::Shuffle}, + ipa_prf::{oprf_padding::PaddingParameters, prf_eval::PrfSharing, shuffle::ShardedShuffle}, prss::{Endpoint, FromPrss}, step::ProtocolStep::Hybrid, Gate, @@ -73,7 +73,7 @@ impl Query { impl Query where C: UpgradableContext - + Shuffle + + ShardedShuffle + ShardedContext + FinalizerContext>, HV: BooleanArray + U128Conversions, From 251d1f3984f11348b8392166928df166c11529a5 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 26 Dec 2024 16:00:41 -0500 Subject: [PATCH 14/14] Fix compact gate steps for sharded shuffle --- .../src/protocol/hybrid/breakdown_reveal.rs | 6 ++---- ipa-core/src/protocol/hybrid/step.rs | 19 +++++++++++++++++-- .../src/protocol/ipa_prf/shuffle/malicious.rs | 8 ++++---- ipa-core/src/protocol/ipa_prf/shuffle/mod.rs | 3 ++- ipa-core/src/protocol/ipa_prf/shuffle/step.rs | 5 +++++ 5 files changed, 30 insertions(+), 11 deletions(-) diff --git a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs index 9af2b7ede..d851a05ac 100644 --- a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs +++ b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs @@ -14,11 +14,9 @@ use crate::{ dzkp_validator::DZKPValidator, Context, DZKPUpgraded, MaliciousProtocolSteps, ShardedContext, UpgradableContext, }, + hybrid::step::AggregationStep as Step, ipa_prf::{ - aggregation::{ - aggregate_values, aggregate_values_proof_chunk, step::AggregationStep as Step, - AGGREGATE_DEPTH, - }, + aggregation::{aggregate_values, aggregate_values_proof_chunk, AGGREGATE_DEPTH}, oprf_padding::{apply_dp_padding, PaddingParameters}, shuffle::ShardedShuffle, }, diff --git a/ipa-core/src/protocol/hybrid/step.rs b/ipa-core/src/protocol/hybrid/step.rs index 2a98488fc..7f4de200a 100644 --- a/ipa-core/src/protocol/hybrid/step.rs +++ b/ipa-core/src/protocol/hybrid/step.rs @@ -5,7 +5,7 @@ pub(crate) enum HybridStep { ReshardByTag, #[step(child = crate::protocol::ipa_prf::oprf_padding::step::PaddingDpStep, name="report_padding_dp")] PaddingDp, - #[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)] + #[step(child = crate::protocol::ipa_prf::shuffle::step::ShardedShuffleStep)] InputShuffle, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::Fp25519ConversionStep)] ConvertFp25519, @@ -19,7 +19,7 @@ pub(crate) enum HybridStep { GroupBySum, #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] GroupBySumValidate, - #[step(child = crate::protocol::ipa_prf::aggregation::step::AggregationStep)] + #[step(child = AggregationStep)] Aggregate, #[step(child = FinalizeSteps)] Finalize, @@ -40,3 +40,18 @@ pub(crate) enum FinalizeSteps { #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] Validate, } + +#[derive(CompactStep)] +pub(crate) enum AggregationStep { + #[step(child = crate::protocol::ipa_prf::oprf_padding::step::PaddingDpStep, name="padding_dp")] + PaddingDp, + #[step(child = crate::protocol::ipa_prf::shuffle::step::ShardedShuffleStep)] + Shuffle, + Reveal, + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] + RevealValidate, // only partly used -- see code + #[step(count = 4, child = crate::protocol::ipa_prf::aggregation::step::AggregateChunkStep, name = "chunks")] + Aggregate(usize), + #[step(count = 4, child = crate::protocol::context::step::DzkpValidationProtocolStep)] + AggregateValidate(usize), +} diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index b0745f568..f73e1e973 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -30,7 +30,7 @@ use crate::{ h1_shuffle_for_shard, h2_shuffle_for_shard, h3_shuffle_for_shard, MaliciousShuffleable, ShuffleShare, Shuffleable, }, - step::{OPRFShuffleStep, VerifyShuffleStep}, + step::{OPRFShuffleStep, ShardedShuffleStep, VerifyShuffleStep}, IntermediateShuffleMessages, }, prss::SharedRandomness, @@ -179,11 +179,11 @@ where // prepare keys let amount_of_keys: usize = (usize::try_from(S::Share::BITS).unwrap() + 31) / 32; - let keys = setup_keys(ctx.narrow(&OPRFShuffleStep::SetupKeys), amount_of_keys).await?; + let keys = setup_keys(ctx.narrow(&ShardedShuffleStep::SetupKeys), amount_of_keys).await?; // compute and append tags to rows let shares_and_tags: Vec> = - compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?; + compute_and_add_tags(ctx.narrow(&ShardedShuffleStep::GenerateTags), &keys, shares).await?; let (shuffled_shares, messages) = match ctx.role() { Role::H1 => h1_shuffle_for_shard(ctx.clone(), shares_and_tags).await, @@ -193,7 +193,7 @@ where // verify the shuffle verify_shuffle::<_, S>( - ctx.narrow(&OPRFShuffleStep::VerifyShuffle), + ctx.narrow(&ShardedShuffleStep::VerifyShuffle), &keys, &shuffled_shares, messages, diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index 189b3e75b..b9d7c40b3 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -9,7 +9,7 @@ use crate::{ context::{Context, MaliciousContext, SemiHonestContext}, ipa_prf::shuffle::sharded::ShuffleContext, }, - sharding::{Sharded}, + sharding::Sharded, }; mod base; @@ -21,6 +21,7 @@ use base::shuffle_protocol as base_shuffle; use malicious::{malicious_sharded_shuffle, malicious_shuffle}; use sharded::shuffle as sharded_shuffle; pub use sharded::{MaliciousShuffleable, Shuffleable}; + use crate::sharding::NotSharded; /// This struct stores some intermediate messages during the shuffle. diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs index 6a4ff2050..6ca43b5ad 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs @@ -29,6 +29,9 @@ pub(crate) enum VerifyShuffleStep { #[derive(CompactStep)] pub(crate) enum ShardedShuffleStep { + SetupKeys, + #[step(child = crate::protocol::boolean::step::EightBitStep)] + GenerateTags, /// Depending on the helper position inside the MPC ring, generate Ã, B̃ or both. PseudoRandomTable, /// Permute the input according to the PRSS shared between H1 and H2. @@ -46,6 +49,8 @@ pub(crate) enum ShardedShuffleStep { TransferXY, /// H2 and H3 interaction - Exchange `C_1` and `C_2`. TransferC, + #[step(child = crate::protocol::ipa_prf::shuffle::step::VerifyShuffleStep)] + VerifyShuffle, } #[derive(CompactStep)]