diff --git a/.clippy.toml b/.clippy.toml index 5c572f532..800113fde 100644 --- a/.clippy.toml +++ b/.clippy.toml @@ -7,3 +7,5 @@ disallowed-methods = [ { path = "std::mem::ManuallyDrop::new", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, { path = "std::vec::Vec::leak", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, ] + +future-size-threshold = 10240 diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 8d117fa32..9a7735e9e 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -53,7 +53,7 @@ jobs: - name: Clippy if: ${{ success() || failure() }} - run: cargo clippy --tests + run: cargo clippy --features "cli test-fixture" --tests - name: Clippy concurrency tests if: ${{ success() || failure() }} @@ -68,7 +68,7 @@ jobs: run: cargo build --tests - name: Run tests - run: cargo test + run: cargo test --features "cli test-fixture relaxed-dp" - name: Run tests with multithreading feature enabled run: cargo test --features "multi-threading" @@ -76,9 +76,6 @@ jobs: - name: Run Web Tests run: cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" - - name: Run Integration Tests - run: cargo test --test encrypted_input --features "cli test-fixture web-app in-memory-infra" - release: name: Release builds and tests runs-on: ubuntu-latest @@ -148,8 +145,11 @@ jobs: - name: Run arithmetic bench run: cargo bench --bench oneshot_arithmetic --no-default-features --features "enable-benches compact-gate" - - name: Run compact gate tests + - name: Run compact gate tests for HTTP stack run: cargo test --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + + - name: Run in-memory compact gate tests + run: cargo test --features "compact-gate" slow: name: Slow tests env: @@ -172,8 +172,17 @@ jobs: target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} - - name: End-to-end tests - run: cargo test --release --test "*" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + - name: Integration Tests - Compact Gate + run: cargo test --release --test "compact_gate" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + + - name: Integration Tests - Helper Networks + run: cargo test --release --test "helper_networks" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + + - name: Integration Tests - Hybrid + run: cargo test --release --test "hybrid" --features "cli test-fixture" + + - name: Integration Tests - IPA with Relaxed DP + run: cargo test --release --test "ipa_with_relaxed_dp" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate relaxed-dp" # sanitizers currently require nightly https://github.com/rust-lang/rust/issues/39699 sanitize: @@ -192,7 +201,7 @@ jobs: - name: Add Rust sources run: rustup component add rust-src - name: Run tests with sanitizer - run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate ${{ matrix.features }}" + run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std -p ipa-core --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate ${{ matrix.features }}" miri: runs-on: ubuntu-latest @@ -236,4 +245,3 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} file: ipa.cov fail_ci_if_error: false - diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 12c6d40bb..8b47bf627 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -34,7 +34,9 @@ jobs: type=sha - name: "Setup Docker Buildx" - uses: docker/setup-buildx-action@v2 + uses: docker/setup-buildx-action@v3 + with: + platforms: linux/amd64 - name: "Login to GitHub Container Registry" uses: docker/login-action@v2 @@ -44,10 +46,10 @@ jobs: password: ${{ secrets.GITHUB_TOKEN }} - name: "Build and Publish Helper Image" - uses: docker/build-push-action@v4 + uses: docker/build-push-action@v6 with: context: . - file: ./docker/ci/helper.Dockerfile + file: ./docker/helper.Dockerfile push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} diff --git a/.gitignore b/.gitignore index 8aa41d239..674cf39f2 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ /in-market-test/hpke/bin /in-market-test/hpke/lib /in-market-test/hpke/pyvenv.cfg +input-data-*.txt \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 377020368..deb437919 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["ipa-core", "ipa-step", "ipa-step-derive", "ipa-step-test"] +members = ["ipa-core", "ipa-step", "ipa-step-derive", "ipa-step-test", "ipa-metrics", "ipa-metrics-tracing"] [profile.release] incremental = true diff --git a/docker/ci/helper.Dockerfile b/docker/ci/helper.Dockerfile deleted file mode 100644 index 7f7b4d376..000000000 --- a/docker/ci/helper.Dockerfile +++ /dev/null @@ -1,13 +0,0 @@ -# syntax=docker/dockerfile:1 -FROM rust:latest as builder - -COPY . /ipa/ -RUN cd /ipa && \ - cargo build --bin helper --release --no-default-features \ - --features "web-app real-world-infra compact-gate" - -# Copy them to the final image -FROM debian:bullseye-slim - -COPY --from=builder /ipa/target/release/helper /bin/ipa-helper -ENTRYPOINT ["/bin/ipa-helper"] diff --git a/docker/helper.Dockerfile b/docker/helper.Dockerfile index 52c0806ab..fc113222d 100644 --- a/docker/helper.Dockerfile +++ b/docker/helper.Dockerfile @@ -1,6 +1,6 @@ # syntax=docker/dockerfile:1 ARG SOURCES_DIR=/usr/src/ipa -FROM rust:bullseye AS builder +FROM rust:bookworm AS builder ARG SOURCES_DIR # Prepare helper binaries @@ -10,7 +10,7 @@ RUN set -eux; \ cargo build --bin helper --release --no-default-features --features "web-app real-world-infra compact-gate" # Copy them to the final image -FROM debian:bullseye-slim +FROM rust:slim-bookworm ENV HELPER_BIN_PATH=/usr/local/bin/ipa-helper ENV CONF_DIR=/etc/ipa ARG SOURCES_DIR diff --git a/input-data-100.txt b/input-data-100.txt deleted file mode 100644 index 4ea0ee7c4..000000000 --- a/input-data-100.txt +++ /dev/null @@ -1,100 +0,0 @@ -600339,534942975307,0,5,0 -96422,191017627906,0,3,0 -507032,117803731851,0,10,0 -17448,304519167044,1,0,4 -224051,12251697120,0,17,0 -572331,534942975307,1,0,1 -204850,534942975307,0,12,0 -572399,865368699047,0,2,0 -595278,865368699047,1,0,4 -457115,191017627906,1,0,4 -279628,534942975307,0,7,0 -100525,925363717604,1,0,5 -565595,925363717604,0,11,0 -567404,865368699047,0,3,0 -140412,304519167044,1,0,5 -329551,925363717604,1,0,1 -524654,314908499604,0,8,0 -240982,850807271120,1,0,5 -603020,117803731851,0,1,0 -272156,865368699047,0,17,0 -227353,12251697120,0,5,0 -265919,925363717604,1,0,1 -547,12251697120,0,2,0 -342491,925363717604,1,0,1 -250600,304519167044,0,6,0 -252290,117803731851,0,18,0 -141260,850807271120,0,6,0 -248451,304519167044,0,16,0 -515699,191017627906,1,0,4 -312537,12251697120,1,0,2 -492188,283283408809,0,13,0 -451766,917537570026,0,7,0 -287218,822386586545,0,11,0 -67235,925363717604,1,0,5 -603886,917537570026,1,0,3 -213895,117803731851,0,11,0 -418303,534942975307,0,10,0 -210243,822386586545,0,9,0 -211179,117803731851,1,0,5 -568874,925363717604,0,0,0 -373535,925363717604,1,0,3 -232675,534942975307,1,0,5 -92636,191017627906,1,0,1 -398372,917537570026,0,6,0 -401827,534942975307,1,0,2 -155515,65168429090,1,0,1 -33026,304519167044,0,17,0 -493183,179797603392,1,0,1 -167758,179797603392,1,0,4 -522471,191017627906,0,11,0 -313610,925363717604,1,0,1 -176225,12251697120,0,16,0 -588107,925363717604,0,13,0 -280600,393203478859,0,10,0 -491601,179797603392,0,4,0 -445133,773905428637,1,0,3 -301999,12251697120,1,0,5 -65750,526858192111,0,19,0 -350976,12251697120,0,9,0 -67867,773905428637,1,0,2 -594037,191017627906,0,11,0 -261995,534942975307,1,0,3 -133066,288854012131,1,0,4 -40015,179797603392,1,0,5 -571126,288854012131,0,10,0 -514451,773905428637,0,8,0 -201640,288854012131,1,0,4 -71935,526858192111,1,0,2 -316596,773905428637,0,6,0 -246923,12251697120,1,0,3 -79789,773905428637,1,0,4 -47468,917537570026,0,17,0 -161925,773905428637,0,9,0 -225460,393203478859,1,0,4 -530756,640580450837,0,4,0 -94219,338037795442,1,0,4 -136211,179797603392,0,0,0 -559897,191017627906,1,0,1 -332026,179797603392,1,0,1 -35911,917537570026,1,0,5 -329450,191017627906,0,4,0 -102812,393203478859,0,11,0 -578374,917537570026,0,15,0 -156477,881719336823,0,0,0 -277455,179797603392,0,7,0 -186143,881719336823,1,0,3 -228562,393203478859,1,0,3 -346392,822386586545,1,0,3 -102532,881719336823,0,1,0 -589048,822386586545,1,0,1 -430856,288854012131,1,0,5 -408260,881719336823,0,16,0 -180588,477090731329,0,16,0 -502918,288854012131,0,7,0 -392616,393203478859,1,0,1 -463878,22654468721,1,0,1 -85787,393203478859,1,0,5 -238574,288854012131,0,4,0 -22862,822386586545,0,19,0 -481629,288854012131,0,3,0 diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 835ebc28c..6ddfc2009 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -1,7 +1,12 @@ [package] name = "ipa-core" version = "0.1.0" -rust-version = "1.80.0" +# When updating the rust version: +# 1. Check at https://hub.docker.com/_/rust that the relevant version of the +# rust:slim-bullseye docker image is available. +# 2. Update the rust version used for draft in +# https://github.com/private-attribution/draft/blob/main/sidecar/ansible/provision.yaml. +rust-version = "1.82.0" edition = "2021" build = "build.rs" @@ -16,11 +21,14 @@ default = [ "stall-detection", "aggregate-circuit", "ipa-prf", - "ipa-step/string-step", + "descriptive-gate", ] cli = ["comfy-table", "clap"] -# Enabling compact gates disables any tests that rely on descriptive gates. -compact-gate = ["ipa-step/string-step"] +# Enable compact gate optimization +compact-gate = [] +# mutually exclusive with compact-gate and disables compact gate optimization. +# It is enabled by default +descriptive-gate = ["ipa-step/string-step"] disable-metrics = [] # TODO move web-app to a separate crate. It adds a lot of build time to people who mostly write protocols # TODO Consider moving out benches as well @@ -74,6 +82,8 @@ reveal-aggregation = [] aggregate-circuit = [] # IPA protocol based on OPRF ipa-prf = [] +# relaxed DP, off by default +relaxed-dp = [] [dependencies] ipa-step = { version = "*", path = "../ipa-step" } @@ -85,16 +95,14 @@ async-scoped = { version = "0.9.0", features = ["use-tokio"], optional = true } axum = { version = "0.7.5", optional = true, features = ["http2", "macros"] } # The following is a temporary version until we can stabilize the build on a higher version # of axum, rustls and the http stack. -axum-server = { git = "https://github.com/cberkhoff/axum-server/", branch = "0.6.1", version = "0.6.1", optional = true, features = [ - "tls-rustls", -] } +axum-server = { version = "0.7.1", optional = true, features = ["tls-rustls"] } base64 = { version = "0.21.2", optional = true } bitvec = "1.0" bytes = "1.4" clap = { version = "4.3.2", optional = true, features = ["derive"] } comfy-table = { version = "7.0", optional = true } config = "0.14" -console-subscriber = { version = "0.2", optional = true } +console-subscriber = { version = "0.4", optional = true } criterion = { version = "0.5.1", optional = true, default-features = false, features = [ "async_tokio", "plotters", @@ -150,7 +158,7 @@ typenum = { version = "1.17", features = ["i128"] } # hpke is pinned to it x25519-dalek = "2.0.0-rc.3" -[target.'cfg(not(target_env = "msvc"))'.dependencies] +[target.'cfg(all(not(target_env = "msvc"), not(target_os = "macos")))'.dependencies] tikv-jemallocator = "0.5.0" [build-dependencies] @@ -196,7 +204,12 @@ bench = false [[bin]] name = "crypto_util" -required-features = ["cli", "test-fixture", "web-app", "in-memory-infra"] +required-features = ["cli", "test-fixture", "web-app"] +bench = false + +[[bin]] +name = "in_the_clear" +required-features = ["cli", "test-fixture", "web-app"] bench = false [[bench]] @@ -252,3 +265,21 @@ required-features = [ "real-world-infra", "test-fixture", ] + +[[test]] +name = "ipa_with_relaxed_dp" +required-features = [ + "cli", + "compact-gate", + "web-app", + "real-world-infra", + "test-fixture", + "relaxed-dp", +] + +[[test]] +name = "hybrid" +required-features = [ + "test-fixture", + "cli", +] diff --git a/ipa-core/benches/oneshot/ipa.rs b/ipa-core/benches/oneshot/ipa.rs index 02f6e0304..b880c28d6 100644 --- a/ipa-core/benches/oneshot/ipa.rs +++ b/ipa-core/benches/oneshot/ipa.rs @@ -19,7 +19,11 @@ use ipa_step::StepNarrow; use rand::{random, rngs::StdRng, SeedableRng}; use tokio::runtime::Builder; -#[cfg(all(not(target_env = "msvc"), not(feature = "dhat-heap")))] +#[cfg(all( + not(target_env = "msvc"), + not(feature = "dhat-heap"), + not(target_os = "macos") +))] #[global_allocator] static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; @@ -82,6 +86,7 @@ impl Args { self.active_work .map(NonZeroUsize::get) .unwrap_or_else(|| self.query_size.clamp(16, 1024)) + .next_power_of_two() } fn attribution_window(&self) -> Option { diff --git a/ipa-core/build.rs b/ipa-core/build.rs index ed45e74f2..ce1987c72 100644 --- a/ipa-core/build.rs +++ b/ipa-core/build.rs @@ -16,6 +16,7 @@ track_steps!( step, }, context::step, + hybrid::step, ipa_prf::{ boolean_ops::step, prf_sharding::step, @@ -27,7 +28,6 @@ track_steps!( dp::step, step, }, - test_fixture::step ); fn main() { @@ -44,7 +44,7 @@ fn main() { // https://docs.rs/tectonic_cfg_support/latest/tectonic_cfg_support/struct.TargetConfiguration.html cfg_aliases! { compact_gate: { feature = "compact-gate" }, - descriptive_gate: { not(compact_gate) }, + descriptive_gate: { all(not(feature = "compact-gate"), feature = "descriptive-gate") }, unit_test: { all(not(feature = "shuttle"), feature = "in-memory-infra", descriptive_gate) }, web_test: { all(not(feature = "shuttle"), feature = "real-world-infra") }, } diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index db603501c..4d91b92ef 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -1,8 +1,9 @@ -use std::{num::NonZeroUsize, sync::Weak}; +use std::sync::Weak; use async_trait::async_trait; use crate::{ + executor::IpaRuntime, helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, routing::{Addr, RouteId}, @@ -13,17 +14,19 @@ use crate::{ protocol::QueryId, query::{NewQueryError, QueryProcessor, QueryStatus}, sync::Arc, + utils::NonZeroU32PowerOfTwo, }; #[derive(Default)] pub struct AppConfig { - active_work: Option, + active_work: Option, key_registry: Option>, + runtime: IpaRuntime, } impl AppConfig { #[must_use] - pub fn with_active_work(mut self, active_work: Option) -> Self { + pub fn with_active_work(mut self, active_work: Option) -> Self { self.active_work = active_work; self } @@ -33,6 +36,12 @@ impl AppConfig { self.key_registry = Some(key_registry); self } + + #[must_use] + pub fn with_runtime(mut self, runtime: IpaRuntime) -> Self { + self.runtime = runtime; + self + } } pub struct Setup { @@ -60,7 +69,7 @@ impl Setup { #[must_use] pub fn new(config: AppConfig) -> (Self, HandlerRef) { let key_registry = config.key_registry.unwrap_or_else(KeyRegistry::empty); - let query_processor = QueryProcessor::new(key_registry, config.active_work); + let query_processor = QueryProcessor::new(key_registry, config.active_work, config.runtime); let handler = HandlerBox::empty(); let this = Self { query_processor, @@ -203,6 +212,10 @@ impl RequestHandler for Inner { let query_id = ext_query_id(&req)?; HelperResponse::from(qp.complete(query_id).await?) } + RouteId::KillQuery => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.kill(query_id)?) + } }) } } diff --git a/ipa-core/src/bin/crypto_util.rs b/ipa-core/src/bin/crypto_util.rs index 4ded7026c..99556089f 100644 --- a/ipa-core/src/bin/crypto_util.rs +++ b/ipa-core/src/bin/crypto_util.rs @@ -2,7 +2,7 @@ use std::fmt::Debug; use clap::{Parser, Subcommand}; use ipa_core::{ - cli::crypto::{decrypt_and_reconstruct, encrypt, DecryptArgs, EncryptArgs}, + cli::crypto::{DecryptArgs, EncryptArgs}, error::BoxError, }; @@ -24,8 +24,8 @@ enum CryptoUtilCommand { async fn main() -> Result<(), BoxError> { let args = Args::parse(); match args.action { - CryptoUtilCommand::Encrypt(encrypt_args) => encrypt(&encrypt_args)?, - CryptoUtilCommand::Decrypt(decrypt_args) => decrypt_and_reconstruct(decrypt_args).await?, + CryptoUtilCommand::Encrypt(encrypt_args) => encrypt_args.encrypt()?, + CryptoUtilCommand::Decrypt(decrypt_args) => decrypt_args.decrypt_and_reconstruct().await?, } Ok(()) } diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 02f1b2101..7c8190c20 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -2,7 +2,6 @@ use std::{ fs, io::BufReader, net::TcpListener, - num::NonZeroUsize, os::fd::{FromRawFd, RawFd}, path::{Path, PathBuf}, process, @@ -16,13 +15,16 @@ use ipa_core::{ }, config::{hpke_registry, HpkeServerConfig, NetworkConfig, ServerConfig, TlsConfig}, error::BoxError, + executor::IpaRuntime, helpers::HelperIdentity, - net::{ClientIdentity, HttpShardTransport, HttpTransport, MpcHelperClient}, - AppConfig, AppSetup, + net::{ClientIdentity, MpcHelperClient, MpcHttpTransport, ShardHttpTransport}, + sharding::ShardIndex, + AppConfig, AppSetup, NonZeroU32PowerOfTwo, }; +use tokio::runtime::Runtime; use tracing::{error, info}; -#[cfg(not(target_env = "msvc"))] +#[cfg(all(not(target_env = "msvc"), not(target_os = "macos")))] #[global_allocator] static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; @@ -93,7 +95,7 @@ struct ServerArgs { /// Override the amount of active work processed in parallel #[arg(long)] - active_work: Option, + active_work: Option, } #[derive(Debug, Subcommand)] @@ -126,7 +128,7 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { }), ) } - (None, None) => (ClientIdentity::Helper(my_identity), None), + (None, None) => (ClientIdentity::Header(my_identity), None), _ => panic!("should have been rejected by clap"), }; @@ -134,9 +136,12 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { private_key_file: sk_path, }); + let query_runtime = new_query_runtime(); let app_config = AppConfig::default() .with_key_registry(hpke_registry(mk_encryption.as_ref()).await?) - .with_active_work(args.active_work); + .with_active_work(args.active_work) + .with_runtime(IpaRuntime::from_tokio_runtime(&query_runtime)); + let (setup, handler) = AppSetup::new(app_config); let server_config = ServerConfig { @@ -154,17 +159,40 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { let network_config_path = args.network.as_deref().unwrap(); let network_config = NetworkConfig::from_toml_str(&fs::read_to_string(network_config_path)?)? .override_scheme(&scheme); - let clients = MpcHelperClient::from_conf(&network_config, &identity); - let (transport, server) = HttpTransport::new( + // TODO: Following is just temporary until Shard Transport is actually used. + let shard_clients_config = network_config.client.clone(); + let shard_server_config = server_config.clone(); + // --- + + let http_runtime = new_http_runtime(); + let clients = MpcHelperClient::from_conf( + &IpaRuntime::from_tokio_runtime(&http_runtime), + &network_config, + &identity, + ); + let (transport, server) = MpcHttpTransport::new( + IpaRuntime::from_tokio_runtime(&http_runtime), my_identity, server_config, network_config, - clients, + &clients, Some(handler), ); - let _app = setup.connect(transport.clone(), HttpShardTransport); + // TODO: Following is just temporary until Shard Transport is actually used. + let shard_network_config = NetworkConfig::new_shards(vec![], shard_clients_config); + let (shard_transport, _shard_server) = ShardHttpTransport::new( + IpaRuntime::from_tokio_runtime(&http_runtime), + ShardIndex::FIRST, + shard_server_config, + shard_network_config, + vec![], + None, + ); + // --- + + let _app = setup.connect(transport.clone(), shard_transport.clone()); let listener = args.server_socket_fd .map(|fd| { @@ -184,18 +212,67 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { let (_addr, server_handle) = server .start_on( + &IpaRuntime::from_tokio_runtime(&http_runtime), listener, // TODO, trace based on the content of the query. None as Option<()>, ) .await; - server_handle.await?; + server_handle.await; + [query_runtime, http_runtime].map(Runtime::shutdown_background); Ok(()) } -#[tokio::main] +/// Creates a new runtime for HTTP stack. It is useful to provide a dedicated +/// scheduler to HTTP tasks, to make sure IPA server can respond to requests, +/// if for some reason query runtime becomes overloaded. +/// When multi-threading feature is enabled it creates a runtime with thread-per-core, +/// otherwise a single-threaded runtime is created. +fn new_http_runtime() -> Runtime { + if cfg!(feature = "multi-threading") { + tokio::runtime::Builder::new_multi_thread() + .thread_name("http-worker") + .enable_all() + .build() + .unwrap() + } else { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .thread_name("http-worker") + .enable_all() + .build() + .unwrap() + } +} + +/// This function creates a runtime suitable for executing MPC queries. +/// When multi-threading feature is enabled it creates a runtime with thread-per-core, +/// otherwise a single-threaded runtime is created. +fn new_query_runtime() -> Runtime { + // it is intentional that IO driver is not enabled here (enable_time() call only). + // query runtime is supposed to use CPU/memory only, no writes to disk and all + // network communication is handled by HTTP runtime. + if cfg!(feature = "multi-threading") { + tokio::runtime::Builder::new_multi_thread() + .thread_name("query-executor") + .enable_time() + .build() + .unwrap() + } else { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .thread_name("query-executor") + .enable_time() + .build() + .unwrap() + } +} + +/// A single thread is enough here, because server spawns additional +/// runtimes to use in MPC queries and HTTP. +#[tokio::main(flavor = "current_thread")] pub async fn main() { let args = Args::parse(); let _handle = args.logging.setup_logging(); diff --git a/ipa-core/src/bin/in_the_clear.rs b/ipa-core/src/bin/in_the_clear.rs new file mode 100644 index 000000000..16b2235df --- /dev/null +++ b/ipa-core/src/bin/in_the_clear.rs @@ -0,0 +1,72 @@ +use std::{error::Error, fs::File, io::Write, num::NonZeroU32, path::PathBuf}; + +use clap::Parser; +use ipa_core::{ + cli::{playbook::InputSource, Verbosity}, + test_fixture::hybrid::{hybrid_in_the_clear, TestHybridRecord}, +}; + +#[derive(Debug, Parser)] +pub struct CommandInput { + #[arg( + long, + help = "Read the input from the provided file, instead of standard input" + )] + input_file: Option, +} + +impl From<&CommandInput> for InputSource { + fn from(source: &CommandInput) -> Self { + if let Some(ref file_name) = source.input_file { + InputSource::from_file(file_name) + } else { + InputSource::from_stdin() + } + } +} + +#[derive(Debug, Parser)] +#[clap(name = "in_the_clear", about = "In the Clear CLI")] +#[command(about)] +struct Args { + #[clap(flatten)] + logging: Verbosity, + + #[clap(flatten)] + input: CommandInput, + + /// The destination file for output. + #[arg(long, value_name = "OUTPUT_FILE")] + output_file: PathBuf, + + #[arg(long, default_value = "20")] + max_breakdown_key: NonZeroU32, +} + +fn main() -> Result<(), Box> { + let args = Args::parse(); + let _handle = args.logging.setup_logging(); + + let input = InputSource::from(&args.input); + + let input_rows = input.iter::().collect::>(); + let expected = hybrid_in_the_clear( + &input_rows, + usize::try_from(args.max_breakdown_key.get()).unwrap(), + ); + + let mut file = File::options() + .write(true) + .create_new(true) + .open(&args.output_file) + .map_err(|e| { + format!( + "Failed to create output file {}: {e}", + &args.output_file.display() + ) + })?; + + write!(file, "{}", serde_json::to_string_pretty(&expected)?)?; + + Ok(()) +} diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 6debb6ead..38750e578 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -10,22 +10,23 @@ use std::{ }; use clap::{Parser, Subcommand}; -use comfy_table::{Cell, Table}; use hyper::http::uri::Scheme; use ipa_core::{ cli::{ - noise::{apply, ApplyDpArgs}, - playbook::{make_clients, playbook_oprf_ipa, validate, validate_dp, InputSource}, + playbook::{ + make_clients, playbook_oprf_ipa, run_query_and_validate, validate, validate_dp, + InputSource, + }, CsvSerializer, IpaQueryResult, Verbosity, }, config::{KeyRegistries, NetworkConfig}, ff::{boolean_array::BA32, FieldType}, helpers::query::{DpMechanism, IpaQueryConfig, QueryConfig, QuerySize, QueryType}, - net::MpcHelperClient, - report::DEFAULT_KEY_ID, + net::{Helper, MpcHelperClient}, + report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID}, test_fixture::{ - ipa::{ipa_in_the_clear, CappingOrder, IpaQueryStyle, IpaSecurityModel, TestRawDataRecord}, - EventGenerator, EventGeneratorConfig, + ipa::{ipa_in_the_clear, CappingOrder, IpaSecurityModel, TestRawDataRecord}, + EventGenerator, EventGeneratorConfig, HybridEventGenerator, HybridGeneratorConfig, }, }; use rand::{distributions::Alphanumeric, rngs::StdRng, thread_rng, Rng}; @@ -54,7 +55,7 @@ struct Args { input: CommandInput, /// The destination file for output. - #[arg(long, value_name = "FILE")] + #[arg(long, value_name = "OUTPUT_FILE")] output_file: Option, #[command(subcommand)] @@ -95,10 +96,42 @@ enum ReportCollectorCommand { #[clap(flatten)] gen_args: EventGeneratorConfig, }, - /// Apply differential privacy noise to IPA inputs - ApplyDpNoise(ApplyDpArgs), - /// Execute OPRF IPA in a semi-honest majority setting - OprfIpa(IpaQueryConfig), + GenHybridInputs { + /// Number of records to generate + #[clap(long, short = 'n')] + count: u32, + + /// The seed for random generator. + #[clap(long, short = 's')] + seed: Option, + + #[clap(flatten)] + gen_args: HybridGeneratorConfig, + }, + /// Execute OPRF IPA in a semi-honest majority setting with known test data + /// and compare results against expectation + SemiHonestOprfIpaTest(IpaQueryConfig), + /// Execute OPRF IPA in an honest majority (one malicious helper) setting + /// with known test data and compare results against expectation + MaliciousOprfIpaTest(IpaQueryConfig), + /// Execute OPRF IPA in a semi-honest majority setting with unknown encrypted data + #[command(visible_alias = "oprf-ipa")] + SemiHonestOprfIpa { + #[clap(flatten)] + encrypted_inputs: EncryptedInputs, + + #[clap(flatten)] + ipa_query_config: IpaQueryConfig, + }, + /// Execute OPRF IPA in an honest majority (one malicious helper) setting + /// with unknown encrypted data + MaliciousOprfIpa { + #[clap(flatten)] + encrypted_inputs: EncryptedInputs, + + #[clap(flatten)] + ipa_query_config: IpaQueryConfig, + }, } #[derive(Debug, clap::Args)] @@ -110,6 +143,21 @@ struct GenInputArgs { breakdowns: u32, } +#[derive(Debug, Parser)] +struct EncryptedInputs { + /// The encrypted input for H1 + #[arg(long, value_name = "H1_ENCRYPTED_INPUT_FILE")] + enc_input_file1: PathBuf, + + /// The encrypted input for H2 + #[arg(long, value_name = "H2_ENCRYPTED_INPUT_FILE")] + enc_input_file2: PathBuf, + + /// The encrypted input for H3 + #[arg(long, value_name = "H3_ENCRYPTED_INPUT_FILE")] + enc_input_file3: PathBuf, +} + #[tokio::main] async fn main() -> Result<(), Box> { let args = Args::parse(); @@ -128,15 +176,54 @@ async fn main() -> Result<(), Box> { seed, gen_args, } => gen_inputs(count, seed, args.output_file, gen_args)?, - ReportCollectorCommand::ApplyDpNoise(ref dp_args) => apply_dp_noise(&args, dp_args)?, - ReportCollectorCommand::OprfIpa(config) => { - ipa( + ReportCollectorCommand::GenHybridInputs { + count, + seed, + gen_args, + } => gen_hybrid_inputs(count, seed, args.output_file, gen_args)?, + ReportCollectorCommand::SemiHonestOprfIpaTest(config) => { + ipa_test( &args, &network, IpaSecurityModel::SemiHonest, config, &clients, - IpaQueryStyle::Oprf, + ) + .await? + } + ReportCollectorCommand::MaliciousOprfIpaTest(config) => { + ipa_test( + &args, + &network, + IpaSecurityModel::Malicious, + config, + &clients, + ) + .await? + } + ReportCollectorCommand::MaliciousOprfIpa { + ref encrypted_inputs, + ipa_query_config, + } => { + ipa( + &args, + IpaSecurityModel::Malicious, + ipa_query_config, + &clients, + encrypted_inputs, + ) + .await? + } + ReportCollectorCommand::SemiHonestOprfIpa { + ref encrypted_inputs, + ipa_query_config, + } => { + ipa( + &args, + IpaSecurityModel::SemiHonest, + ipa_query_config, + &clients, + encrypted_inputs, ) .await? } @@ -145,6 +232,31 @@ async fn main() -> Result<(), Box> { Ok(()) } +fn gen_hybrid_inputs( + count: u32, + seed: Option, + output_file: Option, + args: HybridGeneratorConfig, +) -> io::Result<()> { + let rng = seed + .map(StdRng::seed_from_u64) + .unwrap_or_else(StdRng::from_entropy); + let event_gen = HybridEventGenerator::with_config(rng, args).take(count as usize); + + let mut writer: Box = if let Some(path) = output_file { + Box::new(OpenOptions::new().write(true).create_new(true).open(path)?) + } else { + Box::new(stdout().lock()) + }; + + for event in event_gen { + event.to_csv(&mut writer)?; + writer.write_all(b"\n")?; + } + + Ok(()) +} + fn gen_inputs( count: u32, seed: Option, @@ -171,25 +283,111 @@ fn gen_inputs( Ok(()) } +fn get_query_type(security_model: IpaSecurityModel, ipa_query_config: IpaQueryConfig) -> QueryType { + match security_model { + IpaSecurityModel::SemiHonest => QueryType::SemiHonestOprfIpa(ipa_query_config), + IpaSecurityModel::Malicious => QueryType::MaliciousOprfIpa(ipa_query_config), + } +} + +fn write_ipa_output_file( + path: &PathBuf, + query_result: &IpaQueryResult, +) -> Result<(), Box> { + // it will be sad to lose the results if file already exists. + let path = if Path::is_file(path) { + let mut new_file_name = thread_rng() + .sample_iter(&Alphanumeric) + .take(5) + .map(char::from) + .collect::(); + let file_name = path.file_stem().ok_or("not a file")?; + + new_file_name.insert(0, '-'); + new_file_name.insert_str(0, &file_name.to_string_lossy()); + tracing::warn!( + "{} file exists, renaming to {:?}", + path.display(), + new_file_name + ); + + // it will not be 100% accurate until file_prefix API is stabilized + Cow::Owned( + path.with_file_name(&new_file_name) + .with_extension(path.extension().unwrap_or("".as_ref())), + ) + } else { + Cow::Borrowed(path) + }; + let mut file = File::options() + .write(true) + .create_new(true) + .open(path.deref()) + .map_err(|e| format!("Failed to create output file {}: {e}", path.display()))?; + + write!(file, "{}", serde_json::to_string_pretty(query_result)?)?; + Ok(()) +} + async fn ipa( args: &Args, - network: &NetworkConfig, security_model: IpaSecurityModel, ipa_query_config: IpaQueryConfig, helper_clients: &[MpcHelperClient; 3], - query_style: IpaQueryStyle, + encrypted_inputs: &EncryptedInputs, ) -> Result<(), Box> { - let input = InputSource::from(&args.input); - let query_type: QueryType; - match (security_model, &query_style) { - (IpaSecurityModel::SemiHonest, IpaQueryStyle::Oprf) => { - query_type = QueryType::OprfIpa(ipa_query_config); - } - (IpaSecurityModel::Malicious, IpaQueryStyle::Oprf) => { - panic!("OPRF for malicious is not implemented as yet") - } + let query_type = get_query_type(security_model, ipa_query_config); + + let files = [ + &encrypted_inputs.enc_input_file1, + &encrypted_inputs.enc_input_file2, + &encrypted_inputs.enc_input_file3, + ]; + + let encrypted_oprf_report_streams = EncryptedOprfReportStreams::from(files); + + let query_config = QueryConfig { + size: QuerySize::try_from(encrypted_oprf_report_streams.query_size).unwrap(), + field_type: FieldType::Fp32BitPrime, + query_type, }; + let query_id = helper_clients[0] + .create_query(query_config) + .await + .expect("Unable to create query!"); + + tracing::info!("Starting query for OPRF"); + // the value for histogram values (BA32) must be kept in sync with the server-side + // implementation, otherwise a runtime reconstruct error will be generated. + // see ipa-core/src/query/executor.rs + let actual = run_query_and_validate::( + encrypted_oprf_report_streams.streams, + encrypted_oprf_report_streams.query_size, + helper_clients, + query_id, + ipa_query_config, + ) + .await; + + if let Some(ref path) = args.output_file { + write_ipa_output_file(path, &actual)?; + } else { + println!("{}", serde_json::to_string_pretty(&actual)?); + } + Ok(()) +} + +async fn ipa_test( + args: &Args, + network: &NetworkConfig, + security_model: IpaSecurityModel, + ipa_query_config: IpaQueryConfig, + helper_clients: &[MpcHelperClient; 3], +) -> Result<(), Box> { + let input = InputSource::from(&args.input); + let query_type = get_query_type(security_model, ipa_query_config); + let input_rows = input.iter::().collect::>(); let query_config = QueryConfig { size: QuerySize::try_from(input_rows.len()).unwrap(), @@ -207,9 +405,7 @@ async fn ipa( ipa_query_config.per_user_credit_cap, ipa_query_config.attribution_window_seconds, ipa_query_config.max_breakdown_key, - &(match query_style { - IpaQueryStyle::Oprf => CappingOrder::CapMostRecentFirst, - }), + &CappingOrder::CapMostRecentFirst, ); // pad the output vector to the max breakdown key, to make sure it is aligned with the MPC results @@ -225,55 +421,20 @@ async fn ipa( let Some(key_registries) = key_registries.init_from(network) else { panic!("could not load network file") }; - let actual = match query_style { - IpaQueryStyle::Oprf => { - // the value for histogram values (BA32) must be kept in sync with the server-side - // implementation, otherwise a runtime reconstruct error will be generated. - // see ipa-core/src/query/executor.rs - playbook_oprf_ipa::( - input_rows, - helper_clients, - query_id, - ipa_query_config, - Some((DEFAULT_KEY_ID, key_registries)), - ) - .await - } - }; + // the value for histogram values (BA32) must be kept in sync with the server-side + // implementation, otherwise a runtime reconstruct error will be generated. + // see ipa-core/src/query/executor.rs + let actual = playbook_oprf_ipa::( + input_rows, + helper_clients, + query_id, + ipa_query_config, + Some((DEFAULT_KEY_ID, key_registries)), + ) + .await; if let Some(ref path) = args.output_file { - // it will be sad to lose the results if file already exists. - let path = if Path::is_file(path) { - let mut new_file_name = thread_rng() - .sample_iter(&Alphanumeric) - .take(5) - .map(char::from) - .collect::(); - let file_name = path.file_stem().ok_or("not a file")?; - - new_file_name.insert(0, '-'); - new_file_name.insert_str(0, &file_name.to_string_lossy()); - tracing::warn!( - "{} file exists, renaming to {:?}", - path.display(), - new_file_name - ); - - // it will not be 100% accurate until file_prefix API is stabilized - Cow::Owned( - path.with_file_name(&new_file_name) - .with_extension(path.extension().unwrap_or("".as_ref())), - ) - } else { - Cow::Borrowed(path) - }; - let mut file = File::options() - .write(true) - .create_new(true) - .open(path.deref()) - .map_err(|e| format!("Failed to create output file {}: {e}", path.display()))?; - - write!(file, "{}", serde_json::to_string_pretty(&actual)?)?; + write_ipa_output_file(path, &actual)?; } tracing::info!("{m:?}", m = ipa_query_config); @@ -297,50 +458,3 @@ async fn ipa( Ok(()) } - -fn apply_dp_noise(args: &Args, dp_args: &ApplyDpArgs) -> Result<(), Box> { - let IpaQueryResult { breakdowns, .. } = - serde_json::from_slice(&InputSource::from(&args.input).to_vec()?)?; - - let output = apply(&breakdowns, dp_args); - let mut table = Table::new(); - let header = std::iter::once("Epsilon".to_string()) - .chain(std::iter::once("Variance".to_string())) - .chain(std::iter::once("Mean".to_string())) - .chain((0..breakdowns.len()).map(|i| format!("{}", i + 1))) - .collect::>(); - table.set_header(header); - - // original values - table.add_row( - std::iter::repeat("-".to_string()) - .take(3) - .chain(breakdowns.iter().map(ToString::to_string)), - ); - - // reverse because smaller epsilon means more noise and I print the original values - // in the first row. - for epsilon in output.keys().rev() { - let noised_values = output.get(epsilon).unwrap(); - let mut row = vec![ - Cell::new(format!("{:.3}", epsilon)), - Cell::new(format!("{:.3}", noised_values.std)), - Cell::new(format!("{:.3}", noised_values.mean)), - ]; - - for agg in noised_values.breakdowns.iter() { - row.push(Cell::new(format!("{}", agg))); - } - - table.add_row(row); - } - - println!("{}", table); - - if let Some(file) = &args.output_file { - let mut file = File::create(file)?; - serde_json::to_writer_pretty(&mut file, &output)?; - } - - Ok(()) -} diff --git a/ipa-core/src/bin/test_mpc.rs b/ipa-core/src/bin/test_mpc.rs index 74e2e7284..9da4afbb2 100644 --- a/ipa-core/src/bin/test_mpc.rs +++ b/ipa-core/src/bin/test_mpc.rs @@ -85,6 +85,11 @@ enum TestAction { /// All helpers add their shares locally and set the resulting share to be the /// sum. No communication is required to run the circuit. AddInPrimeField, + /// A test protocol for sharded MPCs. The goal here is to use + /// both shard-to-shard and helper-to-helper communication channels. + /// This is exactly what shuffle does and that's why it is picked + /// for this purpose. + ShardedShuffle, } #[tokio::main] @@ -102,6 +107,7 @@ async fn main() -> Result<(), Box> { match args.action { TestAction::Multiply => multiply(&args, &clients).await, TestAction::AddInPrimeField => add(&args, &clients).await, + TestAction::ShardedShuffle => sharded_shuffle(&args, &clients).await, }; Ok(()) @@ -159,3 +165,7 @@ async fn add(args: &Args, helper_clients: &[MpcHelperClient; 3]) { FieldType::Fp32BitPrime => add_in_field::(args, helper_clients).await, }; } + +async fn sharded_shuffle(_args: &Args, _helper_clients: &[MpcHelperClient; 3]) { + unimplemented!() +} diff --git a/ipa-core/src/cli/clientconf.rs b/ipa-core/src/cli/clientconf.rs index 42835bdd0..341a4253a 100644 --- a/ipa-core/src/cli/clientconf.rs +++ b/ipa-core/src/cli/clientconf.rs @@ -186,7 +186,7 @@ fn assert_network_config(config_toml: &Map, config_str: &str) { else { panic!("peers section in toml config is not a table"); }; - for (i, peer_config_actual) in nw_config.peers.iter().enumerate() { + for (i, peer_config_actual) in nw_config.peers().iter().enumerate() { assert_peer_config(&peer_config_expected[i], peer_config_actual); } } diff --git a/ipa-core/src/cli/crypto.rs b/ipa-core/src/cli/crypto.rs deleted file mode 100644 index c7c16ab40..000000000 --- a/ipa-core/src/cli/crypto.rs +++ /dev/null @@ -1,685 +0,0 @@ -use std::{ - fmt::Debug, - fs::{read_to_string, File, OpenOptions}, - io::{BufRead, BufReader, Write}, - iter::zip, - path::PathBuf, -}; - -use clap::Parser; -use rand::thread_rng; - -use crate::{ - cli::playbook::{BreakdownKey, InputSource, Timestamp, TriggerValue}, - config::{hpke_registry, HpkeServerConfig, KeyRegistries, NetworkConfig}, - error::BoxError, - ff::{ - boolean_array::{BA20, BA3, BA8}, - U128Conversions, - }, - hpke::{KeyRegistry, PrivateKeyOnly}, - report::{EncryptedOprfReport, EventType, InvalidReportError, OprfReport, DEFAULT_KEY_ID}, - secret_sharing::IntoShares, - test_fixture::{ipa::TestRawDataRecord, Reconstruct}, -}; - -#[derive(Debug, Parser)] -#[clap(name = "test_encrypt", about = "Test Encrypt")] -#[command(about)] -pub struct EncryptArgs { - /// Path to file to secret share and encrypt - #[arg(long)] - input_file: PathBuf, - // /// The destination dir for encrypted output. - // /// In that dir, it will create helper1.enc, - // /// helper2.enc, and helper3.enc - #[arg(long, value_name = "FILE")] - output_dir: PathBuf, - /// Path to helper network configuration file - #[arg(long)] - network: PathBuf, -} - -#[derive(Debug, Parser)] -#[clap(name = "test_decrypt", about = "Test Decrypt")] -#[command(about)] -pub struct DecryptArgs { - /// Path to helper1 file to decrypt - #[arg(long)] - input_file1: PathBuf, - - /// Helper1 Private key for decrypting match keys - #[arg(long)] - mk_private_key1: PathBuf, - - /// Path to helper2 file to decrypt - #[arg(long)] - input_file2: PathBuf, - - /// Helper2 Private key for decrypting match keys - #[arg(long)] - mk_private_key2: PathBuf, - - /// Path to helper3 file to decrypt - #[arg(long)] - input_file3: PathBuf, - - /// Helper3 Private key for decrypting match keys - #[arg(long)] - mk_private_key3: PathBuf, - - /// The destination file for decrypted output. - #[arg(long, value_name = "FILE")] - output_file: PathBuf, -} - -/// # Panics -/// if input file or network file are not correctly formatted -/// # Errors -/// if it cannot open the files -pub fn encrypt(args: &EncryptArgs) -> Result<(), BoxError> { - let input = InputSource::from_file(&args.input_file); - - let mut rng = thread_rng(); - let mut key_registries = KeyRegistries::default(); - - let network = NetworkConfig::from_toml_str( - &read_to_string(&args.network) - .unwrap_or_else(|e| panic!("Failed to open network file: {:?}. {}", &args.network, e)), - ) - .unwrap_or_else(|e| { - panic!( - "Failed to parse network file into toml: {:?}. {}", - &args.network, e - ) - }); - let Some(key_registries) = key_registries.init_from(&network) else { - panic!("could not load network file") - }; - - let shares: [Vec>; 3] = - input.iter::().share(); - - for (index, (shares, key_registry)) in zip(shares, key_registries).enumerate() { - let output_filename = format!("helper{}.enc", index + 1); - let mut writer = OpenOptions::new() - .write(true) - .create_new(true) - .open(args.output_dir.join(&output_filename)) - .unwrap_or_else(|e| panic!("unable write to {}. {}", &output_filename, e)); - - for share in shares { - let output = share - .encrypt(DEFAULT_KEY_ID, key_registry, &mut rng) - .unwrap(); - let hex_output = hex::encode(&output); - writeln!(writer, "{hex_output}")?; - } - } - - Ok(()) -} - -async fn build_hpke_registry( - private_key_file: PathBuf, -) -> Result, BoxError> { - let mk_encryption = Some(HpkeServerConfig::File { private_key_file }); - let key_registry = hpke_registry(mk_encryption.as_ref()).await?; - Ok(key_registry) -} - -struct DecryptedReports { - filename: PathBuf, - reader: BufReader, - key_registry: KeyRegistry, - iter_index: usize, -} - -impl DecryptedReports { - fn new(filename: &PathBuf, key_registry: KeyRegistry) -> Self { - let file = File::open(filename) - .unwrap_or_else(|e| panic!("unable to open file {filename:?}. {e}")); - let reader = BufReader::new(file); - Self { - filename: filename.clone(), - reader, - key_registry, - iter_index: 0, - } - } -} - -impl Iterator for DecryptedReports { - type Item = Result, InvalidReportError>; - - fn next(&mut self) -> Option { - let mut line = String::new(); - if self.reader.read_line(&mut line).unwrap() > 0 { - self.iter_index += 1; - let encrypted_report_bytes = hex::decode(line.trim()).unwrap(); - let enc_report = - EncryptedOprfReport::from_bytes(encrypted_report_bytes.as_slice()).unwrap(); - let dec_report = enc_report.decrypt(&self.key_registry); - match dec_report { - Ok(dec_report) => Some(Ok(dec_report)), - Err(e) => { - eprintln!( - "Decryption failed: File: {0}. Record: {1}. Error: {e}.", - self.filename.display(), - self.iter_index - ); - Some(Err(e)) - } - } - } else { - None - } - } -} - -/// # Panics -// if input files or private_keys are not correctly formatted -/// # Errors -/// if it cannot open the files -pub async fn decrypt_and_reconstruct(args: DecryptArgs) -> Result<(), BoxError> { - let key_registry1 = build_hpke_registry(args.mk_private_key1).await?; - let key_registry2 = build_hpke_registry(args.mk_private_key2).await?; - let key_registry3 = build_hpke_registry(args.mk_private_key3).await?; - let decrypted_reports1 = DecryptedReports::new(&args.input_file1, key_registry1); - let decrypted_reports2 = DecryptedReports::new(&args.input_file2, key_registry2); - let decrypted_reports3 = DecryptedReports::new(&args.input_file3, key_registry3); - - let mut writer = Box::new( - OpenOptions::new() - .write(true) - .create_new(true) - .open(args.output_file)?, - ); - for (dec_report1, (dec_report2, dec_report3)) in - decrypted_reports1.zip(decrypted_reports2.zip(decrypted_reports3)) - { - if let (Ok(dec_report1), Ok(dec_report2), Ok(dec_report3)) = - (dec_report1, dec_report2, dec_report3) - { - let timestamp = [ - dec_report1.timestamp, - dec_report2.timestamp, - dec_report3.timestamp, - ] - .reconstruct() - .as_u128(); - - let match_key = [ - dec_report1.match_key, - dec_report2.match_key, - dec_report3.match_key, - ] - .reconstruct() - .as_u128(); - - // these aren't reconstucted, so we explictly make sure - // they are consistent across all three files, then set - // it to the first one (without loss of generality) - assert_eq!(dec_report1.event_type, dec_report2.event_type); - assert_eq!(dec_report2.event_type, dec_report3.event_type); - let is_trigger_report = dec_report1.event_type == EventType::Trigger; - - let breakdown_key = [ - dec_report1.breakdown_key, - dec_report2.breakdown_key, - dec_report3.breakdown_key, - ] - .reconstruct() - .as_u128(); - - let trigger_value = [ - dec_report1.trigger_value, - dec_report2.trigger_value, - dec_report3.trigger_value, - ] - .reconstruct() - .as_u128(); - - writeln!( - writer, - "{},{},{},{},{}", - timestamp, - match_key, - u8::from(is_trigger_report), - breakdown_key, - trigger_value, - )?; - } - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use std::{ - fs::File, - io::{BufRead, BufReader, Write}, - path::Path, - sync::Arc, - }; - - use bytes::BufMut; - use clap::Parser; - use hpke::Deserializable; - use rand::thread_rng; - use tempfile::{tempdir, NamedTempFile}; - - use crate::{ - cli::{ - crypto::{decrypt_and_reconstruct, encrypt, DecryptArgs, EncryptArgs}, - CsvSerializer, - }, - ff::{boolean_array::BA16, U128Conversions}, - helpers::{ - query::{IpaQueryConfig, QuerySize}, - BodyStream, - }, - hpke::{IpaPrivateKey, KeyRegistry, PrivateKeyOnly}, - query::OprfIpaQuery, - test_fixture::{ - ipa::TestRawDataRecord, join3v, EventGenerator, EventGeneratorConfig, Reconstruct, - TestWorld, - }, - }; - - fn are_files_equal(file1: &Path, file2: &Path) { - let file1 = - File::open(file1).unwrap_or_else(|e| panic!("unable to open {}: {e}", file1.display())); - let file2 = - File::open(file2).unwrap_or_else(|e| panic!("unable to open {}: {e}", file2.display())); - let reader1 = BufReader::new(file1).lines(); - let mut reader2 = BufReader::new(file2).lines(); - for line1 in reader1 { - let line2 = reader2.next().expect("Files have different lengths"); - assert_eq!(line1.unwrap(), line2.unwrap()); - } - assert!(reader2.next().is_none(), "Files have different lengths"); - } - - fn write_input_file() -> NamedTempFile { - let count = 10; - let rng = thread_rng(); - let event_gen_args = EventGeneratorConfig::new(10, 5, 20, 1, 10, 604_800); - - let event_gen = EventGenerator::with_config(rng, event_gen_args) - .take(count) - .collect::>(); - let mut input = NamedTempFile::new().unwrap(); - - for event in event_gen { - let _ = event.to_csv(input.as_file_mut()); - writeln!(input.as_file()).unwrap(); - } - input.as_file_mut().flush().unwrap(); - input - } - - fn write_network_file() -> NamedTempFile { - let network_data = r#" -[[peers]] -url = "helper1.test" -[peers.hpke] -public_key = "92a6fb666c37c008defd74abf3204ebea685742eab8347b08e2f7c759893947a" -[[peers]] -url = "helper2.test" -[peers.hpke] -public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" -[[peers]] -url = "helper3.test" -[peers.hpke] -public_key = "b900be35da06106a83ed73c33f733e03e4ea5888b7ea4c912ab270b0b0f8381e" -"#; - let mut network = NamedTempFile::new().unwrap(); - writeln!(network.as_file_mut(), "{network_data}").unwrap(); - network - } - - fn write_mk_private_key(mk_private_key_data: &str) -> NamedTempFile { - let mut mk_private_key = NamedTempFile::new().unwrap(); - writeln!(mk_private_key.as_file_mut(), "{mk_private_key_data}").unwrap(); - mk_private_key - } - - fn build_encrypt_args( - input_file: &Path, - output_dir: &Path, - network_file: &Path, - ) -> EncryptArgs { - EncryptArgs::try_parse_from([ - "test_encrypt", - "--input-file", - input_file.to_str().unwrap(), - "--output-dir", - output_dir.to_str().unwrap(), - "--network", - network_file.to_str().unwrap(), - ]) - .unwrap() - } - - fn build_decrypt_args( - enc1: &Path, - enc2: &Path, - enc3: &Path, - mk_private_key1: &Path, - mk_private_key2: &Path, - mk_private_key3: &Path, - decrypt_output: &Path, - ) -> DecryptArgs { - DecryptArgs::try_parse_from([ - "test_decrypt", - "--input-file1", - enc1.to_str().unwrap(), - "--input-file2", - enc2.to_str().unwrap(), - "--input-file3", - enc3.to_str().unwrap(), - "--mk-private-key1", - mk_private_key1.to_str().unwrap(), - "--mk-private-key2", - mk_private_key2.to_str().unwrap(), - "--mk-private-key3", - mk_private_key3.to_str().unwrap(), - "--output-file", - decrypt_output.to_str().unwrap(), - ]) - .unwrap() - } - - #[test] - #[should_panic = "Failed to open network file:"] - fn encrypt_no_network_file() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_dir = tempdir().unwrap(); - let network_file = network_dir.path().join("does_not_exist"); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.as_path()); - let _ = encrypt(&encrypt_args); - } - - #[test] - #[should_panic = "TOML parse error at"] - fn encrypt_bad_network_file() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_data = r" -this is not toml! -%^& weird characters -(\deadbeef>? -"; - let mut network_file = NamedTempFile::new().unwrap(); - writeln!(network_file.as_file_mut(), "{network_data}").unwrap(); - - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - } - - #[test] - #[should_panic = "invalid length 2, expected an array of length 3"] - fn encrypt_incomplete_network_file() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_data = r#" -[[peers]] -url = "helper1.test" -[peers.hpke] -public_key = "92a6fb666c37c008defd74abf3204ebea685742eab8347b08e2f7c759893947a" -[[peers]] -url = "helper2.test" -[peers.hpke] -public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" -"#; - let mut network_file = NamedTempFile::new().unwrap(); - writeln!(network_file.as_file_mut(), "{network_data}").unwrap(); - - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - } - - #[tokio::test] - #[should_panic = "No such file or directory (os error 2)"] - async fn decrypt_no_enc_file() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_file = write_network_file(); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - - let decrypt_output = output_dir.path().join("output"); - let enc1 = output_dir.path().join("DOES_NOT_EXIST.enc"); - let enc2 = output_dir.path().join("helper2.enc"); - let enc3 = output_dir.path().join("helper3.enc"); - - let mk_private_key1 = write_mk_private_key( - "53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff", - ); - let mk_private_key2 = write_mk_private_key( - "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", - ); - let mk_private_key3 = write_mk_private_key( - "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", - ); - - let decrypt_args = build_decrypt_args( - enc1.as_path(), - enc2.as_path(), - enc3.as_path(), - mk_private_key1.path(), - mk_private_key2.path(), - mk_private_key3.path(), - &decrypt_output, - ); - let _ = decrypt_and_reconstruct(decrypt_args).await; - } - - #[tokio::test] - #[should_panic = "called `Result::unwrap()` on an `Err` value: Crypt(Other)"] - async fn decrypt_bad_private_key() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_file = write_network_file(); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - - let decrypt_output = output_dir.path().join("output"); - let enc1 = output_dir.path().join("helper1.enc"); - let enc2 = output_dir.path().join("helper2.enc"); - let enc3 = output_dir.path().join("helper3.enc"); - let mk_private_key1 = write_mk_private_key( - "bad9fdc79d98471cedd07ee6743d3bb43aabbddabc49cd9fae1d5daef3f2b3ba", - ); - let mk_private_key2 = write_mk_private_key( - "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", - ); - let mk_private_key3 = write_mk_private_key( - "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", - ); - - let decrypt_args = build_decrypt_args( - enc1.as_path(), - enc2.as_path(), - enc3.as_path(), - mk_private_key1.path(), - mk_private_key2.path(), - mk_private_key3.path(), - &decrypt_output, - ); - let _ = decrypt_and_reconstruct(decrypt_args).await; - } - - #[tokio::test] - async fn encrypt_and_decrypt() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_file = write_network_file(); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - - let decrypt_output = output_dir.path().join("output"); - let enc1 = output_dir.path().join("helper1.enc"); - let enc2 = output_dir.path().join("helper2.enc"); - let enc3 = output_dir.path().join("helper3.enc"); - let mk_private_key1 = write_mk_private_key( - "53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff", - ); - let mk_private_key2 = write_mk_private_key( - "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", - ); - let mk_private_key3 = write_mk_private_key( - "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", - ); - - let decrypt_args = build_decrypt_args( - enc1.as_path(), - enc2.as_path(), - enc3.as_path(), - mk_private_key1.path(), - mk_private_key2.path(), - mk_private_key3.path(), - &decrypt_output, - ); - let _ = decrypt_and_reconstruct(decrypt_args).await; - - are_files_equal(input_file.path(), &decrypt_output); - } - - #[tokio::test] - async fn encrypt_and_execute_query() { - const EXPECTED: &[u128] = &[0, 8, 5]; - - let records: Vec = vec![ - TestRawDataRecord { - timestamp: 0, - user_id: 12345, - is_trigger_report: false, - breakdown_key: 2, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 4, - user_id: 68362, - is_trigger_report: false, - breakdown_key: 1, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 10, - user_id: 12345, - is_trigger_report: true, - breakdown_key: 0, - trigger_value: 5, - }, - TestRawDataRecord { - timestamp: 12, - user_id: 68362, - is_trigger_report: true, - breakdown_key: 0, - trigger_value: 2, - }, - TestRawDataRecord { - timestamp: 20, - user_id: 68362, - is_trigger_report: false, - breakdown_key: 1, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 30, - user_id: 68362, - is_trigger_report: true, - breakdown_key: 1, - trigger_value: 7, - }, - ]; - let query_size = QuerySize::try_from(records.len()).unwrap(); - let mut input_file = NamedTempFile::new().unwrap(); - - for event in records { - let _ = event.to_csv(input_file.as_file_mut()); - writeln!(input_file.as_file()).unwrap(); - } - input_file.as_file_mut().flush().unwrap(); - - let output_dir = tempdir().unwrap(); - let network_file = write_network_file(); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - - let enc1 = output_dir.path().join("helper1.enc"); - let enc2 = output_dir.path().join("helper2.enc"); - let enc3 = output_dir.path().join("helper3.enc"); - - let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new()); - for (i, path) in [enc1, enc2, enc3].iter().enumerate() { - let file = File::open(path).unwrap(); - let reader = BufReader::new(file); - for line in reader.lines() { - let line = line.unwrap(); - let encrypted_report_bytes = hex::decode(line.trim()).unwrap(); - println!("{}", encrypted_report_bytes.len()); - buffers[i].put_u16_le(encrypted_report_bytes.len().try_into().unwrap()); - buffers[i].put_slice(encrypted_report_bytes.as_slice()); - } - } - - let world = TestWorld::default(); - let contexts = world.contexts(); - - let mk_private_keys = vec![ - hex::decode("53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff") - .expect("manually provided for test"), - hex::decode("3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569") - .expect("manually provided for test"), - hex::decode("1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7") - .expect("manually provided for test"), - ]; - - #[allow(clippy::large_futures)] - let results = join3v(buffers.into_iter().zip(contexts).zip(mk_private_keys).map( - |((buffer, ctx), mk_private_key)| { - let query_config = IpaQueryConfig { - per_user_credit_cap: 8, - attribution_window_seconds: None, - max_breakdown_key: 3, - with_dp: 0, - epsilon: 1.0, - plaintext_match_keys: false, - }; - let input = BodyStream::from(buffer); - - let private_registry = - Arc::new(KeyRegistry::::from_keys([PrivateKeyOnly( - IpaPrivateKey::from_bytes(&mk_private_key) - .expect("manually constructed for test"), - )])); - - OprfIpaQuery::>::new( - query_config, - private_registry, - ) - .execute(ctx, query_size, input) - }, - )) - .await; - - assert_eq!( - results.reconstruct()[0..3] - .iter() - .map(U128Conversions::as_u128) - .collect::>(), - EXPECTED - ); - } -} diff --git a/ipa-core/src/cli/crypto/decrypt.rs b/ipa-core/src/cli/crypto/decrypt.rs new file mode 100644 index 000000000..63ec8e3f9 --- /dev/null +++ b/ipa-core/src/cli/crypto/decrypt.rs @@ -0,0 +1,278 @@ +use std::{ + fs::{File, OpenOptions}, + io::{BufRead, BufReader, Write}, + path::{Path, PathBuf}, +}; + +use clap::Parser; + +use crate::{ + config::{hpke_registry, HpkeServerConfig}, + error::BoxError, + ff::{ + boolean_array::{BA20, BA3, BA8}, + U128Conversions, + }, + hpke::{KeyRegistry, PrivateKeyOnly}, + report::{EncryptedOprfReport, EventType, OprfReport}, + test_fixture::Reconstruct, +}; + +#[derive(Debug, Parser)] +#[clap(name = "test_decrypt", about = "Test Decrypt")] +#[command(about)] +pub struct DecryptArgs { + /// Path to helper1 file to decrypt + #[arg(long)] + input_file1: PathBuf, + + /// Helper1 Private key for decrypting match keys + #[arg(long)] + mk_private_key1: PathBuf, + + /// Path to helper2 file to decrypt + #[arg(long)] + input_file2: PathBuf, + + /// Helper2 Private key for decrypting match keys + #[arg(long)] + mk_private_key2: PathBuf, + + /// Path to helper3 file to decrypt + #[arg(long)] + input_file3: PathBuf, + + /// Helper3 Private key for decrypting match keys + #[arg(long)] + mk_private_key3: PathBuf, + + /// The destination file for decrypted output. + #[arg(long, value_name = "FILE")] + output_file: PathBuf, +} + +impl DecryptArgs { + #[must_use] + pub fn new( + input_file1: &Path, + input_file2: &Path, + input_file3: &Path, + mk_private_key1: &Path, + mk_private_key2: &Path, + mk_private_key3: &Path, + output_file: &Path, + ) -> Self { + Self { + input_file1: input_file1.to_path_buf(), + mk_private_key1: mk_private_key1.to_path_buf(), + input_file2: input_file2.to_path_buf(), + mk_private_key2: mk_private_key2.to_path_buf(), + input_file3: input_file3.to_path_buf(), + mk_private_key3: mk_private_key3.to_path_buf(), + output_file: output_file.to_path_buf(), + } + } + + /// # Panics + // if input files or private_keys are not correctly formatted + /// # Errors + /// if it cannot open the files + pub async fn decrypt_and_reconstruct(self) -> Result<(), BoxError> { + let Self { + input_file1, + mk_private_key1, + input_file2, + mk_private_key2, + input_file3, + mk_private_key3, + output_file, + } = self; + let key_registry1 = build_hpke_registry(mk_private_key1).await?; + let key_registry2 = build_hpke_registry(mk_private_key2).await?; + let key_registry3 = build_hpke_registry(mk_private_key3).await?; + let decrypted_reports1 = DecryptedReports::new(&input_file1, key_registry1); + let decrypted_reports2 = DecryptedReports::new(&input_file2, key_registry2); + let decrypted_reports3 = DecryptedReports::new(&input_file3, key_registry3); + + let mut writer = Box::new( + OpenOptions::new() + .write(true) + .create_new(true) + .open(output_file)?, + ); + + for (dec_report1, (dec_report2, dec_report3)) in + decrypted_reports1.zip(decrypted_reports2.zip(decrypted_reports3)) + { + let timestamp = [ + dec_report1.timestamp, + dec_report2.timestamp, + dec_report3.timestamp, + ] + .reconstruct() + .as_u128(); + + let match_key = [ + dec_report1.match_key, + dec_report2.match_key, + dec_report3.match_key, + ] + .reconstruct() + .as_u128(); + + // these aren't reconstucted, so we explictly make sure + // they are consistent across all three files, then set + // it to the first one (without loss of generality) + assert_eq!(dec_report1.event_type, dec_report2.event_type); + assert_eq!(dec_report2.event_type, dec_report3.event_type); + let is_trigger_report = dec_report1.event_type == EventType::Trigger; + + let breakdown_key = [ + dec_report1.breakdown_key, + dec_report2.breakdown_key, + dec_report3.breakdown_key, + ] + .reconstruct() + .as_u128(); + + let trigger_value = [ + dec_report1.trigger_value, + dec_report2.trigger_value, + dec_report3.trigger_value, + ] + .reconstruct() + .as_u128(); + + writeln!( + writer, + "{},{},{},{},{}", + timestamp, + match_key, + u8::from(is_trigger_report), + breakdown_key, + trigger_value, + )?; + } + + Ok(()) + } +} + +struct DecryptedReports { + reader: BufReader, + key_registry: KeyRegistry, +} + +impl Iterator for DecryptedReports { + type Item = OprfReport; + + fn next(&mut self) -> Option { + let mut line = String::new(); + if self.reader.read_line(&mut line).unwrap() > 0 { + let encrypted_report_bytes = hex::decode(line.trim()).unwrap(); + let enc_report = + EncryptedOprfReport::from_bytes(encrypted_report_bytes.as_slice()).unwrap(); + let dec_report: OprfReport = + enc_report.decrypt(&self.key_registry).unwrap(); + Some(dec_report) + } else { + None + } + } +} + +impl DecryptedReports { + fn new(filename: &PathBuf, key_registry: KeyRegistry) -> Self { + let file = File::open(filename) + .unwrap_or_else(|e| panic!("unable to open file {filename:?}. {e}")); + let reader = BufReader::new(file); + Self { + reader, + key_registry, + } + } +} + +async fn build_hpke_registry( + private_key_file: PathBuf, +) -> Result, BoxError> { + let mk_encryption = Some(HpkeServerConfig::File { private_key_file }); + let key_registry = hpke_registry(mk_encryption.as_ref()).await?; + Ok(key_registry) +} + +#[cfg(test)] +mod tests { + + use tempfile::tempdir; + + use crate::cli::crypto::{decrypt::DecryptArgs, encrypt::EncryptArgs, sample_data}; + + #[tokio::test] + #[should_panic = "No such file or directory (os error 2)"] + async fn decrypt_no_enc_file() { + let input_file = sample_data::write_csv(sample_data::test_ipa_data().take(10)).unwrap(); + + let output_dir = tempdir().unwrap(); + let network_file = sample_data::test_keys().network_config(); + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + + let decrypt_output = output_dir.path().join("output"); + let enc1 = output_dir.path().join("DOES_NOT_EXIST.enc"); + let enc2 = output_dir.path().join("helper2.enc"); + let enc3 = output_dir.path().join("helper3.enc"); + + let [mk_private_key1, mk_private_key2, mk_private_key3] = + sample_data::test_keys().sk_files(); + + let decrypt_args = DecryptArgs::new( + enc1.as_path(), + enc2.as_path(), + enc3.as_path(), + mk_private_key1.path(), + mk_private_key2.path(), + mk_private_key3.path(), + &decrypt_output, + ); + decrypt_args.decrypt_and_reconstruct().await.unwrap(); + } + + #[tokio::test] + #[should_panic = "called `Result::unwrap()` on an `Err` value: Crypt(Other)"] + async fn decrypt_bad_private_key() { + let input_file = sample_data::write_csv(sample_data::test_ipa_data().take(10)).unwrap(); + + let network_file = sample_data::test_keys().network_config(); + let output_dir = tempdir().unwrap(); + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + + let decrypt_output = output_dir.path().join("output"); + let enc1 = output_dir.path().join("helper1.enc"); + let enc2 = output_dir.path().join("helper2.enc"); + let enc3 = output_dir.path().join("helper3.enc"); + + // corrupt the secret key for H1 + let mut keys = sample_data::test_keys().clone(); + let mut sk = keys.get_sk(0); + sk[0] = !sk[0]; + keys.set_sk(0, sk); + let [mk_private_key1, mk_private_key2, mk_private_key3] = keys.sk_files(); + + DecryptArgs::new( + enc1.as_path(), + enc2.as_path(), + enc3.as_path(), + mk_private_key1.path(), + mk_private_key2.path(), + mk_private_key3.path(), + &decrypt_output, + ) + .decrypt_and_reconstruct() + .await + .unwrap(); + } +} diff --git a/ipa-core/src/cli/crypto/encrypt.rs b/ipa-core/src/cli/crypto/encrypt.rs new file mode 100644 index 000000000..c2ee6ea84 --- /dev/null +++ b/ipa-core/src/cli/crypto/encrypt.rs @@ -0,0 +1,269 @@ +use std::{ + fs::{read_to_string, OpenOptions}, + io::Write, + iter::zip, + path::{Path, PathBuf}, +}; + +use clap::Parser; +use rand::thread_rng; + +use crate::{ + cli::playbook::{BreakdownKey, InputSource, Timestamp, TriggerValue}, + config::{KeyRegistries, NetworkConfig}, + error::BoxError, + report::{OprfReport, DEFAULT_KEY_ID}, + secret_sharing::IntoShares, + test_fixture::ipa::TestRawDataRecord, +}; + +#[derive(Debug, Parser)] +#[clap(name = "test_encrypt", about = "Test Encrypt")] +#[command(about)] +pub struct EncryptArgs { + /// Path to file to secret share and encrypt + #[arg(long)] + input_file: PathBuf, + /// The destination dir for encrypted output. + /// In that dir, it will create helper1.enc, + /// helper2.enc, and helper3.enc + #[arg(long, value_name = "FILE")] + output_dir: PathBuf, + /// Path to helper network configuration file + #[arg(long)] + network: PathBuf, +} + +impl EncryptArgs { + #[must_use] + pub fn new(input_file: &Path, output_dir: &Path, network: &Path) -> Self { + Self { + input_file: input_file.to_path_buf(), + output_dir: output_dir.to_path_buf(), + network: network.to_path_buf(), + } + } + + /// # Panics + /// if input file or network file are not correctly formatted + /// # Errors + /// if it cannot open the files + pub fn encrypt(&self) -> Result<(), BoxError> { + let input = InputSource::from_file(&self.input_file); + + let mut rng = thread_rng(); + let mut key_registries = KeyRegistries::default(); + + let network = + NetworkConfig::from_toml_str(&read_to_string(&self.network).unwrap_or_else(|e| { + panic!("Failed to open network file: {:?}. {}", &self.network, e) + })) + .unwrap_or_else(|e| { + panic!( + "Failed to parse network file into toml: {:?}. {}", + &self.network, e + ) + }); + let Some(key_registries) = key_registries.init_from(&network) else { + panic!("could not load network file") + }; + + let shares: [Vec>; 3] = + input.iter::().share(); + + for (index, (shares, key_registry)) in zip(shares, key_registries).enumerate() { + let output_filename = format!("helper{}.enc", index + 1); + let mut writer = OpenOptions::new() + .write(true) + .create_new(true) + .open(self.output_dir.join(&output_filename)) + .unwrap_or_else(|e| panic!("unable write to {}. {}", &output_filename, e)); + + for share in shares { + let output = share + .encrypt(DEFAULT_KEY_ID, key_registry, &mut rng) + .unwrap(); + let hex_output = hex::encode(&output); + writeln!(writer, "{hex_output}")?; + } + } + + Ok(()) + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::{io::Write, sync::Arc}; + + use hpke::Deserializable; + use tempfile::{tempdir, NamedTempFile}; + + use crate::{ + cli::{ + crypto::{encrypt::EncryptArgs, sample_data}, + CsvSerializer, + }, + ff::{boolean_array::BA16, U128Conversions}, + helpers::query::{IpaQueryConfig, QuerySize}, + hpke::{IpaPrivateKey, KeyRegistry, PrivateKeyOnly}, + query::OprfIpaQuery, + report::EncryptedOprfReportStreams, + test_fixture::{ipa::TestRawDataRecord, join3v, Reconstruct, TestWorld}, + }; + + #[tokio::test] + async fn encrypt_and_execute_query() { + const EXPECTED: &[u128] = &[0, 2, 5]; + + let records = vec![ + TestRawDataRecord { + timestamp: 0, + user_id: 12345, + is_trigger_report: false, + breakdown_key: 2, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 4, + user_id: 68362, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 10, + user_id: 12345, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 5, + }, + TestRawDataRecord { + timestamp: 12, + user_id: 68362, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 2, + }, + ]; + let query_size = QuerySize::try_from(records.len()).unwrap(); + let mut input_file = NamedTempFile::new().unwrap(); + + for event in records { + event.to_csv(input_file.as_file_mut()).unwrap(); + writeln!(input_file.as_file()).unwrap(); + } + input_file.flush().unwrap(); + + let output_dir = tempdir().unwrap(); + let network_file = sample_data::test_keys().network_config(); + + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + + let files = [ + &output_dir.path().join("helper1.enc"), + &output_dir.path().join("helper2.enc"), + &output_dir.path().join("helper3.enc"), + ]; + + let world = TestWorld::default(); + + let mk_private_keys = [ + "53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff", + "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", + "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", + ]; + + #[allow(clippy::large_futures)] + let results = join3v( + EncryptedOprfReportStreams::from(files) + .streams + .into_iter() + .zip(world.contexts()) + .zip(mk_private_keys.into_iter()) + .map(|((input, ctx), mk_private_key)| { + let mk_private_key = hex::decode(mk_private_key) + .map(|bytes| IpaPrivateKey::from_bytes(&bytes).unwrap()) + .unwrap(); + let query_config = IpaQueryConfig { + max_breakdown_key: 3, + with_dp: 0, + epsilon: 1.0, + ..Default::default() + }; + + OprfIpaQuery::<_, BA16, _>::new( + query_config, + Arc::new(KeyRegistry::from_keys([PrivateKeyOnly(mk_private_key)])), + ) + .execute(ctx, query_size, input) + }), + ) + .await; + + assert_eq!( + results.reconstruct()[0..3] + .iter() + .map(U128Conversions::as_u128) + .collect::>(), + EXPECTED + ); + } + + #[test] + #[should_panic = "Failed to open network file:"] + fn encrypt_no_network_file() { + let input_file = sample_data::write_csv(sample_data::test_ipa_data().take(10)).unwrap(); + + let output_dir = tempdir().unwrap(); + let network_dir = tempdir().unwrap(); + let network_file = network_dir.path().join("does_not_exist"); + EncryptArgs::new(input_file.path(), output_dir.path(), &network_file) + .encrypt() + .unwrap(); + } + + #[test] + #[should_panic = "TOML parse error at"] + fn encrypt_bad_network_file() { + let input_file = sample_data::write_csv(sample_data::test_ipa_data().take(10)).unwrap(); + let output_dir = tempdir().unwrap(); + let network_data = r" +this is not toml! +%^& weird characters +(\deadbeef>? +"; + let mut network_file = NamedTempFile::new().unwrap(); + writeln!(network_file.as_file_mut(), "{network_data}").unwrap(); + + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + } + + #[test] + #[should_panic = "Expected a Vec of length 3 but it was 2"] + fn encrypt_incomplete_network_file() { + let input_file = sample_data::write_csv(sample_data::test_ipa_data().take(10)).unwrap(); + + let output_dir = tempdir().unwrap(); + let network_data = r#" +[[peers]] +url = "helper1.test" +[peers.hpke] +public_key = "92a6fb666c37c008defd74abf3204ebea685742eab8347b08e2f7c759893947a" +[[peers]] +url = "helper2.test" +[peers.hpke] +public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" +"#; + let mut network_file = NamedTempFile::new().unwrap(); + writeln!(network_file.as_file_mut(), "{network_data}").unwrap(); + + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + } +} diff --git a/ipa-core/src/cli/crypto/mod.rs b/ipa-core/src/cli/crypto/mod.rs new file mode 100644 index 000000000..0bcb1f629 --- /dev/null +++ b/ipa-core/src/cli/crypto/mod.rs @@ -0,0 +1,198 @@ +mod decrypt; +mod encrypt; + +pub use decrypt::DecryptArgs; +pub use encrypt::EncryptArgs; + +#[cfg(test)] +mod sample_data { + use std::{io, io::Write, sync::OnceLock}; + + use hpke::{Deserializable, Serializable}; + use rand::thread_rng; + use tempfile::NamedTempFile; + + use crate::{ + cli::CsvSerializer, + hpke::{IpaPrivateKey, IpaPublicKey}, + test_fixture::{ipa::TestRawDataRecord, EventGenerator, EventGeneratorConfig}, + }; + + /// Keys that are used in crypto tests + #[derive(Clone)] + pub(super) struct TestKeys { + key_pairs: [(IpaPublicKey, IpaPrivateKey); 3], + } + + static TEST_KEYS: OnceLock = OnceLock::new(); + pub fn test_keys() -> &'static TestKeys { + TEST_KEYS.get_or_init(TestKeys::new) + } + + impl TestKeys { + pub fn new() -> Self { + Self { + key_pairs: [ + ( + decode_key::<_, IpaPublicKey>( + "92a6fb666c37c008defd74abf3204ebea685742eab8347b08e2f7c759893947a", + ), + decode_key::<_, IpaPrivateKey>( + "53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff", + ), + ), + ( + decode_key::<_, IpaPublicKey>( + "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e", + ), + decode_key::<_, IpaPrivateKey>( + "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", + ), + ), + ( + decode_key::<_, IpaPublicKey>( + "b900be35da06106a83ed73c33f733e03e4ea5888b7ea4c912ab270b0b0f8381e", + ), + decode_key::<_, IpaPrivateKey>( + "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", + ), + ), + ], + } + } + + pub fn network_config(&self) -> NamedTempFile { + let mut file = NamedTempFile::new().unwrap(); + let [pk1, pk2, pk3] = self.key_pairs.each_ref().map(|(pk, _)| pk); + let [pk1, pk2, pk3] = [ + hex::encode(pk1.to_bytes()), + hex::encode(pk2.to_bytes()), + hex::encode(pk3.to_bytes()), + ]; + let network_data = format!( + r#" + [[peers]] + url = "helper1.test" + [peers.hpke] + public_key = "{pk1}" + [[peers]] + url = "helper2.test" + [peers.hpke] + public_key = "{pk2}" + [[peers]] + url = "helper3.test" + [peers.hpke] + public_key = "{pk3}" + "# + ); + file.write_all(network_data.as_bytes()).unwrap(); + + file + } + + pub fn set_sk>(&mut self, idx: usize, data: I) { + self.key_pairs[idx].1 = IpaPrivateKey::from_bytes(data.as_ref()).unwrap(); + } + + pub fn get_sk(&self, idx: usize) -> Vec { + self.key_pairs[idx].1.to_bytes().to_vec() + } + + pub fn sk_files(&self) -> [NamedTempFile; 3] { + self.key_pairs.each_ref().map(|(_, sk)| sk).map(|sk| { + let mut file = NamedTempFile::new().unwrap(); + file.write_all(hex::encode(sk.to_bytes()).as_bytes()) + .unwrap(); + file.flush().unwrap(); + + file + }) + } + } + + fn decode_key, T: Deserializable>(input: I) -> T { + let bytes = hex::decode(input).unwrap(); + T::from_bytes(&bytes).unwrap() + } + + pub fn test_ipa_data() -> impl Iterator { + let rng = thread_rng(); + let event_gen_args = EventGeneratorConfig::new(10, 5, 20, 1, 10, 604_800); + + EventGenerator::with_config(rng, event_gen_args) + } + + pub fn write_csv( + data: impl Iterator, + ) -> Result { + let mut file = NamedTempFile::new()?; + for event in data { + let () = event.to_csv(&mut file)?; + writeln!(file)?; + } + + file.flush()?; + + Ok(file) + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::{ + fs::File, + io::{BufRead, BufReader}, + path::Path, + }; + + use tempfile::tempdir; + + use crate::cli::crypto::{decrypt::DecryptArgs, encrypt::EncryptArgs, sample_data}; + + fn are_files_equal(file1: &Path, file2: &Path) { + let file1 = + File::open(file1).unwrap_or_else(|e| panic!("unable to open {}: {e}", file1.display())); + let file2 = + File::open(file2).unwrap_or_else(|e| panic!("unable to open {}: {e}", file2.display())); + let reader1 = BufReader::new(file1).lines(); + let mut reader2 = BufReader::new(file2).lines(); + for line1 in reader1 { + let line2 = reader2.next().expect("Files have different lengths"); + assert_eq!(line1.unwrap(), line2.unwrap()); + } + assert!(reader2.next().is_none(), "Files have different lengths"); + } + + #[tokio::test] + async fn encrypt_and_decrypt() { + let output_dir = tempdir().unwrap(); + let input = sample_data::test_ipa_data().take(10); + let input_file = sample_data::write_csv(input).unwrap(); + let network_file = sample_data::test_keys().network_config(); + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + + let decrypt_output = output_dir.path().join("output"); + let enc1 = output_dir.path().join("helper1.enc"); + let enc2 = output_dir.path().join("helper2.enc"); + let enc3 = output_dir.path().join("helper3.enc"); + let [mk_private_key1, mk_private_key2, mk_private_key3] = + sample_data::test_keys().sk_files(); + + DecryptArgs::new( + enc1.as_path(), + enc2.as_path(), + enc3.as_path(), + mk_private_key1.path(), + mk_private_key2.path(), + mk_private_key3.path(), + &decrypt_output, + ) + .decrypt_and_reconstruct() + .await + .unwrap(); + + are_files_equal(input_file.path(), &decrypt_output); + } +} diff --git a/ipa-core/src/cli/csv.rs b/ipa-core/src/cli/csv.rs index 37772ac55..621c9b352 100644 --- a/ipa-core/src/cli/csv.rs +++ b/ipa-core/src/cli/csv.rs @@ -20,3 +20,22 @@ impl Serializer for crate::test_fixture::ipa::TestRawDataRecord { Ok(()) } } + +#[cfg(any(test, feature = "test-fixture"))] +impl Serializer for crate::test_fixture::hybrid::TestHybridRecord { + fn to_csv(&self, buf: &mut W) -> io::Result<()> { + match self { + crate::test_fixture::hybrid::TestHybridRecord::TestImpression { + match_key, + breakdown_key, + } => { + write!(buf, "i,{match_key},{breakdown_key}")?; + } + crate::test_fixture::hybrid::TestHybridRecord::TestConversion { match_key, value } => { + write!(buf, "c,{match_key},{value}")?; + } + } + + Ok(()) + } +} diff --git a/ipa-core/src/cli/metric_collector.rs b/ipa-core/src/cli/metric_collector.rs index 17fd72705..881e775f5 100644 --- a/ipa-core/src/cli/metric_collector.rs +++ b/ipa-core/src/cli/metric_collector.rs @@ -31,6 +31,7 @@ pub fn install_collector() -> CollectorHandle { // register metrics crate::telemetry::metrics::register(); + tracing::info!("Metrics enabled"); CollectorHandle { snapshotter } } diff --git a/ipa-core/src/cli/mod.rs b/ipa-core/src/cli/mod.rs index 28847f416..467425785 100644 --- a/ipa-core/src/cli/mod.rs +++ b/ipa-core/src/cli/mod.rs @@ -1,19 +1,12 @@ #[cfg(feature = "web-app")] mod clientconf; -#[cfg(all( - feature = "test-fixture", - feature = "web-app", - feature = "cli", - feature = "in-memory-infra" -))] +#[cfg(all(feature = "test-fixture", feature = "web-app", feature = "cli",))] pub mod crypto; mod csv; mod ipa_output; #[cfg(feature = "web-app")] mod keygen; mod metric_collector; -#[cfg(feature = "cli")] -pub mod noise; mod paths; #[cfg(all(feature = "test-fixture", feature = "web-app", feature = "cli"))] pub mod playbook; diff --git a/ipa-core/src/cli/noise.rs b/ipa-core/src/cli/noise.rs deleted file mode 100644 index f83f93174..000000000 --- a/ipa-core/src/cli/noise.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::{ - cmp::Ordering, - collections::BTreeMap, - fmt::{Debug, Display, Formatter}, -}; - -use clap::Args; -use rand::rngs::StdRng; -use rand_core::SeedableRng; -use serde::{Deserialize, Serialize, Serializer}; - -use crate::protocol::ipa_prf::oprf_padding::InsecureDiscreteDp; - -#[derive(Debug, Args)] -#[clap(about = "Apply differential privacy noise to the given input")] -pub struct ApplyDpArgs { - /// Various epsilon values to use inside the DP. - #[arg(long, short = 'e')] - epsilon: Vec, - - /// Delta parameter for (\epsilon, \delta) DP. - #[arg(long, short = 'd', default_value = "1e-7")] - delta: f64, - - /// Seed for the random number generator. - #[arg(long, short = 's')] - seed: Option, - - /// The sensitivity of the input or maximum contribution allowed per user to preserve privacy. - #[arg(long, short = 'c')] - cap: u32, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct NoisyOutput { - /// Aggregated breakdowns with noise applied. It is important to use unsigned values here - /// to avoid bias/mean skew - pub breakdowns: Box<[i64]>, - pub mean: f64, - pub std: f64, -} - -/// This exists to be able to use f64 as key inside a map. We don't have to deal with infinities or -/// NaN values for epsilons, so we can treat them as raw bytes for this purpose. -#[derive(Debug, Copy, Clone)] -pub struct EpsilonBits(f64); - -impl Serialize for EpsilonBits { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(&self.0.to_string()) - } -} - -impl From for EpsilonBits { - fn from(value: f64) -> Self { - assert!(value.is_finite()); - Self(value) - } -} - -// the following implementations are fine because NaN values are rejected from inside `From` - -impl PartialEq for EpsilonBits { - fn eq(&self, other: &Self) -> bool { - self.0.to_bits().eq(&other.0.to_bits()) - } -} - -impl Eq for EpsilonBits {} - -impl PartialOrd for EpsilonBits { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for EpsilonBits { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.0.partial_cmp(&other.0).unwrap() - } -} - -impl Display for EpsilonBits { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - Display::fmt(&self.0, f) - } -} - -/// Apply DP noise to the given input. -/// -/// ## Panics -/// If DP parameters are not valid. -pub fn apply>(input: I, args: &ApplyDpArgs) -> BTreeMap { - let mut rng = args - .seed - .map_or_else(StdRng::from_entropy, StdRng::seed_from_u64); - let mut result = BTreeMap::new(); - for &epsilon in &args.epsilon { - let discrete_dp = - InsecureDiscreteDp::new(epsilon, args.delta, f64::from(args.cap)).unwrap(); - let mut v = input - .as_ref() - .iter() - .copied() - .map(i64::from) - .collect::>(); - discrete_dp.apply(v.as_mut_slice(), &mut rng); - - result.insert( - epsilon.into(), - NoisyOutput { - breakdowns: v.into_boxed_slice(), - mean: discrete_dp.mean(), - std: discrete_dp.std(), - }, - ); - } - - result -} diff --git a/ipa-core/src/cli/playbook/input.rs b/ipa-core/src/cli/playbook/input.rs index 70efa3bcb..5015d5587 100644 --- a/ipa-core/src/cli/playbook/input.rs +++ b/ipa-core/src/cli/playbook/input.rs @@ -7,8 +7,9 @@ use std::{ }; use crate::{ - cli::playbook::generator::U128Generator, ff::U128Conversions, - test_fixture::ipa::TestRawDataRecord, + cli::playbook::generator::U128Generator, + ff::U128Conversions, + test_fixture::{hybrid::TestHybridRecord, ipa::TestRawDataRecord}, }; pub trait InputItem { @@ -56,6 +57,40 @@ impl InputItem for TestRawDataRecord { } } +impl InputItem for TestHybridRecord { + fn from_str(s: &str) -> Self { + if let [event_type, match_key, number] = s.splitn(3, ',').collect::>()[..] { + let match_key: u64 = match_key + .parse() + .unwrap_or_else(|e| panic!("Expected an u64, got {match_key}: {e}")); + + let number: u32 = number + .parse() + .unwrap_or_else(|e| panic!("Expected an u32, got {number}: {e}")); + + match event_type { + "i" => TestHybridRecord::TestImpression { + match_key, + breakdown_key: number, + }, + + "c" => TestHybridRecord::TestConversion { + match_key, + value: number, + }, + _ => panic!( + "{}", + format!( + "Invalid input. Rows should start with 'i' or 'c'. Did not expect {event_type}" + ) + ), + } + } else { + panic!("{s} is not a valid {}", type_name::()) + } + } +} + pub struct InputSource { inner: Box, sz: Option, @@ -203,13 +238,13 @@ mod tests { } #[test] - #[should_panic] + #[should_panic(expected = "ParseIntError")] fn parse_negative() { Fp31::from_str("-1"); } #[test] - #[should_panic] + #[should_panic(expected = "ParseIntError")] fn parse_empty() { Fp31::from_str(""); } @@ -229,7 +264,7 @@ mod tests { } #[test] - #[should_panic] + #[should_panic(expected = "ParseIntError")] fn tuple_parse_error() { <(Fp31, Fp31)>::from_str("20,"); } diff --git a/ipa-core/src/cli/playbook/ipa.rs b/ipa-core/src/cli/playbook/ipa.rs index b2c8050c6..5f56911f0 100644 --- a/ipa-core/src/cli/playbook/ipa.rs +++ b/ipa-core/src/cli/playbook/ipa.rs @@ -95,6 +95,8 @@ where run_query_and_validate::(inputs, query_size, clients, query_id, query_config).await } +/// # Panics +/// if results are invalid #[allow(clippy::disallowed_methods)] // allow try_join_all pub async fn run_query_and_validate( inputs: [BodyStream; 3], diff --git a/ipa-core/src/cli/playbook/mod.rs b/ipa-core/src/cli/playbook/mod.rs index df43911d8..4b5164f56 100644 --- a/ipa-core/src/cli/playbook/mod.rs +++ b/ipa-core/src/cli/playbook/mod.rs @@ -14,12 +14,13 @@ pub use input::InputSource; pub use multiply::secure_mul; use tokio::time::sleep; -pub use self::ipa::playbook_oprf_ipa; +pub use self::ipa::{playbook_oprf_ipa, run_query_and_validate}; use crate::{ config::{ClientConfig, NetworkConfig, PeerConfig}, + executor::IpaRuntime, ff::boolean_array::{BA20, BA3, BA8}, helpers::query::DpMechanism, - net::{ClientIdentity, MpcHelperClient}, + net::{ClientIdentity, Helper, MpcHelperClient}, protocol::{dp::NoiseParams, ipa_prf::oprf_padding::insecure::OPRFPaddingDp}, }; @@ -146,7 +147,6 @@ pub fn validate_dp( } else { next_actual_f64 }; - println!("next_actual_f64 = {next_actual_f64}, next_actual_f64_shifted = {next_actual_f64_shifted}"); let (_, std) = truncated_discrete_laplace.mean_and_std(); let tolerance_factor = 20.0; // set so this fails randomly with small probability @@ -194,25 +194,26 @@ pub async fn make_clients( network_path: Option<&Path>, scheme: Scheme, wait: usize, -) -> ([MpcHelperClient; 3], NetworkConfig) { +) -> ([MpcHelperClient; 3], NetworkConfig) { let mut wait = wait; let network = if let Some(path) = network_path { NetworkConfig::from_toml_str(&fs::read_to_string(path).unwrap()).unwrap() } else { - NetworkConfig { - peers: [ + NetworkConfig::::new_mpc( + vec![ PeerConfig::new("localhost:3000".parse().unwrap(), None), PeerConfig::new("localhost:3001".parse().unwrap(), None), PeerConfig::new("localhost:3002".parse().unwrap(), None), ], - client: ClientConfig::default(), - } + ClientConfig::default(), + ) }; let network = network.override_scheme(&scheme); // Note: This closure is only called when the selected action uses clients. - let clients = MpcHelperClient::from_conf(&network, &ClientIdentity::None); + let clients = + MpcHelperClient::from_conf(&IpaRuntime::current(), &network, &ClientIdentity::None); while wait > 0 && !clients_ready(&clients).await { tracing::debug!("waiting for servers to come up"); sleep(Duration::from_secs(1)).await; diff --git a/ipa-core/src/cli/verbosity.rs b/ipa-core/src/cli/verbosity.rs index 53a2bee39..068af04f5 100644 --- a/ipa-core/src/cli/verbosity.rs +++ b/ipa-core/src/cli/verbosity.rs @@ -32,24 +32,29 @@ impl Verbosity { #[must_use] pub fn setup_logging(&self) -> LoggingHandle { let filter_layer = self.log_filter(); + info!("Logging setup at level {}", filter_layer); + let fmt_layer = fmt::layer() .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) .with_ansi(std::io::stderr().is_terminal()) .with_writer(stderr); - tracing_subscriber::registry() - .with(self.log_filter()) - .with(fmt_layer) - .with(MetricsLayer::new()) - .init(); + let registry = tracing_subscriber::registry() + .with(filter_layer) + .with(fmt_layer); + + if cfg!(feature = "disable-metrics") { + registry.init(); + } else { + registry.with(MetricsLayer::new()).init(); + } let handle = LoggingHandle { - metrics_handle: (!self.quiet).then(install_collector), + metrics_handle: (!self.quiet && !cfg!(feature = "disable-metrics")) + .then(install_collector), }; set_global_panic_hook(); - info!("Logging setup at level {}", filter_layer); - handle } diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 0486ad490..50bd90f4b 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -1,13 +1,12 @@ use std::{ - array, borrow::{Borrow, Cow}, fmt::{Debug, Formatter}, - iter::Zip, + iter::zip, path::PathBuf, - slice, time::Duration, }; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use hyper::{http::uri::Scheme, Uri}; use hyper_util::client::legacy::Builder; use rustls_pemfile::Item; @@ -22,6 +21,8 @@ use crate::{ Deserializable as _, IpaPrivateKey, IpaPublicKey, KeyRegistry, PrivateKeyOnly, PublicKeyOnly, Serializable as _, }, + net::{ConnectionFlavor, Helper, Shard}, + sharding::ShardIndex, }; pub type OwnedCertificate = CertificateDer<'static>; @@ -37,23 +38,115 @@ pub enum Error { IOError(#[from] std::io::Error), } -/// Configuration information describing a helper network. +/// Configuration describing either 3 peers in a Ring or N shard peers. In a non-sharded case a +/// single [`NetworkConfig`] represents the entire network. In a sharded case, each host should +/// have one Ring and one Sharded configuration to know how to reach its peers. /// /// The most important thing this contains is discovery information for each of the participating -/// helpers. +/// peers. #[derive(Clone, Debug, Deserialize)] -pub struct NetworkConfig { - /// Information about each helper participating in the network. The order that helpers are - /// listed here determines their assigned helper identities in the network. Note that while the - /// helper identities are stable, roles are assigned per query. - pub peers: [PeerConfig; 3], +pub struct NetworkConfig { + peers: Vec, /// HTTP client configuration. #[serde(default)] pub client: ClientConfig, + + /// The identities of the index-matching peers. Separating this from [`Self::peers`](field) so + /// that parsing is easy to implement. + #[serde(skip)] + identities: Vec, +} + +impl NetworkConfig { + /// # Panics + /// If `PathAndQuery::from_str("")` fails + #[must_use] + pub fn override_scheme(self, scheme: &Scheme) -> Self { + Self { + peers: self + .peers + .into_iter() + .map(|mut peer| { + let mut parts = peer.url.into_parts(); + parts.scheme = Some(scheme.clone()); + // `http::uri::Uri::from_parts()` requires that a URI have a path if it has a + // scheme. If the URI does not have a scheme, it is not required to have a path. + if parts.path_and_query.is_none() { + parts.path_and_query = Some("".parse().unwrap()); + } + peer.url = Uri::try_from(parts).unwrap(); + peer + }) + .collect(), + ..self + } + } + + #[must_use] + pub fn vec_peers(&self) -> Vec { + self.peers.clone() + } + + #[must_use] + pub fn get_peer(&self, i: usize) -> Option<&PeerConfig> { + self.peers.get(i) + } + + pub fn peers_iter(&self) -> std::slice::Iter<'_, PeerConfig> { + self.peers.iter() + } + + /// We currently require an exact match with the peer cert (i.e. we don't support verifying + /// the certificate against a truststore and identifying the peer by the certificate + /// subject). This could be changed if the need arises. + #[must_use] + pub fn identify_cert(&self, cert: Option<&CertificateDer>) -> Option { + let cert = cert?; + for (id, p) in zip(self.identities.iter(), self.peers.iter()) { + if p.certificate.as_ref() == Some(cert) { + return Some(*id); + } + } + // It might be nice to log something here. We could log the certificate base64? + tracing::error!( + "A client certificate was presented that does not match a known helper. Certificate: {}", + BASE64.encode(cert), + ); + None + } +} + +impl NetworkConfig { + /// # Panics + /// In the unlikely event a usize cannot be turned into a u32 + #[must_use] + pub fn new_shards(peers: Vec, client: ClientConfig) -> Self { + let identities = (0u32..peers.len().try_into().unwrap()) + .map(ShardIndex::from) + .collect(); + Self { + peers, + client, + identities, + } + } } -impl NetworkConfig { +impl NetworkConfig { + /// Creates a new configuration for 3 MPC clients (ring) configuration. + /// # Panics + /// If the vector doesn't contain exactly 3 items. + #[must_use] + pub fn new_mpc(ring: Vec, client: ClientConfig) -> Self { + assert_eq!(3, ring.len()); + Self { + peers: ring, + client, + identities: HelperIdentity::make_three().to_vec(), + } + } + /// Reads config from string. Expects config to be toml format. /// To read file, use `fs::read_to_string` /// @@ -62,49 +155,25 @@ impl NetworkConfig { pub fn from_toml_str(input: &str) -> Result { use config::{Config, File, FileFormat}; - let conf: Self = Config::builder() + let mut conf: Self = Config::builder() .add_source(File::from_str(input, FileFormat::Toml)) .build()? .try_deserialize()?; - Ok(conf) - } - - pub fn new(peers: [PeerConfig; 3], client: ClientConfig) -> Self { - Self { peers, client } - } + conf.identities = HelperIdentity::make_three().to_vec(); - pub fn peers(&self) -> &[PeerConfig; 3] { - &self.peers - } - - // Can maybe be replaced with array::zip when stable? - pub fn enumerate_peers( - &self, - ) -> Zip, slice::Iter> { - HelperIdentity::make_three() - .into_iter() - .zip(self.peers.iter()) + Ok(conf) } + /// Clones the internal configs and returns them as an array. /// # Panics - /// If `PathAndQuery::from_str("")` fails + /// If the internal vector isn't of size 3. #[must_use] - pub fn override_scheme(self, scheme: &Scheme) -> NetworkConfig { - NetworkConfig { - peers: self.peers.map(|mut peer| { - let mut parts = peer.url.into_parts(); - parts.scheme = Some(scheme.clone()); - // `http::uri::Uri::from_parts()` requires that a URI have a path if it has a - // scheme. If the URI does not have a scheme, it is not required to have a path. - if parts.path_and_query.is_none() { - parts.path_and_query = Some("".parse().unwrap()); - } - peer.url = Uri::try_from(parts).unwrap(); - peer - }), - ..self - } + pub fn peers(&self) -> [PeerConfig; 3] { + self.peers + .clone() + .try_into() + .unwrap_or_else(|v: Vec<_>| panic!("Expected a Vec of length 3 but it was {}", v.len())) } } @@ -422,10 +491,11 @@ impl KeyRegistries { /// If network file is improperly formatted pub fn init_from( &mut self, - network: &NetworkConfig, + network: &NetworkConfig, ) -> Option<[&KeyRegistry; 3]> { // Get the configs, if all three peers have one - let configs = network.peers().iter().try_fold(Vec::new(), |acc, peer| { + let peers = network.peers(); + let configs = peers.iter().try_fold(Vec::new(), |acc, peer| { if let (mut vec, Some(hpke_config)) = (acc, peer.hpke_config.as_ref()) { vec.push(hpke_config); Some(vec) @@ -453,10 +523,12 @@ mod tests { use rand::rngs::StdRng; use rand_core::SeedableRng; + use super::{NetworkConfig, PeerConfig}; use crate::{ config::{ClientConfig, HpkeClientConfig, Http2Configurator, HttpClientConfigurator}, helpers::HelperIdentity, net::test::TestConfigBuilder, + sharding::ShardIndex, }; const URI_1: &str = "http://localhost:3000"; @@ -531,4 +603,13 @@ mod tests { }), ); } + + #[test] + fn indexing_peer_happy_case() { + let uri1 = URI_1.parse::().unwrap(); + let pc1 = PeerConfig::new(uri1, None); + let client = ClientConfig::default(); + let conf = NetworkConfig::new_shards(vec![pc1.clone()], client); + assert_eq!(conf.peers[ShardIndex(0)].url, pc1.url); + } } diff --git a/ipa-core/src/error.rs b/ipa-core/src/error.rs index 771cb1c0a..168827c8e 100644 --- a/ipa-core/src/error.rs +++ b/ipa-core/src/error.rs @@ -104,6 +104,8 @@ pub enum Error { }, #[error("The verification of the shuffle failed: {0}")] ShuffleValidationFailed(String), + #[error("Duplicate bytes found after {0} checks")] + DuplicateBytes(usize), } impl Default for Error { diff --git a/ipa-core/src/ff/boolean_array.rs b/ipa-core/src/ff/boolean_array.rs index 3df41f269..43bfee4a2 100644 --- a/ipa-core/src/ff/boolean_array.rs +++ b/ipa-core/src/ff/boolean_array.rs @@ -5,7 +5,7 @@ use bitvec::{ slice::Iter, }; use generic_array::GenericArray; -use typenum::{U14, U18, U2, U32, U8}; +use typenum::{U12, U14, U18, U2, U32, U8}; use crate::{ error::{Error, LengthError}, @@ -862,6 +862,9 @@ macro_rules! boolean_array_impl_large { //impl store for U8 store_impl!(U8, 64); +//impl store for U12 +store_impl!(U12, 96); + //impl store for U14 store_impl!(U14, 112); @@ -890,6 +893,7 @@ boolean_array_impl_small!(boolean_array_16, BA16, 16, infallible); boolean_array_impl_small!(boolean_array_20, BA20, 20, fallible); boolean_array_impl_small!(boolean_array_32, BA32, 32, infallible); boolean_array_impl_small!(boolean_array_64, BA64, 64, infallible); +boolean_array_impl_small!(boolean_array_96, BA96, 96, infallible); boolean_array_impl_small!(boolean_array_112, BA112, 112, infallible); boolean_array_impl_large!(boolean_array_144, BA144, 144, infallible, U18); boolean_array_impl_large!(boolean_array_256, BA256, 256, infallible, U32); diff --git a/ipa-core/src/helpers/buffers/unordered_receiver.rs b/ipa-core/src/helpers/buffers/unordered_receiver.rs index 92cfbf2e1..3377995cf 100644 --- a/ipa-core/src/helpers/buffers/unordered_receiver.rs +++ b/ipa-core/src/helpers/buffers/unordered_receiver.rs @@ -295,7 +295,6 @@ where inner: Arc>>, } -#[allow(dead_code)] impl UnorderedReceiver where S: Stream + Send, diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index b1c57b3cc..a25321da4 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -30,6 +30,7 @@ use crate::{ protocol::QueryId, sharding::ShardIndex, sync::{Arc, Mutex}, + utils::NonZeroU32PowerOfTwo, }; /// Alias for the currently configured transport. @@ -44,9 +45,9 @@ pub type MpcTransportImpl = TransportImpl; pub type ShardTransportImpl = TransportImpl; #[cfg(feature = "real-world-infra")] -pub type MpcTransportImpl = crate::sync::Arc; +pub type MpcTransportImpl = crate::net::MpcHttpTransport; #[cfg(feature = "real-world-infra")] -pub type ShardTransportImpl = crate::net::HttpShardTransport; +pub type ShardTransportImpl = crate::net::ShardHttpTransport; pub type MpcTransportError = ::Error; @@ -73,7 +74,7 @@ pub struct State { pub struct GatewayConfig { /// The number of items that can be active at the one time. /// This is used to determine the size of sending and receiving buffers. - pub active: NonZeroUsize, + pub active: NonZeroU32PowerOfTwo, /// Number of bytes packed and sent together in one batch down to the network layer. This /// shouldn't be too small to keep the network throughput, but setting it large enough may @@ -81,9 +82,20 @@ pub struct GatewayConfig { /// A rule of thumb is that this should get as close to network packet size as possible. /// /// This will be set for all channels and because they send records of different side, the actual - /// payload may not be exactly this, but it will be the closest multiple of record size to this - /// number. For instance, having 14 bytes records and batch size of 4096 will result in - /// 4088 bytes being sent in a batch. + /// payload may not be exactly this, but it will be the closest multiple of record size smaller than + /// or equal to number. For alignment reasons, this multiple will be a power of two, otherwise + /// a deadlock is possible. See ipa/#1300 for details how it can happen. + /// + /// For instance, having 14 bytes records and batch size of 4096 will result in + /// 3584 bytes being sent in a batch (`2^8 * 14 < 4096, 2^9 * 14 > 4096`). + /// + /// The consequence is that HTTP buffer size may not be perfectly aligned with the target. + /// As long as we use TCP it does not matter, but if we want to switch to UDP and have + /// precise control over the size of chunk sent, we should tune the buffer size at the + /// HTTP layer instead (using Hyper/H3 API or something like that). If we do this, then + /// read size becomes obsolete and should be removed in favor of flushing the entire + /// buffer chunks from the application layer down to HTTP and let network to figure out + /// the best way to slice this data before sending it to a peer. pub read_size: NonZeroUsize, /// Time to wait before checking gateway progress. If no progress has been made between @@ -150,12 +162,15 @@ impl Gateway { &self, channel_id: &HelperChannelId, total_records: TotalRecords, + active_work: NonZeroU32PowerOfTwo, ) -> send::SendingEnd { let transport = &self.transports.mpc; let channel = self.inner.mpc_senders.get::( channel_id, transport, - self.config, + // we override the active work provided in config if caller + // wants to use a different value. + self.config.set_active_work(active_work), self.query_id, total_records, ); @@ -257,6 +272,11 @@ impl GatewayConfig { /// The configured amount of active work. #[must_use] pub fn active_work(&self) -> NonZeroUsize { + self.active.to_non_zero_usize() + } + + #[must_use] + pub fn active_work_as_power_of_two(&self) -> NonZeroU32PowerOfTwo { self.active } @@ -276,32 +296,60 @@ impl GatewayConfig { // capabilities (see #ipa/1171) to allow that currently. usize::from(value.size), ), - ); + ) + .next_power_of_two(); // we set active to be at least 2, so unwrap is fine. - self.active = NonZeroUsize::new(active).unwrap(); + self.active = NonZeroU32PowerOfTwo::try_from(active).unwrap(); + } + + /// Creates a new configuration by overriding the value of active work. + #[must_use] + pub fn set_active_work(&self, active_work: NonZeroU32PowerOfTwo) -> Self { + Self { + active: active_work, + ..*self + } } } #[cfg(all(test, unit_test))] mod tests { - use std::iter::{repeat, zip}; + use std::{ + iter::{repeat, zip}, + sync::Arc, + }; use futures::{ future::{join, try_join, try_join_all}, + stream, stream::StreamExt, }; + use proptest::proptest; + use tokio::sync::Barrier; use crate::{ - ff::{boolean_array::BA3, Fp31, Fp32BitPrime, Gf2, U128Conversions}, - helpers::{Direction, GatewayConfig, MpcMessage, Role, SendingEnd}, + ff::{ + boolean_array::{BA20, BA256, BA3, BA4, BA5, BA6, BA7, BA8}, + FieldType, Fp31, Fp32BitPrime, Gf2, U128Conversions, + }, + helpers::{ + gateway::QueryConfig, + query::{QuerySize, QueryType}, + ChannelId, Direction, GatewayConfig, MpcMessage, MpcReceivingEnd, Role, SendingEnd, + TotalRecords, + }, protocol::{ context::{Context, ShardedContext}, - RecordId, + Gate, RecordId, }, - secret_sharing::replicated::semi_honest::AdditiveShare, + secret_sharing::{ + replicated::semi_honest::AdditiveShare, SharedValue, SharedValueArray, StdArray, + }, + seq_join::seq_join, sharding::ShardConfiguration, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards}, + utils::NonZeroU32PowerOfTwo, }; /// Verifies that [`Gateway`] send buffer capacity is adjusted to the message size. @@ -516,6 +564,129 @@ mod tests { }); } + #[test] + fn custom_active_work() { + run(|| async move { + let world = TestWorld::new_with(TestWorldConfig { + gateway_config: GatewayConfig { + active: 8.try_into().unwrap(), + ..Default::default() + }, + ..Default::default() + }); + let new_active_work = NonZeroU32PowerOfTwo::try_from(4).unwrap(); + assert!( + new_active_work + < world + .gateway(Role::H1) + .config() + .active_work_as_power_of_two() + ); + let sender = world.gateway(Role::H1).get_mpc_sender::( + &ChannelId::new(Role::H2, Gate::default()), + TotalRecords::specified(15).unwrap(), + new_active_work, + ); + try_join_all( + (0..new_active_work.get()) + .map(|record_id| sender.send(record_id.into(), BA3::ZERO)), + ) + .await + .unwrap(); + let recv = world.gateway(Role::H2).get_mpc_receiver::(&ChannelId { + peer: Role::H1, + gate: Gate::default(), + }); + // this will hang if the original active work is used + try_join_all( + (0..new_active_work.get()).map(|record_id| recv.receive(record_id.into())), + ) + .await + .unwrap(); + }); + } + + macro_rules! send_recv_test { + ( + message: $message:expr, + read_size: $read_size:expr, + active_work: $active_work:expr, + total_records: $total_records:expr, + $test_fn: ident + ) => { + #[test] + fn $test_fn() { + run(|| async { + send_recv($read_size, $active_work, $total_records, $message).await; + }); + } + }; + } + + send_recv_test! { + message: BA20::ZERO, + read_size: 5, + active_work: 8, + total_records: 25, + test_ba20_5_10_25 + } + + send_recv_test! { + message: StdArray::::ZERO_ARRAY, + read_size: 2048, + active_work: 16, + total_records: 43, + test_ba256_by_16_2048_10_43 + } + + send_recv_test! { + message: StdArray::::ZERO_ARRAY, + read_size: 2048, + active_work: 32, + total_records: 50, + test_ba8_by_16_2048_37_50 + } + + proptest! { + #[test] + fn send_recv_randomized( + total_records in 1_usize..500, + active in 2_usize..1000, + read_size in (1_usize..32768), + record_size in 1_usize..=8, + ) { + let active = active.next_power_of_two(); + run(move || async move { + match record_size { + 1 => send_recv(read_size, active, total_records, StdArray::::ZERO_ARRAY).await, + 2 => send_recv(read_size, active, total_records, StdArray::::ZERO_ARRAY).await, + 3 => send_recv(read_size, active, total_records, BA3::ZERO).await, + 4 => send_recv(read_size, active, total_records, BA4::ZERO).await, + 5 => send_recv(read_size, active, total_records, BA5::ZERO).await, + 6 => send_recv(read_size, active, total_records, BA6::ZERO).await, + 7 => send_recv(read_size, active, total_records, BA7::ZERO).await, + 8 => send_recv(read_size, active, total_records, StdArray::::ZERO_ARRAY).await, + _ => unreachable!(), + } + }); + } + } + + /// ensures when active work is set from query input, it is always a power of two + #[test] + fn gateway_config_active_work_power_of_two() { + let mut config = GatewayConfig { + active: 2.try_into().unwrap(), + ..Default::default() + }; + config.set_active_work_from_query_config(&QueryConfig { + size: QuerySize::try_from(5).unwrap(), + field_type: FieldType::Fp31, + query_type: QueryType::TestAddInPrimeField, + }); + assert_eq!(8, config.active_work().get()); + } + async fn shard_comms_test(test_world: &TestWorld>) { let input = vec![BA3::truncate_from(0_u32), BA3::truncate_from(1_u32)]; @@ -553,4 +724,112 @@ mod tests { let world_ptr = world as *mut _; (world, world_ptr) } + + /// This serves the purpose of randomized testing of our send channels by providing + /// variable sizes for read size, active work and record size + async fn send_recv(read_size: usize, active_work: usize, total_records: usize, sample: M) + where + M: MpcMessage + Clone + PartialEq, + { + fn duplex_channel( + world: &TestWorld, + left: Role, + right: Role, + total_records: usize, + active_work: usize, + ) -> (SendingEnd, MpcReceivingEnd) { + ( + world.gateway(left).get_mpc_sender::( + &ChannelId::new(right, Gate::default()), + TotalRecords::specified(total_records).unwrap(), + active_work.try_into().unwrap(), + ), + world + .gateway(right) + .get_mpc_receiver::(&ChannelId::new(left, Gate::default())), + ) + } + + async fn circuit( + send_channel: SendingEnd, + recv_channel: MpcReceivingEnd, + active_work: usize, + total_records: usize, + msg: M, + ) where + M: MpcMessage + Clone + PartialEq, + { + let last_batch_size = total_records % active_work; + let last_batch = total_records / active_work; + + let barrier = Arc::new(Barrier::new(active_work)); + let last_batch_barrier = Arc::new(Barrier::new(last_batch_size)); + + // perform "multiplication-like" operation (send + subsequent receive) + // and "validate": block the future until we have at least `active_work` + // futures pending and unblock them all at the same time + seq_join( + active_work.try_into().unwrap(), + stream::iter(std::iter::repeat(msg).take(total_records).enumerate()).map( + |(record_id, msg)| { + let send_channel = &send_channel; + let recv_channel = &recv_channel; + let barrier = Arc::clone(&barrier); + let last_batch_barrier = Arc::clone(&last_batch_barrier); + async move { + send_channel + .send(record_id.into(), msg.clone()) + .await + .unwrap(); + let r = recv_channel.receive(record_id.into()).await.unwrap(); + // this simulates validate_record API by forcing futures to wait + // until the entire batch is validated by the last future in that batch + if record_id >= last_batch * active_work { + last_batch_barrier.wait().await; + } else { + barrier.wait().await; + } + + assert_eq!(msg, r); + } + }, + ), + ) + .collect::>() + .await; + } + + let config = TestWorldConfig { + gateway_config: GatewayConfig { + active: active_work.try_into().unwrap(), + read_size: read_size.try_into().unwrap(), + ..Default::default() + }, + ..Default::default() + }; + + let world = TestWorld::new_with(&config); + let (h1_send_channel, h1_recv_channel) = + duplex_channel(&world, Role::H1, Role::H2, total_records, active_work); + let (h2_send_channel, h2_recv_channel) = + duplex_channel(&world, Role::H2, Role::H1, total_records, active_work); + + join( + circuit( + h1_send_channel, + h1_recv_channel, + active_work, + total_records, + sample.clone(), + ), + circuit( + h2_send_channel, + h2_recv_channel, + active_work, + total_records, + sample, + ), + ) + .await; + } } diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index fc73caf5d..089876243 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -203,10 +203,9 @@ impl GatewaySenders { match self.inner.entry(channel_id.clone()) { Entry::Occupied(entry) => Arc::clone(entry.get()), Entry::Vacant(entry) => { - let sender = Self::new_sender( - &SendChannelConfig::new::(config, total_records), - channel_id.clone(), - ); + let config = SendChannelConfig::new::(config, total_records); + tracing::trace!("send configuration for {channel_id:?}: {config:?}"); + let sender = Self::new_sender(&config, channel_id.clone()); entry.insert(Arc::clone(&sender)); tokio::spawn({ @@ -249,40 +248,69 @@ impl Stream for GatewaySendStream { impl SendChannelConfig { fn new(gateway_config: GatewayConfig, total_records: TotalRecords) -> Self { - debug_assert!(M::Size::USIZE > 0, "Message size cannot be 0"); + Self::new_with(gateway_config, total_records, M::Size::USIZE) + } + fn new_with( + gateway_config: GatewayConfig, + total_records: TotalRecords, + record_size: usize, + ) -> Self { + // this computes the greatest positive power of 2 that is + // less than or equal to target. + fn non_zero_prev_power_of_two(target: usize) -> usize { + let bits = usize::BITS - target.leading_zeros(); + + 1 << (std::cmp::max(1, bits) - 1) + } + + assert!(record_size > 0, "Message size cannot be 0"); - let record_size = M::Size::USIZE; let total_capacity = gateway_config.active.get() * record_size; - Self { + // define read size as a multiplier of record size. The multiplier must be + // a power of two to align perfectly with total capacity. We don't want to exceed + // the target read size, so multiplier * record_size <= read_size. We want to get + // as close as possible to read_size. + let read_size_multiplier = { + let target = gateway_config.read_size.get() / record_size; + // If record_size is greater than read_size, we set the multiplier to 1 + // as read size cannot be 0. + non_zero_prev_power_of_two(target) + }; + + let this = Self { total_capacity: total_capacity.try_into().unwrap(), record_size: record_size.try_into().unwrap(), - read_size: if total_records.is_indeterminate() - || gateway_config.read_size.get() <= record_size - { + read_size: if total_records.is_indeterminate() { record_size } else { - std::cmp::min( - total_capacity, - // closest multiple of record_size to read_size - gateway_config.read_size.get() / record_size * record_size, - ) + std::cmp::min(total_capacity, read_size_multiplier * record_size) } .try_into() .unwrap(), total_records, - } + }; + + // If capacity can't fit all active work items, the protocol deadlocks on + // inserts above the total capacity. + assert!(this.total_capacity.get() >= record_size * gateway_config.active.get()); + // if capacity is not aligned with read size, we can get a deadlock + // described in ipa/1300 + assert_eq!(0, this.total_capacity.get() % this.read_size.get()); + + this } } -#[cfg(test)] +#[cfg(all(test, unit_test))] mod test { use std::num::NonZeroUsize; + use proptest::proptest; use typenum::Unsigned; use crate::{ ff::{ - boolean_array::{BA16, BA20, BA256, BA3, BA7}, + boolean_array::{BA16, BA20, BA256, BA3, BA32, BA7}, Serializable, }, helpers::{gateway::send::SendChannelConfig, GatewayConfig, TotalRecords}, @@ -380,15 +408,82 @@ mod test { fn config_read_size_closest_multiple_to_record_size() { assert_eq!( 6, - send_config::(TotalRecords::Specified(2.try_into().unwrap())) + send_config::(TotalRecords::Specified(2.try_into().unwrap())) .read_size .get() ); assert_eq!( 6, - send_config::(TotalRecords::Specified(2.try_into().unwrap())) + send_config::(TotalRecords::Specified(2.try_into().unwrap())) .read_size .get() ); } + + #[test] + fn config_read_size_record_size_misalignment() { + ensure_config(Some(15), 90, 16, 3); + } + + #[test] + fn config_read_size_multiple_of_record_size() { + // 4 bytes * 8 = 32 bytes total capacity. + // desired read size is 15 bytes, and the closest multiple of BA32 + // to it that is a power of two is 2 (4 gets us over 15 byte target) + assert_eq!(8, send_config::(50.into()).read_size.get()); + + // here, read size is already a power of two + assert_eq!(16, send_config::(50.into()).read_size.get()); + + // read size can be ridiculously small, config adjusts it to fit + // at least one record + assert_eq!(3, send_config::(50.into()).read_size.get()); + } + + fn ensure_config( + total_records: Option, + active: usize, + read_size: usize, + record_size: usize, + ) { + #[allow(clippy::needless_update)] // stall detection feature wants default value + let gateway_config = GatewayConfig { + active: active.next_power_of_two().try_into().unwrap(), + read_size: read_size.try_into().unwrap(), + ..Default::default() + }; + let config = SendChannelConfig::new_with( + gateway_config, + total_records.map_or(TotalRecords::Indeterminate, |v| { + TotalRecords::specified(v).unwrap() + }), + record_size, + ); + + // total capacity checks + assert!(config.total_capacity.get() > 0); + assert!(config.total_capacity.get() >= config.read_size.get()); + assert_eq!(0, config.total_capacity.get() % config.record_size.get()); + assert_eq!( + config.total_capacity.get(), + record_size * gateway_config.active.get() + ); + + // read size checks + assert!(config.read_size.get() > 0); + assert!(config.read_size.get() >= config.record_size.get()); + assert_eq!(0, config.total_capacity.get() % config.read_size.get()); + } + + proptest! { + #[test] + fn config_prop( + total_records in proptest::option::of(1_usize..1 << 32), + active in 1_usize..100_000, + read_size in 1_usize..32768, + record_size in 1_usize..4096, + ) { + ensure_config(total_records, active, read_size, record_size); + } + } } diff --git a/ipa-core/src/helpers/gateway/stall_detection.rs b/ipa-core/src/helpers/gateway/stall_detection.rs index 49a879be4..4a844386f 100644 --- a/ipa-core/src/helpers/gateway/stall_detection.rs +++ b/ipa-core/src/helpers/gateway/stall_detection.rs @@ -80,6 +80,7 @@ mod gateway { protocol::QueryId, sharding::ShardIndex, sync::Arc, + utils::NonZeroU32PowerOfTwo, }; pub struct InstrumentedGateway { @@ -153,12 +154,13 @@ mod gateway { &self, channel_id: &HelperChannelId, total_records: TotalRecords, + active_work: NonZeroU32PowerOfTwo, ) -> SendingEnd { Observed::wrap( Weak::clone(self.get_sn()), self.inner() .gateway - .get_mpc_sender(channel_id, total_records), + .get_mpc_sender(channel_id, total_records, active_work), ) } diff --git a/ipa-core/src/helpers/hashing.rs b/ipa-core/src/helpers/hashing.rs index 10f50484b..ae9579097 100644 --- a/ipa-core/src/helpers/hashing.rs +++ b/ipa-core/src/helpers/hashing.rs @@ -12,7 +12,7 @@ use crate::{ protocol::prss::FromRandomU128, }; -#[derive(Debug, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq)] pub struct Hash(Output); impl Serializable for Hash { diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index e33a2ec99..a52516786 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -7,6 +7,7 @@ use std::{ convert::Infallible, fmt::{Debug, Display, Formatter}, num::NonZeroUsize, + ops::Not, }; use generic_array::GenericArray; @@ -181,6 +182,10 @@ impl HelperIdentity { pub const TWO: Self = Self { id: 2 }; pub const THREE: Self = Self { id: 3 }; + pub const ONE_STR: &'static str = "A"; + pub const TWO_STR: &'static str = "B"; + pub const THREE_STR: &'static str = "C"; + /// Given a helper identity, return an array of the identities of the other two helpers. // The order that helpers are returned here is not intended to be meaningful, however, // it is currently used directly to determine the assignment of roles in @@ -267,6 +272,17 @@ pub enum Direction { Right, } +impl Not for Direction { + type Output = Self; + + fn not(self) -> Self { + match self { + Direction::Left => Direction::Right, + Direction::Right => Direction::Left, + } + } +} + impl Role { const H1_STR: &'static str = "H1"; const H2_STR: &'static str = "H2"; @@ -618,36 +634,6 @@ where } } -pub struct RepeatN { - element: T, - count: usize, -} - -// As of Apr. 2024, this is unstable in `std::iter`. It is also available in `itertools`. -// The advantage over `repeat(element).take(count)` that we care about is that this -// implements `ExactSizeIterator`. The other advantage is that `repeat_n` can return -// the original value (saving a clone) on the last iteration. -pub fn repeat_n(element: T, count: usize) -> RepeatN { - RepeatN { element, count } -} - -impl Iterator for RepeatN { - type Item = T; - - fn next(&mut self) -> Option { - (self.count > 0).then(|| { - self.count -= 1; - self.element.clone() - }) - } - - fn size_hint(&self) -> (usize, Option) { - (self.count, Some(self.count)) - } -} - -impl ExactSizeIterator for RepeatN {} - #[cfg(all(test, unit_test))] mod tests { use super::*; @@ -817,7 +803,7 @@ mod concurrency_tests { let input = (0u32..11).map(TestField::truncate_from).collect::>(); let config = TestWorldConfig { gateway_config: GatewayConfig { - active: input.len().try_into().unwrap(), + active: input.len().next_power_of_two().try_into().unwrap(), ..Default::default() }, ..Default::default() @@ -875,7 +861,7 @@ mod concurrency_tests { let input = (0u32..11).map(TestField::truncate_from).collect::>(); let config = TestWorldConfig { gateway_config: GatewayConfig { - active: input.len().try_into().unwrap(), + active: input.len().next_power_of_two().try_into().unwrap(), ..Default::default() }, ..Default::default() diff --git a/ipa-core/src/helpers/prss_protocol.rs b/ipa-core/src/helpers/prss_protocol.rs index 8171ca019..850d6c733 100644 --- a/ipa-core/src/helpers/prss_protocol.rs +++ b/ipa-core/src/helpers/prss_protocol.rs @@ -21,8 +21,16 @@ pub async fn negotiate( let left_channel = ChannelId::new(gateway.role().peer(Direction::Left), gate.clone()); let right_channel = ChannelId::new(gateway.role().peer(Direction::Right), gate.clone()); - let left_sender = gateway.get_mpc_sender::(&left_channel, TotalRecords::ONE); - let right_sender = gateway.get_mpc_sender::(&right_channel, TotalRecords::ONE); + let left_sender = gateway.get_mpc_sender::( + &left_channel, + TotalRecords::ONE, + gateway.config().active_work_as_power_of_two(), + ); + let right_sender = gateway.get_mpc_sender::( + &right_channel, + TotalRecords::ONE, + gateway.config().active_work_as_power_of_two(), + ); let left_receiver = gateway.get_mpc_receiver::(&left_channel); let right_receiver = gateway.get_mpc_receiver::(&right_channel); diff --git a/ipa-core/src/helpers/transport/handler.rs b/ipa-core/src/helpers/transport/handler.rs index 42981d097..525edb67e 100644 --- a/ipa-core/src/helpers/transport/handler.rs +++ b/ipa-core/src/helpers/transport/handler.rs @@ -12,7 +12,7 @@ use crate::{ }, query::{ NewQueryError, PrepareQueryError, ProtocolResult, QueryCompletionError, QueryInputError, - QueryStatus, QueryStatusError, + QueryKillStatus, QueryKilled, QueryStatus, QueryStatusError, }, sync::{Arc, Mutex, Weak}, }; @@ -135,6 +135,13 @@ impl From for HelperResponse { } } +impl From for HelperResponse { + fn from(value: QueryKilled) -> Self { + let v = serde_json::to_vec(&json!({"query_id": value.0, "status": "killed"})).unwrap(); + Self { body: v } + } +} + impl> From for HelperResponse { fn from(value: R) -> Self { let v = value.as_ref().to_bytes(); @@ -156,6 +163,8 @@ pub enum Error { #[error(transparent)] QueryStatus(#[from] QueryStatusError), #[error(transparent)] + QueryKill(#[from] QueryKillStatus), + #[error(transparent)] DeserializationFailure(#[from] serde_json::Error), #[error("MalformedRequest: {0}")] BadRequest(BoxError), diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 3c1a9e926..cd7324e89 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -119,7 +119,8 @@ impl InMemoryTransport { | RouteId::PrepareQuery | RouteId::QueryInput | RouteId::QueryStatus - | RouteId::CompleteQuery => { + | RouteId::CompleteQuery + | RouteId::KillQuery => { handler .as_ref() .expect("Handler is set") diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index c3bb307d8..f72814614 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -44,22 +44,58 @@ pub trait Identity: Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord + Hash + Send + Sync + 'static { fn as_str(&self) -> Cow<'static, str>; + + /// Parses a ref to a string representation of this identity + /// + /// # Errors + /// If there where any problems parsing the identity. + fn from_str(s: &str) -> Result; + + /// Returns a 0-based index suitable to index Vec or other containers. + fn as_index(&self) -> usize; } impl Identity for ShardIndex { fn as_str(&self) -> Cow<'static, str> { Cow::Owned(self.to_string()) } + + fn from_str(s: &str) -> Result { + s.parse::() + .map_err(|_e| { + crate::error::Error::InvalidId(format!("The string {s} is an invalid Shard Index")) + }) + .map(ShardIndex::from) + } + + fn as_index(&self) -> usize { + usize::from(*self) + } } impl Identity for HelperIdentity { fn as_str(&self) -> Cow<'static, str> { Cow::Borrowed(match *self { - Self::ONE => "A", - Self::TWO => "B", - Self::THREE => "C", + Self::ONE => Self::ONE_STR, + Self::TWO => Self::TWO_STR, + Self::THREE => Self::THREE_STR, _ => unreachable!(), }) } + + fn from_str(s: &str) -> Result { + match s { + Self::ONE_STR => Ok(Self::ONE), + Self::TWO_STR => Ok(Self::TWO), + Self::THREE_STR => Ok(Self::THREE), + _ => Err(crate::error::Error::InvalidId(format!( + "The string {s} is an invalid Helper Identity" + ))), + } + } + + fn as_index(&self) -> usize { + usize::from(self.id) - 1 + } } /// Role is an identifier of helper peer, only valid within a given query. For every query, there @@ -68,6 +104,25 @@ impl Identity for Role { fn as_str(&self) -> Cow<'static, str> { Cow::Borrowed(Role::as_static_str(self)) } + + fn from_str(s: &str) -> Result { + match s { + Self::H1_STR => Ok(Self::H1), + Self::H2_STR => Ok(Self::H2), + Self::H3_STR => Ok(Self::H3), + _ => Err(crate::error::Error::InvalidId(format!( + "The string {s} is an invalid Role" + ))), + } + } + + fn as_index(&self) -> usize { + match self { + Self::H1 => 0, + Self::H2 => 1, + Self::H3 => 2, + } + } } pub trait ResourceIdentifier: Sized {} @@ -229,3 +284,53 @@ pub trait Transport: Clone + Send + Sync + 'static { ::clone(self) } } + +#[cfg(all(test, unit_test))] +mod tests { + use crate::{ + helpers::{HelperIdentity, Role, TransportIdentity}, + sharding::ShardIndex, + }; + + #[test] + fn helper_from_str() { + assert_eq!(HelperIdentity::from_str("A").unwrap(), HelperIdentity::ONE); + assert_eq!(HelperIdentity::from_str("B").unwrap(), HelperIdentity::TWO); + assert_eq!( + HelperIdentity::from_str("C").unwrap(), + HelperIdentity::THREE + ); + } + + #[test] + #[should_panic(expected = "The string H1 is an invalid Helper Identity")] + fn invalid_helper_from_str() { + assert_eq!(HelperIdentity::from_str("H1").unwrap(), HelperIdentity::ONE); + } + + #[test] + fn shard_from_str() { + assert_eq!(ShardIndex::from_str("42").unwrap(), ShardIndex::from(42)); + assert_eq!(ShardIndex::from_str("9").unwrap(), ShardIndex::from(9)); + assert_eq!(ShardIndex::from_str("0").unwrap(), ShardIndex::from(0)); + } + + #[test] + #[should_panic(expected = "The string -1 is an invalid Shard Index")] + fn invalid_shard_from_str() { + assert_eq!(ShardIndex::from_str("-1").unwrap(), ShardIndex::from(0)); + } + + #[test] + fn role_from_str() { + assert_eq!(Role::from_str("H1").unwrap(), Role::H1); + assert_eq!(Role::from_str("H2").unwrap(), Role::H2); + assert_eq!(Role::from_str("H3").unwrap(), Role::H3); + } + + #[test] + #[should_panic(expected = "The string A is an invalid Role")] + fn invalid_role_from_str() { + assert_eq!(Role::from_str("A").unwrap(), Role::H1); + } +} diff --git a/ipa-core/src/helpers/transport/query/hybrid.rs b/ipa-core/src/helpers/transport/query/hybrid.rs new file mode 100644 index 000000000..2b6906d28 --- /dev/null +++ b/ipa-core/src/helpers/transport/query/hybrid.rs @@ -0,0 +1,32 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq)] +#[cfg_attr(feature = "clap", derive(clap::Args))] +pub struct HybridQueryParams { + #[cfg_attr(feature = "clap", arg(long, default_value = "8"))] + pub per_user_credit_cap: u32, + #[cfg_attr(feature = "clap", arg(long, default_value = "5"))] + pub max_breakdown_key: u32, + #[cfg_attr(feature = "clap", arg(short = 'd', long, default_value = "1"))] + pub with_dp: u32, + #[cfg_attr(feature = "clap", arg(short = 'e', long, default_value = "5.0"))] + pub epsilon: f64, + #[cfg_attr(feature = "clap", arg(long))] + #[serde(default)] + pub plaintext_match_keys: bool, +} + +#[cfg(test)] +impl Eq for HybridQueryParams {} + +impl Default for HybridQueryParams { + fn default() -> Self { + Self { + per_user_credit_cap: 8, + max_breakdown_key: 20, + with_dp: 1, + epsilon: 0.10, + plaintext_match_keys: false, + } + } +} diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index 27850ab8d..ac70209b3 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -1,8 +1,11 @@ +mod hybrid; + use std::{ fmt::{Debug, Display, Formatter}, num::NonZeroU32, }; +pub use hybrid::HybridQueryParams; use serde::{Deserialize, Deserializer, Serialize}; use crate::{ @@ -198,14 +201,21 @@ pub enum QueryType { TestMultiply, #[cfg(any(test, feature = "test-fixture", feature = "cli"))] TestAddInPrimeField, - OprfIpa(IpaQueryConfig), + #[cfg(any(test, feature = "test-fixture", feature = "cli"))] + TestShardedShuffle, + SemiHonestOprfIpa(IpaQueryConfig), + MaliciousOprfIpa(IpaQueryConfig), + SemiHonestHybrid(HybridQueryParams), } impl QueryType { /// TODO: strum pub const TEST_MULTIPLY_STR: &'static str = "test-multiply"; pub const TEST_ADD_STR: &'static str = "test-add"; - pub const OPRF_IPA_STR: &'static str = "oprf_ipa"; + pub const TEST_SHARDED_SHUFFLE_STR: &'static str = "test-sharded-shuffle"; + pub const SEMI_HONEST_OPRF_IPA_STR: &'static str = "semi-honest-oprf-ipa"; + pub const MALICIOUS_OPRF_IPA_STR: &'static str = "malicious-oprf-ipa"; + pub const SEMI_HONEST_HYBRID_STR: &'static str = "semi-honest-hybrid"; } /// TODO: should this `AsRef` impl (used for `Substep`) take into account config of IPA? @@ -216,7 +226,11 @@ impl AsRef for QueryType { QueryType::TestMultiply => Self::TEST_MULTIPLY_STR, #[cfg(any(test, feature = "cli", feature = "test-fixture"))] QueryType::TestAddInPrimeField => Self::TEST_ADD_STR, - QueryType::OprfIpa(_) => Self::OPRF_IPA_STR, + #[cfg(any(test, feature = "cli", feature = "test-fixture"))] + QueryType::TestShardedShuffle => Self::TEST_SHARDED_SHUFFLE_STR, + QueryType::SemiHonestOprfIpa(_) => Self::SEMI_HONEST_OPRF_IPA_STR, + QueryType::MaliciousOprfIpa(_) => Self::MALICIOUS_OPRF_IPA_STR, + QueryType::SemiHonestHybrid(_) => Self::SEMI_HONEST_HYBRID_STR, } } } diff --git a/ipa-core/src/helpers/transport/routing.rs b/ipa-core/src/helpers/transport/routing.rs index 4d8f44796..3d9c2bb5f 100644 --- a/ipa-core/src/helpers/transport/routing.rs +++ b/ipa-core/src/helpers/transport/routing.rs @@ -16,6 +16,7 @@ pub enum RouteId { QueryInput, QueryStatus, CompleteQuery, + KillQuery, } /// The header/metadata of the incoming request. diff --git a/ipa-core/src/helpers/transport/stream/collection.rs b/ipa-core/src/helpers/transport/stream/collection.rs index 09e4f5e63..f19fd7ce5 100644 --- a/ipa-core/src/helpers/transport/stream/collection.rs +++ b/ipa-core/src/helpers/transport/stream/collection.rs @@ -114,6 +114,26 @@ impl StreamCollection { let mut streams = self.inner.lock().unwrap(); streams.clear(); } + + /// Returns the number of streams inside this collection. + /// + /// ## Panics + /// if mutex is poisoned. + #[cfg(test)] + #[must_use] + pub fn len(&self) -> usize { + self.inner.lock().unwrap().len() + } + + /// Returns `true` if this collection is empty. + /// + /// ## Panics + /// if mutex is poisoned. + #[must_use] + #[cfg(test)] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } /// Describes the lifecycle of records stream inside [`StreamCollection`] diff --git a/ipa-core/src/helpers/transport/stream/mod.rs b/ipa-core/src/helpers/transport/stream/mod.rs index 2f2d6ccc6..59c76cdf4 100644 --- a/ipa-core/src/helpers/transport/stream/mod.rs +++ b/ipa-core/src/helpers/transport/stream/mod.rs @@ -187,7 +187,7 @@ mod tests { let stream = BodyStream::from_bytes_stream(stream::once(future::ready(Ok(Bytes::from(data))))); - stream.try_collect::>().await.unwrap() + stream.try_collect::>().await.unwrap(); }); } } diff --git a/ipa-core/src/hpke/info.rs b/ipa-core/src/hpke/info.rs index 584f46525..e0b7a794b 100644 --- a/ipa-core/src/hpke/info.rs +++ b/ipa-core/src/hpke/info.rs @@ -52,7 +52,7 @@ impl<'a> Info<'a> { /// Converts this instance into an owned byte slice that can further be used to create HPKE /// sender or receiver context. - pub(super) fn to_bytes(&self) -> Box<[u8]> { + pub(crate) fn to_bytes(&self) -> Box<[u8]> { let info_len = DOMAIN.len() + self.helper_origin.len() + self.site_domain.len() diff --git a/ipa-core/src/hpke/mod.rs b/ipa-core/src/hpke/mod.rs index 19b21bf67..e545efa54 100644 --- a/ipa-core/src/hpke/mod.rs +++ b/ipa-core/src/hpke/mod.rs @@ -96,26 +96,21 @@ impl From for CryptError { /// If ciphertext cannot be opened for any reason. /// /// [`HPKE decryption`]: https://datatracker.ietf.org/doc/html/rfc9180#name-encryption-and-decryption -pub fn open_in_place<'a, R: PrivateKeyRegistry>( - key_registry: &R, +pub fn open_in_place<'a>( + sk: &IpaPrivateKey, enc: &[u8], ciphertext: &'a mut [u8], - info: &Info, + info: &[u8], ) -> Result<&'a [u8], CryptError> { - let key_id = info.key_id; - let info = info.to_bytes(); let encap_key = ::EncappedKey::from_bytes(enc)?; let (ct, tag) = ciphertext.split_at_mut(ciphertext.len() - AeadTag::::size()); let tag = AeadTag::::from_bytes(tag)?; - let sk = key_registry - .private_key(key_id) - .ok_or(CryptError::NoSuchKey(key_id))?; single_shot_open_in_place_detached::<_, IpaKdf, IpaKem>( &OpModeR::Base, sk, &encap_key, - &info, + info, ct, &[], &tag, @@ -136,22 +131,16 @@ pub(crate) type Ciphertext<'a> = ( /// ## Errors /// If the match key cannot be sealed for any reason. -pub(crate) fn seal_in_place<'a, R: CryptoRng + RngCore, K: PublicKeyRegistry>( - key_registry: &K, +pub(crate) fn seal_in_place<'a, R: CryptoRng + RngCore>( + pk: &IpaPublicKey, plaintext: &'a mut [u8], - info: &'a Info, + info: &[u8], rng: &mut R, ) -> Result, CryptError> { - let key_id = info.key_id; - let info = info.to_bytes(); - let pk_r = key_registry - .public_key(key_id) - .ok_or(CryptError::NoSuchKey(key_id))?; - let (encap_key, tag) = single_shot_seal_in_place_detached::( &OpModeS::Base, - pk_r, - &info, + pk, + info, plaintext, &[], rng, @@ -169,6 +158,7 @@ mod tests { use rand_core::{CryptoRng, RngCore, SeedableRng}; use typenum::Unsigned; + use super::{PrivateKeyRegistry, PublicKeyRegistry}; use crate::{ ff::{Gf40Bit, Serializable as IpaSerializable}, hpke::{open_in_place, seal_in_place, CryptError, Info, IpaAead, KeyPair, KeyRegistry}, @@ -231,9 +221,12 @@ mod tests { match_key.serialize(&mut plaintext); let (encap_key, ciphertext, tag) = seal_in_place( - &self.registry, + self.registry + .public_key(info.key_id) + .ok_or(CryptError::NoSuchKey(info.key_id)) + .unwrap(), plaintext.as_mut_slice(), - &info, + &info.to_bytes(), &mut self.rng, ) .unwrap(); @@ -282,7 +275,14 @@ mod tests { Self::SITE_DOMAIN, ) .unwrap(); - open_in_place(&self.registry, &enc.enc, enc.ct.as_mut(), &info)?; + open_in_place( + self.registry + .private_key(info.key_id) + .ok_or(CryptError::NoSuchKey(info.key_id))?, + &enc.enc, + enc.ct.as_mut(), + &info.to_bytes(), + )?; // TODO: fix once array split is a thing. Ok(XorReplicated::deserialize_infallible( @@ -467,7 +467,8 @@ mod tests { _ => panic!("bad test setup: only 5 fields can be corrupted, asked to corrupt: {corrupted_info_field}") }; - open_in_place(&suite.registry, &encryption.enc, &mut encryption.ct, &info).unwrap_err(); + open_in_place(suite.registry.private_key(info.key_id) + .ok_or(CryptError::NoSuchKey(info.key_id))?, &encryption.enc, &mut encryption.ct, &info.to_bytes()).unwrap_err(); } } } diff --git a/ipa-core/src/hpke/registry.rs b/ipa-core/src/hpke/registry.rs index 281f4af47..283d1fbd0 100644 --- a/ipa-core/src/hpke/registry.rs +++ b/ipa-core/src/hpke/registry.rs @@ -93,13 +93,9 @@ impl KeyRegistry { Self { keys: Box::new([]) } } - pub fn from_keys>(pairs: [I; N]) -> Self { + pub fn from_keys(pairs: [K; N]) -> Self { Self { - keys: pairs - .into_iter() - .map(Into::into) - .collect::>() - .into_boxed_slice(), + keys: pairs.into_iter().collect::>().into_boxed_slice(), } } diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 8c04fac7a..345bbe0ae 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -32,8 +32,8 @@ mod seq_join; mod serde; pub mod sharding; mod utils; - pub use app::{AppConfig, HelperApp, Setup as AppSetup}; +pub use utils::NonZeroU32PowerOfTwo; extern crate core; #[cfg(all(feature = "shuttle", test))] @@ -70,10 +70,10 @@ pub(crate) mod rand { #[cfg(all(feature = "shuttle", test))] pub(crate) mod task { - pub use shuttle::future::{JoinError, JoinHandle}; + pub use shuttle::future::JoinError; } -#[cfg(all(feature = "multi-threading", feature = "shuttle"))] +#[cfg(feature = "shuttle")] pub(crate) mod shim { use std::any::Any; @@ -94,9 +94,170 @@ pub(crate) mod shim { #[cfg(not(all(feature = "shuttle", test)))] pub(crate) mod task { + #[allow(unused_imports)] pub use tokio::task::{JoinError, JoinHandle}; } +#[cfg(not(feature = "shuttle"))] +pub mod executor { + use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + }; + + use tokio::{ + runtime::{Handle, Runtime}, + task::JoinHandle, + }; + + /// In prod we use Tokio scheduler, so this struct just wraps + /// its runtime handle and mimics the standard executor API. + /// The name was chosen to avoid clashes with tokio runtime + /// when importing it + #[derive(Clone)] + pub struct IpaRuntime(Handle); + + /// Wrapper around Tokio's [`JoinHandle`] + #[pin_project::pin_project] + pub struct IpaJoinHandle(#[pin] JoinHandle); + + impl Default for IpaRuntime { + fn default() -> Self { + Self::current() + } + } + + impl IpaRuntime { + #[must_use] + pub fn current() -> Self { + Self(Handle::current()) + } + + #[must_use] + pub fn spawn(&self, future: F) -> IpaJoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + IpaJoinHandle(self.0.spawn(future)) + } + + /// This is a convenience method to convert a Tokio runtime into + /// an IPA runtime. It does not assume ownership of the Tokio runtime. + /// The caller is responsible for ensuring the Tokio runtime is properly + /// shut down. + #[must_use] + pub fn from_tokio_runtime(rt: &Runtime) -> Self { + Self(rt.handle().clone()) + } + } + + /// allow using [`IpaRuntime`] as Hyper executor + #[cfg(feature = "web-app")] + impl hyper::rt::Executor for IpaRuntime + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + fn execute(&self, fut: Fut) { + // Dropping the handle does not terminate the task + // Clippy wants us to be explicit here. + drop(self.spawn(fut)); + } + } + + impl IpaJoinHandle { + pub fn abort(&self) { + self.0.abort(); + } + } + + impl Future for IpaJoinHandle { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().0.poll(cx) { + Poll::Ready(Ok(v)) => Poll::Ready(v), + Poll::Ready(Err(e)) => match e.try_into_panic() { + Ok(p) => std::panic::resume_unwind(p), + Err(e) => panic!("Task is cancelled: {e:?}"), + }, + Poll::Pending => Poll::Pending, + } + } + } +} + +#[cfg(feature = "shuttle")] +pub(crate) mod executor { + use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + }; + + use shuttle_crate::future::{spawn, JoinHandle}; + + use crate::shim::Tokio; + + /// Shuttle does not support more than one runtime + /// so we always use its default + #[derive(Clone, Default)] + pub struct IpaRuntime; + #[pin_project::pin_project] + pub struct IpaJoinHandle(#[pin] JoinHandle); + + #[cfg(feature = "web-app")] + impl hyper::rt::Executor for IpaRuntime + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + fn execute(&self, fut: Fut) { + drop(self.spawn(fut)); + } + } + + impl IpaRuntime { + #[must_use] + pub fn current() -> Self { + Self + } + + #[must_use] + #[allow(clippy::unused_self)] // to conform with runtime API + pub fn spawn(&self, future: F) -> IpaJoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + IpaJoinHandle(spawn(future)) + } + } + + impl IpaJoinHandle { + pub fn abort(&self) { + self.0.abort(); + } + } + + impl Future for IpaJoinHandle { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().0.poll(cx) { + Poll::Ready(Ok(v)) => Poll::Ready(v), + Poll::Ready(Err(e)) => match e.try_into_panic() { + Ok(p) => std::panic::resume_unwind(p), + Err(e) => panic!("Task is cancelled: {e:?}"), + }, + Poll::Pending => Poll::Pending, + } + } + } +} + #[cfg(all(feature = "shuttle", test))] pub(crate) mod test_executor { use std::future::Future; @@ -118,14 +279,14 @@ pub(crate) mod test_executor { } } -#[cfg(all(test, unit_test, not(feature = "shuttle")))] +#[cfg(all(test, not(feature = "shuttle")))] pub(crate) mod test_executor { use std::future::Future; - pub fn run_with(f: F) -> T + pub fn run_with(f: F) where F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future, + Fut: Future, { tokio::runtime::Builder::new_multi_thread() // enable_all() is common to use to build Tokio runtime, but it enables both IO and time drivers. @@ -134,15 +295,16 @@ pub(crate) mod test_executor { .enable_time() .build() .unwrap() - .block_on(f()) + .block_on(f()); } - pub fn run(f: F) -> T + #[allow(dead_code)] + pub fn run(f: F) where F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future, + Fut: Future, { - run_with::<_, _, _, 1>(f) + run_with::<_, _, 1>(f); } } @@ -182,3 +344,40 @@ macro_rules! mutually_incompatible { } mutually_incompatible!("in-memory-infra", "real-world-infra"); +#[cfg(not(any(compact_gate, descriptive_gate)))] +compile_error!("At least one of `compact_gate` or `descriptive_gate` features must be enabled"); + +#[cfg(test)] +mod tests { + /// Tests in this module ensure both Shuttle and Tokio runtimes conform to the same API + mod executor { + use crate::{executor::IpaRuntime, test_executor::run}; + + #[test] + #[should_panic(expected = "task panicked")] + fn handle_join_panicked() { + run(|| async move { + let rt = IpaRuntime::current(); + rt.spawn(async { panic!("task panicked") }).await; + }); + } + + #[test] + /// It is nearly impossible to intentionally hang a Shuttle task. Its executor + /// detects that immediately and panics with a deadlock error. We only want to test + /// the API, so it is not that important to panic with cancellation error + #[cfg_attr(not(feature = "shuttle"), should_panic(expected = "Task is cancelled"))] + fn handle_abort() { + run(|| async move { + let rt = IpaRuntime::current(); + let handle = rt.spawn(async { + #[cfg(not(feature = "shuttle"))] + futures::future::pending::<()>().await; + }); + + handle.abort(); + handle.await; + }); + } + } +} diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index e7180b6c5..3693431f2 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -2,6 +2,7 @@ use std::{ collections::HashMap, future::Future, io::{self, BufRead}, + marker::PhantomData, pin::Pin, sync::Arc, task::{ready, Context, Poll}, @@ -18,31 +19,33 @@ use hyper::{header::HeaderName, http::HeaderValue, Request, Response, StatusCode use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder}; use hyper_util::{ client::legacy::{connect::HttpConnector, Client}, - rt::{TokioExecutor, TokioTimer}, + rt::TokioTimer, }; use pin_project::pin_project; use rustls::RootCertStore; use tracing::error; +use super::{ConnectionFlavor, Helper}; use crate::{ config::{ ClientConfig, HyperClientConfigurator, NetworkConfig, OwnedCertificate, OwnedPrivateKey, PeerConfig, }, + executor::IpaRuntime, helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, - HelperIdentity, + TransportIdentity, }, - net::{http_serde, server::HTTP_CLIENT_ID_HEADER, Error, CRYPTO_PROVIDER}, + net::{http_serde, Error, CRYPTO_PROVIDER}, protocol::{Gate, QueryId}, }; #[derive(Default)] -pub enum ClientIdentity { +pub enum ClientIdentity { /// Claim the specified helper identity without any additional authentication. /// /// This is only supported for HTTP clients. - Helper(HelperIdentity), + Header(F::Identity), /// Authenticate with an X.509 certificate or a certificate chain. /// @@ -54,7 +57,7 @@ pub enum ClientIdentity { None, } -impl ClientIdentity { +impl ClientIdentity { /// Authenticates clients with an X.509 certificate using the provided certificate and private /// key. Certificate must be in PEM format, private key encoding must be [`PKCS8`]. /// @@ -79,10 +82,10 @@ impl ClientIdentity { /// to own a private key, and we need to create 3 with the same config, we provide Clone /// capabilities via this method to `ClientIdentity`. #[must_use] - pub fn clone_with_key(&self) -> ClientIdentity { + pub fn clone_with_key(&self) -> ClientIdentity { match self { Self::Certificate((c, pk)) => Self::Certificate((c.clone(), pk.clone_key())), - Self::Helper(h) => Self::Helper(*h), + Self::Header(h) => Self::Header(*h), Self::None => Self::None, } } @@ -91,20 +94,22 @@ impl ClientIdentity { /// Wrapper around Hyper's [future](hyper::client::ResponseFuture) interface that keeps around /// request endpoint for nicer error messages if request fails. #[pin_project] -pub struct ResponseFuture<'a> { - authority: &'a uri::Authority, +pub struct ResponseFuture { + /// There used to be a reference here, but there is really no need for that, + /// because `uri::Authority` type uses `Bytes` internally. + authority: uri::Authority, #[pin] inner: hyper_util::client::legacy::ResponseFuture, } /// Similar to [fut](ResponseFuture), wraps the response and keeps the URI authority for better /// error messages that show where error is originated from -pub struct ResponseFromEndpoint<'a> { - authority: &'a uri::Authority, +pub struct ResponseFromEndpoint { + authority: uri::Authority, inner: Response, } -impl<'a> ResponseFromEndpoint<'a> { +impl ResponseFromEndpoint { pub fn endpoint(&self) -> String { self.authority.to_string() } @@ -117,13 +122,13 @@ impl<'a> ResponseFromEndpoint<'a> { self.inner.into_body() } - pub fn into_parts(self) -> (&'a uri::Authority, Body) { + pub fn into_parts(self) -> (uri::Authority, Body) { (self.authority, self.inner.into_body()) } } -impl<'a> Future for ResponseFuture<'a> { - type Output = Result, Error>; +impl Future for ResponseFuture { + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -132,7 +137,7 @@ impl<'a> Future for ResponseFuture<'a> { let (http_parts, http_body) = resp.into_parts(); let axum_resp = Response::from_parts(http_parts, Body::new(http_body)); Poll::Ready(Ok(ResponseFromEndpoint { - authority: this.authority, + authority: this.authority.clone(), inner: axum_resp, })) } @@ -144,36 +149,41 @@ impl<'a> Future for ResponseFuture<'a> { } } +/// Helper to read a possible error response to a request that returns nothing on success +/// +/// # Errors +/// If there was an error reading the response body or if the request itself failed. +pub async fn resp_ok(resp: ResponseFromEndpoint) -> Result<(), Error> { + if resp.status().is_success() { + Ok(()) + } else { + Err(Error::from_failed_resp(resp).await) + } +} + +/// Reads the entire response from the server into Bytes +/// +/// # Errors +/// If there was an error collecting the response stream. +async fn response_to_bytes(resp: ResponseFromEndpoint) -> Result { + Ok(resp.into_body().collect().await?.to_bytes()) +} + /// TODO: we need a client that can be used by any system that is not aware of the internals /// of the helper network. That means that create query and send inputs API need to be /// separated from prepare/step data etc. /// TODO: It probably isn't necessary to always use `[MpcHelperClient; 3]`. Instead, a single /// client can be configured to talk to all three helpers. #[derive(Debug, Clone)] -pub struct MpcHelperClient { +pub struct MpcHelperClient { client: Client, Body>, scheme: uri::Scheme, authority: uri::Authority, auth_header: Option<(HeaderName, HeaderValue)>, + _restriction: PhantomData, } -impl MpcHelperClient { - /// Create a set of clients for the MPC helpers in the supplied helper network configuration. - /// - /// This function returns a set of three clients, which may be used to talk to each of the - /// helpers. - /// - /// `identity` configures whether and how the client will authenticate to the server. It is for - /// the helper making the calls, so the same one is used for all three of the clients. - /// Authentication is not required when calling the report collector APIs. - #[must_use] - #[allow(clippy::missing_panics_doc)] - pub fn from_conf(conf: &NetworkConfig, identity: &ClientIdentity) -> [MpcHelperClient; 3] { - conf.peers() - .each_ref() - .map(|peer_conf| Self::new(&conf.client, peer_conf.clone(), identity.clone_with_key())) - } - +impl MpcHelperClient { /// Create a new client with the given configuration /// /// `identity`, if present, configures whether and how the client will authenticate to the server @@ -183,9 +193,10 @@ impl MpcHelperClient { /// If some aspect of the configuration is not valid. #[must_use] pub fn new( + runtime: IpaRuntime, client_config: &ClientConfig, peer_config: PeerConfig, - identity: ClientIdentity, + identity: ClientIdentity, ) -> Self { let (connector, auth_header) = if peer_config.url.scheme() == Some(&Scheme::HTTP) { // This connector works for both http and https. A regular HttpConnector would suffice, @@ -195,7 +206,10 @@ impl MpcHelperClient { error!("certificate identity ignored for HTTP client"); None } - ClientIdentity::Helper(id) => Some((HTTP_CLIENT_ID_HEADER.clone(), id.into())), + ClientIdentity::Header(id) => Some(( + F::identity_header(), + HeaderValue::from_str(id.as_str().as_ref()).unwrap(), + )), ClientIdentity::None => None, }; ( @@ -225,7 +239,7 @@ impl MpcHelperClient { ClientIdentity::Certificate((cert_chain, pk)) => builder .with_client_auth_cert(cert_chain, pk) .expect("Can setup client authentication with certificate"), - ClientIdentity::Helper(_) => { + ClientIdentity::Header(_) => { error!("header-passed identity ignored for HTTPS client"); builder.with_no_client_auth() } @@ -247,19 +261,27 @@ impl MpcHelperClient { None, ) }; - Self::new_internal(peer_config.url, connector, auth_header, client_config) + Self::new_internal( + runtime, + peer_config.url, + connector, + auth_header, + client_config, + ) } #[must_use] fn new_internal( + runtime: IpaRuntime, addr: Uri, connector: HttpsConnector, auth_header: Option<(HeaderName, HeaderValue)>, conf: &C, ) -> Self { - let mut builder = Client::builder(TokioExecutor::new()); + let mut builder = Client::builder(runtime); // the following timer is necessary for http2, in particular for any timeouts // and waits the clients will need to make + // TODO: implement IpaTimer to allow wrapping other than Tokio runtimes builder.timer(TokioTimer::new()); let client = conf.configure(&mut builder).build(connector); let Parts { @@ -275,39 +297,20 @@ impl MpcHelperClient { scheme, authority, auth_header, + _restriction: PhantomData, } } - pub fn request(&self, mut req: Request) -> ResponseFuture<'_> { + pub fn request(&self, mut req: Request) -> ResponseFuture { if let Some((k, v)) = self.auth_header.clone() { req.headers_mut().insert(k, v); } ResponseFuture { - authority: &self.authority, + authority: self.authority.clone(), inner: self.client.request(req), } } - /// Helper to read a possible error response to a request that returns nothing on success - /// - /// # Errors - /// If there was an error reading the response body or if the request itself failed. - pub async fn resp_ok(resp: ResponseFromEndpoint<'_>) -> Result<(), Error> { - if resp.status().is_success() { - Ok(()) - } else { - Err(Error::from_failed_resp(resp).await) - } - } - - /// Reads the entire response from the server into Bytes - /// - /// # Errors - /// If there was an error collecting the response stream. - async fn response_to_bytes(resp: ResponseFromEndpoint<'_>) -> Result { - Ok(resp.into_body().collect().await?.to_bytes()) - } - /// Responds with whatever input is passed to it /// # Errors /// If the request has illegal arguments, or fails to deliver to helper @@ -319,7 +322,7 @@ impl MpcHelperClient { let resp = self.request(req).await?; let status = resp.status(); if status.is_success() { - let bytes = Self::response_to_bytes(resp).await?; + let bytes = response_to_bytes(resp).await?; let http_serde::echo::Request { mut query_params, .. } = serde_json::from_slice(&bytes)?; @@ -335,6 +338,65 @@ impl MpcHelperClient { } } + /// Sends a batch of messages associated with a query's step to another helper. Messages are a + /// contiguous block of records. Also includes [`crate::protocol::RecordId`] information and + /// [`crate::helpers::network::ChannelId`]. + /// # Errors + /// If the request has illegal arguments, or fails to deliver to helper + /// # Panics + /// If messages size > max u32 (unlikely) + pub fn step> + Send + 'static>( + &self, + query_id: QueryId, + gate: &Gate, + data: S, + ) -> Result { + let data = data.map(|v| Ok::(Bytes::from(v))); + let body = axum::body::Body::from_stream(data); + let req = http_serde::query::step::Request::new(query_id, gate.clone(), body); + let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; + Ok(self.request(req)) + } + + /// Used to communicate from one helper to another. Specifically, the helper that receives a + /// "create query" from an external party must communicate the intent to start a query to the + /// other helpers, which this prepare query does. + /// # Errors + /// If the request has illegal arguments, or fails to deliver to helper + pub async fn prepare_query(&self, data: PrepareQuery) -> Result<(), Error> { + let req = http_serde::query::prepare::Request::new(data); + let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; + let resp = self.request(req).await?; + resp_ok(resp).await + } +} + +impl MpcHelperClient { + /// Create a set of clients for the MPC helpers in the supplied helper network configuration. + /// + /// This function returns a set of three clients, which may be used to talk to each of the + /// helpers. + /// + /// `identity` configures whether and how the client will authenticate to the server. It is for + /// the helper making the calls, so the same one is used for all three of the clients. + /// Authentication is not required when calling the report collector APIs. + #[must_use] + #[allow(clippy::missing_panics_doc)] + pub fn from_conf( + runtime: &IpaRuntime, + conf: &NetworkConfig, + identity: &ClientIdentity, + ) -> [Self; 3] { + conf.peers().each_ref().map(|peer_conf| { + Self::new( + runtime.clone(), + &conf.client, + peer_conf.clone(), + identity.clone_with_key(), + ) + }) + } + /// Intended to be called externally, by the report collector. Informs the MPC ring that /// the external party wants to start a new query. /// # Errors @@ -344,7 +406,7 @@ impl MpcHelperClient { let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; let resp = self.request(req).await?; if resp.status().is_success() { - let bytes = Self::response_to_bytes(resp).await?; + let bytes = response_to_bytes(resp).await?; let http_serde::query::create::ResponseBody { query_id } = serde_json::from_slice(&bytes)?; Ok(query_id) @@ -353,18 +415,6 @@ impl MpcHelperClient { } } - /// Used to communicate from one helper to another. Specifically, the helper that receives a - /// "create query" from an external party must communicate the intent to start a query to the - /// other helpers, which this prepare query does. - /// # Errors - /// If the request has illegal arguments, or fails to deliver to helper - pub async fn prepare_query(&self, data: PrepareQuery) -> Result<(), Error> { - let req = http_serde::query::prepare::Request::new(data); - let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; - let resp = self.request(req).await?; - Self::resp_ok(resp).await - } - /// Intended to be called externally, e.g. by the report collector. After the report collector /// calls "create query", it must then send the data for the query to each of the clients. This /// query input contains the data intended for a helper. @@ -374,27 +424,7 @@ impl MpcHelperClient { let req = http_serde::query::input::Request::new(data); let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; let resp = self.request(req).await?; - Self::resp_ok(resp).await - } - - /// Sends a batch of messages associated with a query's step to another helper. Messages are a - /// contiguous block of records. Also includes [`crate::protocol::RecordId`] information and - /// [`crate::helpers::network::ChannelId`]. - /// # Errors - /// If the request has illegal arguments, or fails to deliver to helper - /// # Panics - /// If messages size > max u32 (unlikely) - pub fn step> + Send + 'static>( - &self, - query_id: QueryId, - gate: &Gate, - data: S, - ) -> Result { - let data = data.map(|v| Ok::(Bytes::from(v))); - let body = axum::body::Body::from_stream(data); - let req = http_serde::query::step::Request::new(query_id, gate.clone(), body); - let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; - Ok(self.request(req)) + resp_ok(resp).await } /// Retrieve the status of a query. @@ -411,7 +441,7 @@ impl MpcHelperClient { let resp = self.request(req).await?; if resp.status().is_success() { - let bytes = Self::response_to_bytes(resp).await?; + let bytes = response_to_bytes(resp).await?; let http_serde::query::status::ResponseBody { status } = serde_json::from_slice(&bytes)?; Ok(status) @@ -464,14 +494,14 @@ pub(crate) mod tests { use crate::{ ff::{FieldType, Fp31}, helpers::{ - make_owned_handler, query::QueryType::TestMultiply, BytesStream, HelperResponse, - RequestHandler, RoleAssignment, Transport, MESSAGE_PAYLOAD_SIZE_BYTES, + make_owned_handler, query::QueryType::TestMultiply, BytesStream, HelperIdentity, + HelperResponse, RequestHandler, RoleAssignment, Transport, MESSAGE_PAYLOAD_SIZE_BYTES, }, net::test::TestServer, + protocol::step::TestExecutionStep, query::ProtocolResult, secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, sync::Arc, - test_fixture::step::TestExecutionStep, }; #[tokio::test] @@ -487,8 +517,12 @@ pub(crate) mod tests { certificate: None, hpke_config: None, }; - let client = - MpcHelperClient::new(&ClientConfig::default(), peer_config, ClientIdentity::None); + let client = MpcHelperClient::new( + IpaRuntime::current(), + &ClientConfig::default(), + peer_config, + ClientIdentity::::None, + ); // The server's self-signed test cert is not in the system truststore, and we didn't supply // it in the client config, so the connection should fail with a certificate error. @@ -655,9 +689,9 @@ pub(crate) mod tests { .await .unwrap(); - MpcHelperClient::resp_ok(resp).await.unwrap(); + resp_ok(resp).await.unwrap(); - let mut stream = Arc::clone(&transport) + let mut stream = transport .receive(HelperIdentity::ONE, (QueryId, expected_step.clone())) .into_bytes_stream(); diff --git a/ipa-core/src/net/error.rs b/ipa-core/src/net/error.rs index e97551f6f..731df19de 100644 --- a/ipa-core/src/net/error.rs +++ b/ipa-core/src/net/error.rs @@ -73,7 +73,7 @@ impl Error { /// /// # Panics /// If the response is not a failure (4xx/5xx status) - pub async fn from_failed_resp(resp: ResponseFromEndpoint<'_>) -> Self { + pub async fn from_failed_resp(resp: ResponseFromEndpoint) -> Self { let status = resp.status(); assert!(status.is_client_error() || status.is_server_error()); // must be failure let (endpoint, body) = resp.into_parts(); diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 92f9ebf48..1965c15ce 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -122,9 +122,13 @@ pub mod query { QueryType::TEST_MULTIPLY_STR => Ok(QueryType::TestMultiply), #[cfg(any(test, feature = "cli", feature = "test-fixture"))] QueryType::TEST_ADD_STR => Ok(QueryType::TestAddInPrimeField), - QueryType::OPRF_IPA_STR => { + QueryType::SEMI_HONEST_OPRF_IPA_STR => { let Query(q) = req.extract().await?; - Ok(QueryType::OprfIpa(q)) + Ok(QueryType::SemiHonestOprfIpa(q)) + } + QueryType::MALICIOUS_OPRF_IPA_STR => { + let Query(q) = req.extract().await?; + Ok(QueryType::MaliciousOprfIpa(q)) } other => Err(Error::bad_query_value("query_type", other)), }?; @@ -148,7 +152,9 @@ pub mod query { match self.query_type { #[cfg(any(test, feature = "test-fixture", feature = "cli"))] QueryType::TestMultiply | QueryType::TestAddInPrimeField => Ok(()), - QueryType::OprfIpa(config) => { + #[cfg(any(test, feature = "test-fixture", feature = "cli"))] + QueryType::TestShardedShuffle => Ok(()), + QueryType::SemiHonestOprfIpa(config) | QueryType::MaliciousOprfIpa(config) => { write!( f, "&per_user_credit_cap={}&max_breakdown_key={}&with_dp={}&epsilon={}", @@ -166,6 +172,22 @@ pub mod query { write!(f, "&attribution_window_seconds={}", window.get())?; } + Ok(()) + } + QueryType::SemiHonestHybrid(config) => { + write!( + f, + "&per_user_credit_cap={}&max_breakdown_key={}&with_dp={}&epsilon={}", + config.per_user_credit_cap, + config.max_breakdown_key, + config.with_dp, + config.epsilon, + )?; + + if config.plaintext_match_keys { + write!(f, "&plaintext_match_keys=true")?; + } + Ok(()) } } @@ -511,4 +533,82 @@ pub mod query { pub const AXUM_PATH: &str = "/:query_id/complete"; } + + pub mod kill { + use serde::{Deserialize, Serialize}; + + use crate::{ + helpers::{routing::RouteId, HelperResponse, NoStep, RouteParams}, + protocol::QueryId, + }; + + pub struct Request { + pub query_id: QueryId, + } + + impl RouteParams for Request { + type Params = String; + + fn resource_identifier(&self) -> RouteId { + RouteId::KillQuery + } + + fn query_id(&self) -> QueryId { + self.query_id + } + + fn gate(&self) -> NoStep { + NoStep + } + + fn extra(&self) -> Self::Params { + String::new() + } + } + + impl Request { + /// Currently, it is only possible to kill + /// a query by issuing an HTTP request manually. + /// Maybe report collector can support this API, + /// but for now, only tests exercise this path + /// hence methods here are hidden behind feature + /// flags + #[cfg(all(test, unit_test))] + pub fn new(query_id: QueryId) -> Self { + Self { query_id } + } + + #[cfg(all(test, unit_test))] + pub fn try_into_http_request( + self, + scheme: axum::http::uri::Scheme, + authority: axum::http::uri::Authority, + ) -> crate::net::http_serde::OutgoingRequest { + let uri = axum::http::uri::Uri::builder() + .scheme(scheme) + .authority(authority) + .path_and_query(format!( + "{}/{}/kill", + crate::net::http_serde::query::BASE_AXUM_PATH, + self.query_id.as_ref() + )) + .build()?; + Ok(hyper::Request::post(uri).body(axum::body::Body::empty())?) + } + } + + #[derive(Clone, Debug, Serialize, Deserialize)] + pub struct ResponseBody { + pub query_id: QueryId, + pub status: String, + } + + impl From for ResponseBody { + fn from(value: HelperResponse) -> Self { + serde_json::from_slice(value.into_body().as_slice()).unwrap() + } + } + + pub const AXUM_PATH: &str = "/:query_id/kill"; + } } diff --git a/ipa-core/src/net/mod.rs b/ipa-core/src/net/mod.rs index cb1373c7c..e0fdca35a 100644 --- a/ipa-core/src/net/mod.rs +++ b/ipa-core/src/net/mod.rs @@ -1,13 +1,19 @@ use std::{ + fmt::Debug, io::{self, BufRead}, sync::Arc, }; +use hyper::header::HeaderName; use once_cell::sync::Lazy; use rustls::crypto::CryptoProvider; use rustls_pki_types::CertificateDer; -use crate::config::{OwnedCertificate, OwnedPrivateKey}; +use crate::{ + config::{OwnedCertificate, OwnedPrivateKey}, + helpers::{HelperIdentity, TransportIdentity}, + sharding::ShardIndex, +}; mod client; mod error; @@ -20,10 +26,12 @@ mod transport; pub use client::{ClientIdentity, MpcHelperClient}; pub use error::Error; pub use server::{MpcHelperServer, TracingSpanMaker}; -pub use transport::{HttpShardTransport, HttpTransport}; +pub use transport::{HttpTransport, MpcHttpTransport, ShardHttpTransport}; -pub const APPLICATION_JSON: &str = "application/json"; -pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; +const APPLICATION_JSON: &str = "application/json"; +const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; +static HTTP_HELPER_ID_HEADER: HeaderName = HeaderName::from_static("x-unverified-helper-identity"); +static HTTP_SHARD_INDEX_HEADER: HeaderName = HeaderName::from_static("x-unverified-shard-index"); /// This has the same meaning as const defined in h2 crate, but we don't import it directly. /// According to the [`spec`] it cannot exceed 2^31 - 1. @@ -38,6 +46,51 @@ pub(crate) const MAX_HTTP2_CONCURRENT_STREAMS: u32 = 5000; static CRYPTO_PROVIDER: Lazy> = Lazy::new(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider())); +/// This simple trait is used to make aware on what transport dimnsion one is running. Structs like +/// [`MpcHelperClient`] use it to know whether they are talking to other servers as Shards +/// inside a Helper or as a Helper talking to another Helper in a Ring. This trait can be used to +/// limit the functions exposed by a struct impl, depending on the context that it's being used. +/// Continuing the previous example, the functions a [`MpcHelperClient`] provides are dependent +/// on whether it's communicating with another Shard or another Helper. +/// +/// This trait is a safety restriction so that structs or traits only expose an API that's +/// meaningful for their specific context. When used as a generic bound, it also spreads through +/// the types making it harder to be misused or combining incompatible types, e.g. Using a +/// [`ShardIndex`] with a [`Shard`]. +pub trait ConnectionFlavor: Debug + Send + Sync + Clone + 'static { + /// The meaningful identity used in this transport dimension. + type Identity: TransportIdentity; + + /// The header to be used to identify a HTTP request + fn identity_header() -> HeaderName; +} + +/// Shard-to-shard communication marker. +/// This marker is used to restrict communication inside a single Helper, with other shards. +#[derive(Debug, Copy, Clone)] +pub struct Shard; + +/// Helper-to-helper communication marker. +/// This marker is used to restrict communication between Helpers. This communication usually has +/// more restrictions. 3 Hosts with the same sharding index are conencted in a Ring. +#[derive(Debug, Copy, Clone)] +pub struct Helper; + +impl ConnectionFlavor for Shard { + type Identity = ShardIndex; + + fn identity_header() -> HeaderName { + HTTP_SHARD_INDEX_HEADER.clone() + } +} +impl ConnectionFlavor for Helper { + type Identity = HelperIdentity; + + fn identity_header() -> HeaderName { + HTTP_HELPER_ID_HEADER.clone() + } +} + /// Reads certificates and a private key from the corresponding readers /// /// # Errors diff --git a/ipa-core/src/net/server/handlers/mod.rs b/ipa-core/src/net/server/handlers/mod.rs index 14c9b4c49..3e83c6568 100644 --- a/ipa-core/src/net/server/handlers/mod.rs +++ b/ipa-core/src/net/server/handlers/mod.rs @@ -3,16 +3,13 @@ mod query; use axum::Router; -use crate::{ - net::{http_serde, HttpTransport}, - sync::Arc, -}; +use crate::net::{http_serde, transport::MpcHttpTransport}; -pub fn router(transport: Arc) -> Router { +pub fn mpc_router(transport: MpcHttpTransport) -> Router { echo::router().nest( http_serde::query::BASE_AXUM_PATH, Router::new() - .merge(query::query_router(Arc::clone(&transport))) + .merge(query::query_router(transport.clone())) .merge(query::h2h_router(transport)), ) } diff --git a/ipa-core/src/net/server/handlers/query/create.rs b/ipa-core/src/net/server/handlers/query/create.rs index aa4577ec4..58bf71e3b 100644 --- a/ipa-core/src/net/server/handlers/query/create.rs +++ b/ipa-core/src/net/server/handlers/query/create.rs @@ -2,22 +2,21 @@ use axum::{routing::post, Extension, Json, Router}; use hyper::StatusCode; use crate::{ - helpers::{ApiError, BodyStream, Transport}, + helpers::{ApiError, BodyStream}, net::{ http_serde::{self, query::QueryConfigQueryParams}, - Error, HttpTransport, + transport::MpcHttpTransport, + Error, }, query::NewQueryError, - sync::Arc, }; /// Takes details from the HTTP request and creates a `[TransportCommand]::CreateQuery` that is sent /// to the [`HttpTransport`]. async fn handler( - transport: Extension>, + transport: Extension, QueryConfigQueryParams(query_config): QueryConfigQueryParams, ) -> Result, Error> { - let transport = Transport::clone_ref(&*transport); match transport.dispatch(query_config, BodyStream::empty()).await { Ok(resp) => Ok(Json(resp.try_into()?)), Err(err @ ApiError::NewQuery(NewQueryError::State { .. })) => { @@ -27,7 +26,7 @@ async fn handler( } } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(http_serde::query::create::AXUM_PATH, post(handler)) .layer(Extension(transport)) @@ -90,7 +89,7 @@ mod tests { async fn create_test_ipa_no_attr_window() { create_test( QueryConfig::new( - QueryType::OprfIpa(IpaQueryConfig { + QueryType::SemiHonestOprfIpa(IpaQueryConfig { per_user_credit_cap: 1, max_breakdown_key: 1, attribution_window_seconds: None, @@ -107,10 +106,30 @@ mod tests { } #[tokio::test] - async fn create_test_ipa_no_attr_window_with_dp() { + async fn create_test_semi_honest_ipa_no_attr_window_with_dp_default_padding() { create_test( QueryConfig::new( - QueryType::OprfIpa(IpaQueryConfig { + QueryType::SemiHonestOprfIpa(IpaQueryConfig { + per_user_credit_cap: 8, + max_breakdown_key: 20, + attribution_window_seconds: None, + with_dp: 1, + epsilon: 5.0, + plaintext_match_keys: true, + }), + FieldType::Fp32BitPrime, + 1, + ) + .unwrap(), + ) + .await; + } + + #[tokio::test] + async fn create_test_malicious_ipa_no_attr_window_with_dp_default_padding() { + create_test( + QueryConfig::new( + QueryType::MaliciousOprfIpa(IpaQueryConfig { per_user_credit_cap: 8, max_breakdown_key: 20, attribution_window_seconds: None, @@ -131,7 +150,7 @@ mod tests { create_test(QueryConfig { size: 1.try_into().unwrap(), field_type: FieldType::Fp32BitPrime, - query_type: QueryType::OprfIpa(IpaQueryConfig { + query_type: QueryType::SemiHonestOprfIpa(IpaQueryConfig { per_user_credit_cap: 1, max_breakdown_key: 1, attribution_window_seconds: NonZeroU32::new(86_400), @@ -238,7 +257,7 @@ mod tests { fn default() -> Self { Self { field_type: format!("{:?}", FieldType::Fp32BitPrime), - query_type: QueryType::OPRF_IPA_STR.to_string(), + query_type: QueryType::SEMI_HONEST_OPRF_IPA_STR.to_string(), per_user_credit_cap: "1".into(), max_breakdown_key: "1".into(), attribution_window_seconds: None, diff --git a/ipa-core/src/net/server/handlers/query/input.rs b/ipa-core/src/net/server/handlers/query/input.rs index 844604485..da47e9386 100644 --- a/ipa-core/src/net/server/handlers/query/input.rs +++ b/ipa-core/src/net/server/handlers/query/input.rs @@ -2,14 +2,13 @@ use axum::{extract::Path, routing::post, Extension, Router}; use hyper::StatusCode; use crate::{ - helpers::{query::QueryInput, routing::RouteId, BodyStream, Transport}, - net::{http_serde, Error, HttpTransport}, + helpers::{query::QueryInput, routing::RouteId, BodyStream}, + net::{http_serde, transport::MpcHttpTransport, Error}, protocol::QueryId, - sync::Arc, }; async fn handler( - transport: Extension>, + transport: Extension, Path(query_id): Path, input_stream: BodyStream, ) -> Result<(), Error> { @@ -17,7 +16,6 @@ async fn handler( query_id, input_stream, }; - let transport = Transport::clone_ref(&*transport); let _ = transport .dispatch( (RouteId::QueryInput, query_input.query_id), @@ -29,7 +27,7 @@ async fn handler( Ok(()) } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(http_serde::query::input::AXUM_PATH, post(handler)) .layer(Extension(transport)) diff --git a/ipa-core/src/net/server/handlers/query/kill.rs b/ipa-core/src/net/server/handlers/query/kill.rs new file mode 100644 index 000000000..f97fa1657 --- /dev/null +++ b/ipa-core/src/net/server/handlers/query/kill.rs @@ -0,0 +1,134 @@ +use axum::{extract::Path, routing::post, Extension, Json, Router}; +use hyper::StatusCode; + +use crate::{ + helpers::{ApiError, BodyStream}, + net::{ + http_serde::query::kill::{self, Request}, + server::Error, + transport::MpcHttpTransport, + Error::QueryIdNotFound, + }, + protocol::QueryId, + query::QueryKillStatus, +}; + +async fn handler( + transport: Extension, + Path(query_id): Path, +) -> Result, Error> { + let req = Request { query_id }; + match transport.dispatch(req, BodyStream::empty()).await { + Ok(state) => Ok(Json(kill::ResponseBody::from(state))), + Err(ApiError::QueryKill(QueryKillStatus::NoSuchQuery(query_id))) => Err( + Error::application(StatusCode::NOT_FOUND, QueryIdNotFound(query_id)), + ), + Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), + } +} + +pub fn router(transport: MpcHttpTransport) -> Router { + Router::new() + .route(kill::AXUM_PATH, post(handler)) + .layer(Extension(transport)) +} + +#[cfg(all(test, unit_test))] +mod tests { + use axum::{ + body::Body, + http::uri::{Authority, Scheme}, + }; + use hyper::StatusCode; + + use crate::{ + helpers::{ + make_owned_handler, + routing::{Addr, RouteId}, + ApiError, BodyStream, HelperIdentity, HelperResponse, + }, + net::{ + http_serde, + server::handlers::query::test_helpers::{ + assert_fails_with, assert_fails_with_handler, assert_success_with, + }, + }, + protocol::QueryId, + query::{QueryKillStatus, QueryKilled}, + }; + + #[tokio::test] + async fn calls_kill() { + let expected_query_id = QueryId; + + let handler = make_owned_handler( + move |addr: Addr, _data: BodyStream| async move { + let RouteId::KillQuery = addr.route else { + panic!("unexpected call: {addr:?}"); + }; + assert_eq!(addr.query_id, Some(expected_query_id)); + Ok(HelperResponse::from(QueryKilled(expected_query_id))) + }, + ); + + let req = http_serde::query::kill::Request::new(QueryId); + let req = req + .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) + .unwrap(); + assert_success_with(req, handler).await; + } + + #[tokio::test] + async fn no_such_query() { + let handler = make_owned_handler( + move |_addr: Addr, _data: BodyStream| async move { + Err(QueryKillStatus::NoSuchQuery(QueryId).into()) + }, + ); + + let req = http_serde::query::kill::Request::new(QueryId) + .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) + .unwrap(); + assert_fails_with_handler(req, handler, StatusCode::NOT_FOUND).await; + } + + #[tokio::test] + async fn unknown_error() { + let handler = make_owned_handler( + move |_addr: Addr, _data: BodyStream| async move { + Err(ApiError::DeserializationFailure( + serde_json::from_str::<()>("not-a-json").unwrap_err(), + )) + }, + ); + + let req = http_serde::query::kill::Request::new(QueryId) + .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) + .unwrap(); + assert_fails_with_handler(req, handler, StatusCode::INTERNAL_SERVER_ERROR).await; + } + + struct OverrideReq { + query_id: String, + } + + impl From for hyper::Request { + fn from(val: OverrideReq) -> Self { + let uri = format!( + "http://localhost{}/{}/kill", + http_serde::query::BASE_AXUM_PATH, + val.query_id + ); + hyper::Request::post(uri).body(Body::empty()).unwrap() + } + } + + #[tokio::test] + async fn malformed_query_id() { + let req = OverrideReq { + query_id: "not-a-query-id".into(), + }; + + assert_fails_with(req.into(), StatusCode::BAD_REQUEST).await; + } +} diff --git a/ipa-core/src/net/server/handlers/query/mod.rs b/ipa-core/src/net/server/handlers/query/mod.rs index 49f18e0a8..13b3b962d 100644 --- a/ipa-core/src/net/server/handlers/query/mod.rs +++ b/ipa-core/src/net/server/handlers/query/mod.rs @@ -1,5 +1,6 @@ mod create; mod input; +mod kill; mod prepare; mod results; mod status; @@ -17,8 +18,8 @@ use hyper::{Request, StatusCode}; use tower::{layer::layer_fn, Service}; use crate::{ - net::{server::ClientIdentity, HttpTransport}, - sync::Arc, + helpers::HelperIdentity, + net::{server::ClientIdentity, transport::MpcHttpTransport}, }; /// Construct router for IPA query web service @@ -26,11 +27,12 @@ use crate::{ /// In principle, this web service could be backed by either an HTTP-interconnected helper network or /// an in-memory helper network. These are the APIs used by external callers (report collectors) to /// examine attribution results. -pub fn query_router(transport: Arc) -> Router { +pub fn query_router(transport: MpcHttpTransport) -> Router { Router::new() - .merge(create::router(Arc::clone(&transport))) - .merge(input::router(Arc::clone(&transport))) - .merge(status::router(Arc::clone(&transport))) + .merge(create::router(transport.clone())) + .merge(input::router(transport.clone())) + .merge(status::router(transport.clone())) + .merge(kill::router(transport.clone())) .merge(results::router(transport)) } @@ -41,9 +43,9 @@ pub fn query_router(transport: Arc) -> Router { /// particular query, to coordinate servicing that query. // // It might make sense to split the query and h2h handlers into two modules. -pub fn h2h_router(transport: Arc) -> Router { +pub fn h2h_router(transport: MpcHttpTransport) -> Router { Router::new() - .merge(prepare::router(Arc::clone(&transport))) + .merge(prepare::router(transport.clone())) .merge(step::router(transport)) .layer(layer_fn(HelperAuthentication::new)) } @@ -86,7 +88,7 @@ impl, Response = Response>> Service> } fn call(&mut self, req: Request) -> Self::Future { - match req.extensions().get() { + match req.extensions().get::>() { Some(ClientIdentity(_)) => self.inner.call(req).left_future(), None => ready(Ok(( StatusCode::UNAUTHORIZED, @@ -139,6 +141,19 @@ pub mod test_helpers { assert_eq!(resp.status(), expected_status); } + pub async fn assert_fails_with_handler( + req: hyper::Request, + handler: Arc>, + expected_status: StatusCode, + ) { + let test_server = TestServer::builder() + .with_request_handler(handler) + .build() + .await; + let resp = test_server.server.handle_req(req).await; + assert_eq!(resp.status(), expected_status); + } + pub async fn assert_success_with( req: hyper::Request, handler: Arc>, diff --git a/ipa-core/src/net/server/handlers/query/prepare.rs b/ipa-core/src/net/server/handlers/query/prepare.rs index 5ad5431d1..51ed1019d 100644 --- a/ipa-core/src/net/server/handlers/query/prepare.rs +++ b/ipa-core/src/net/server/handlers/query/prepare.rs @@ -2,25 +2,25 @@ use axum::{extract::Path, response::IntoResponse, routing::post, Extension, Json use hyper::StatusCode; use crate::{ - helpers::{query::PrepareQuery, BodyStream, Transport}, + helpers::{query::PrepareQuery, BodyStream, HelperIdentity}, net::{ http_serde::{ self, query::{prepare::RequestBody, QueryConfigQueryParams}, }, server::ClientIdentity, - Error, HttpTransport, + transport::MpcHttpTransport, + Error, }, protocol::QueryId, query::PrepareQueryError, - sync::Arc, }; /// Called by whichever peer helper is the leader for an individual query, to initiatialize /// processing of that query. async fn handler( - transport: Extension>, - _: Extension, // require that client is an authenticated helper + transport: Extension, + _: Extension>, // require that client is an authenticated helper Path(query_id): Path, QueryConfigQueryParams(config): QueryConfigQueryParams, Json(RequestBody { roles }): Json, @@ -30,7 +30,6 @@ async fn handler( config, roles, }; - let transport = Transport::clone_ref(&*transport); let _ = transport .dispatch(data, BodyStream::empty()) .await @@ -45,7 +44,7 @@ impl IntoResponse for PrepareQueryError { } } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(http_serde::query::prepare::AXUM_PATH, post(handler)) .layer(Extension(transport)) @@ -100,7 +99,7 @@ mod tests { // since we tested `QueryType` with `create`, skip it here // More lenient version of Request, specifically so to test failure scenarios struct OverrideReq { - client_id: Option, + client_id: Option>, query_id: String, field_type: String, size: Option, diff --git a/ipa-core/src/net/server/handlers/query/results.rs b/ipa-core/src/net/server/handlers/query/results.rs index abd77b947..1c359b659 100644 --- a/ipa-core/src/net/server/handlers/query/results.rs +++ b/ipa-core/src/net/server/handlers/query/results.rs @@ -2,31 +2,29 @@ use axum::{extract::Path, routing::get, Extension, Router}; use hyper::StatusCode; use crate::{ - helpers::{BodyStream, Transport}, + helpers::BodyStream, net::{ http_serde::{self, query::results::Request}, server::Error, - HttpTransport, + transport::MpcHttpTransport, }, protocol::QueryId, - sync::Arc, }; /// Handles the completion of the query by blocking the sender until query is completed. async fn handler( - transport: Extension>, + transport: Extension, Path(query_id): Path, ) -> Result, Error> { let req = Request { query_id }; // TODO: we may be able to stream the response - let transport = Transport::clone_ref(&*transport); match transport.dispatch(req, BodyStream::empty()).await { Ok(resp) => Ok(resp.into_body()), Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), } } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(http_serde::query::results::AXUM_PATH, get(handler)) .layer(Extension(transport)) diff --git a/ipa-core/src/net/server/handlers/query/status.rs b/ipa-core/src/net/server/handlers/query/status.rs index dcd4e1c62..0056b76d0 100644 --- a/ipa-core/src/net/server/handlers/query/status.rs +++ b/ipa-core/src/net/server/handlers/query/status.rs @@ -2,29 +2,27 @@ use axum::{extract::Path, routing::get, Extension, Json, Router}; use hyper::StatusCode; use crate::{ - helpers::{BodyStream, Transport}, + helpers::BodyStream, net::{ http_serde::query::status::{self, Request}, server::Error, - HttpTransport, + transport::MpcHttpTransport, }, protocol::QueryId, - sync::Arc, }; async fn handler( - transport: Extension>, + transport: Extension, Path(query_id): Path, ) -> Result, Error> { let req = Request { query_id }; - let transport = Transport::clone_ref(&*transport); match transport.dispatch(req, BodyStream::empty()).await { Ok(state) => Ok(Json(status::ResponseBody::from(state))), Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), } } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(status::AXUM_PATH, get(handler)) .layer(Extension(transport)) diff --git a/ipa-core/src/net/server/handlers/query/step.rs b/ipa-core/src/net/server/handlers/query/step.rs index 07e511c65..2c112be92 100644 --- a/ipa-core/src/net/server/handlers/query/step.rs +++ b/ipa-core/src/net/server/handlers/query/step.rs @@ -1,30 +1,28 @@ use axum::{extract::Path, routing::post, Extension, Router}; use crate::{ - helpers::{BodyStream, Transport}, + helpers::{BodyStream, HelperIdentity}, net::{ http_serde, server::{ClientIdentity, Error}, - HttpTransport, + transport::MpcHttpTransport, }, protocol::{Gate, QueryId}, - sync::Arc, }; #[allow(clippy::unused_async)] // axum doesn't like synchronous handler #[tracing::instrument(level = "trace", "step", skip_all, fields(from = ?**from, gate = ?gate))] async fn handler( - transport: Extension>, - from: Extension, + transport: Extension, + from: Extension>, Path((query_id, gate)): Path<(QueryId, Gate)>, body: BodyStream, ) -> Result<(), Error> { - let transport = Transport::clone_ref(&*transport); transport.receive_stream(query_id, gate, **from, body); Ok(()) } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(http_serde::query::step::AXUM_PATH, post(handler)) .layer(Extension(transport)) @@ -41,7 +39,7 @@ mod tests { use super::*; use crate::{ - helpers::{HelperIdentity, MESSAGE_PAYLOAD_SIZE_BYTES}, + helpers::{HelperIdentity, Transport, MESSAGE_PAYLOAD_SIZE_BYTES}, net::{ server::handlers::query::test_helpers::{assert_fails_with, MaybeExtensionExt}, test::TestServer, @@ -65,7 +63,8 @@ mod tests { test_server.server.handle_req(req.into()).await; - let mut stream = Arc::clone(&test_server.transport) + let mut stream = test_server + .transport .receive(HelperIdentity::TWO, (QueryId, step)) .into_bytes_stream(); @@ -76,7 +75,7 @@ mod tests { } struct OverrideReq { - client_id: Option, + client_id: Option>, query_id: String, gate: Gate, payload: Vec, diff --git a/ipa-core/src/net/server/mod.rs b/ipa-core/src/net/server/mod.rs index 87d7ee2cd..f0229f4e5 100644 --- a/ipa-core/src/net/server/mod.rs +++ b/ipa-core/src/net/server/mod.rs @@ -4,6 +4,7 @@ mod handlers; use std::{ borrow::Cow, io, + marker::PhantomData, net::{Ipv4Addr, SocketAddr, TcpListener}, ops::Deref, task::{Context, Poll}, @@ -15,6 +16,7 @@ use ::tokio::{ net::TcpStream, }; use axum::{ + http::HeaderValue, response::{IntoResponse, Response}, routing::IntoMakeService, Router, @@ -25,32 +27,34 @@ use axum_server::{ tls_rustls::{RustlsAcceptor, RustlsConfig}, Handle, Server, }; -use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use futures::{ future::{ready, BoxFuture, Either, Ready}, - Future, FutureExt, + FutureExt, }; -use hyper::{body::Incoming, header::HeaderName, Request}; +use hyper::{body::Incoming, Request}; use metrics::increment_counter; use rustls::{server::WebPkiClientVerifier, RootCertStore}; -use rustls_pki_types::CertificateDer; -#[cfg(all(feature = "shuttle", test))] -use shuttle::future as tokio; use tokio_rustls::server::TlsStream; use tower::{layer::layer_fn, Service}; use tower_http::trace::TraceLayer; use tracing::{error, Span}; +use super::{ + transport::{MpcHttpTransport, ShardHttpTransport}, + Shard, +}; use crate::{ - config::{NetworkConfig, OwnedCertificate, OwnedPrivateKey, ServerConfig, TlsConfig}, + config::{ + NetworkConfig, OwnedCertificate, OwnedPrivateKey, PeerConfig, ServerConfig, TlsConfig, + }, error::BoxError, - helpers::HelperIdentity, + executor::{IpaJoinHandle, IpaRuntime}, + helpers::TransportIdentity, net::{ - parse_certificate_and_private_key_bytes, server::config::HttpServerConfig, Error, - HttpTransport, CRYPTO_PROVIDER, + parse_certificate_and_private_key_bytes, server::config::HttpServerConfig, + ConnectionFlavor, Error, Helper, CRYPTO_PROVIDER, }, sync::Arc, - task::JoinHandle, telemetry::metrics::{web::RequestProtocolVersion, REQUESTS_RECEIVED}, }; @@ -76,34 +80,54 @@ impl TracingSpanMaker for () { /// IPA helper web service /// -/// `MpcHelperServer` handles requests from both peer helpers and external clients. -pub struct MpcHelperServer { - transport: Arc, +/// `MpcHelperServer` handles requests from peer helpers, shards within the same helper and +/// external clients. +/// +/// The Transport Restriction generic is used to make the server aware whether it should offer a +/// HTTP API for shards or for other Helpers. External clients can reach out to both APIs to push +/// the input data among other things. +pub struct MpcHelperServer { config: ServerConfig, - network_config: NetworkConfig, + network_config: NetworkConfig, + router: Router, } -impl MpcHelperServer { - pub fn new( - transport: Arc, +impl MpcHelperServer { + #[must_use] + pub fn new_mpc( + transport: &MpcHttpTransport, config: ServerConfig, - network_config: NetworkConfig, + network_config: NetworkConfig, ) -> Self { + let router = handlers::mpc_router(transport.clone()); MpcHelperServer { - transport, config, network_config, + router, } } +} - fn router(&self) -> Router { - handlers::router(Arc::clone(&self.transport)) +impl MpcHelperServer { + #[must_use] + pub fn new_shards( + _transport: &ShardHttpTransport, + config: ServerConfig, + network_config: NetworkConfig, + ) -> Self { + MpcHelperServer { + config, + network_config, + router: Router::new(), + } } +} +impl MpcHelperServer { #[cfg(all(test, unit_test))] async fn handle_req(&self, req: hyper::Request) -> axum::response::Response { use tower::ServiceExt; - self.router().oneshot(req).await.unwrap() + self.router.clone().oneshot(req).await.unwrap() } /// Starts the MPC helper service. @@ -121,9 +145,10 @@ impl MpcHelperServer { /// configured, it must be valid.) pub async fn start_on( &self, + runtime: &IpaRuntime, listener: Option, tracing: T, - ) -> (SocketAddr, JoinHandle<()>) { + ) -> (SocketAddr, IpaJoinHandle<()>) { // This should probably come from the server config. // Note that listening on 0.0.0.0 requires accepting a MacOS security // warning on each test run. @@ -132,7 +157,7 @@ impl MpcHelperServer { #[cfg(not(test))] const BIND_ADDRESS: Ipv4Addr = Ipv4Addr::UNSPECIFIED; - let svc = self.router().layer( + let svc = self.router.clone().layer( TraceLayer::new_for_http() .make_span_with(move |_request: &hyper::Request<_>| tracing.make_span()) .on_request(|request: &hyper::Request<_>, _: &Span| { @@ -145,22 +170,29 @@ impl MpcHelperServer { let task_handle = match (self.config.disable_https, listener) { (true, Some(listener)) => { let svc = svc - .layer(layer_fn(SetClientIdentityFromHeader::new)) + .layer(layer_fn(SetClientIdentityFromHeader::<_, F>::new)) .into_make_service(); - spawn_server(axum_server::from_tcp(listener), handle.clone(), svc).await + spawn_server( + runtime, + axum_server::from_tcp(listener), + handle.clone(), + svc, + ) + .await } (true, None) => { let addr = SocketAddr::new(BIND_ADDRESS.into(), self.config.port.unwrap_or(0)); let svc = svc - .layer(layer_fn(SetClientIdentityFromHeader::new)) + .layer(layer_fn(SetClientIdentityFromHeader::<_, F>::new)) .into_make_service(); - spawn_server(axum_server::bind(addr), handle.clone(), svc).await + spawn_server(runtime, axum_server::bind(addr), handle.clone(), svc).await } (false, Some(listener)) => { - let rustls_config = rustls_config(&self.config, &self.network_config) + let rustls_config = rustls_config(&self.config, self.network_config.vec_peers()) .await .expect("invalid TLS configuration"); spawn_server( + runtime, axum_server::from_tcp_rustls(listener, rustls_config).map(|a| { ClientCertRecognizingAcceptor::new(a, self.network_config.clone()) }), @@ -171,10 +203,11 @@ impl MpcHelperServer { } (false, None) => { let addr = SocketAddr::new(BIND_ADDRESS.into(), self.config.port.unwrap_or(0)); - let rustls_config = rustls_config(&self.config, &self.network_config) + let rustls_config = rustls_config(&self.config, self.network_config.vec_peers()) .await .expect("invalid TLS configuration"); spawn_server( + runtime, axum_server::bind_rustls(addr, rustls_config).map(|a| { ClientCertRecognizingAcceptor::new(a, self.network_config.clone()) }), @@ -201,30 +234,24 @@ impl MpcHelperServer { ); (bound_addr, task_handle) } - - pub fn start( - &self, - tracing: T, - ) -> impl Future)> + '_ { - self.start_on(None, tracing) - } } /// Spawns a new server with the given configuration. /// This function glues Tower, Axum, Hyper and Axum-Server together, hence the trait bounds. #[allow(clippy::unused_async)] async fn spawn_server( + runtime: &IpaRuntime, mut server: Server, handle: Handle, svc: IntoMakeService, -) -> JoinHandle<()> +) -> IpaJoinHandle<()> where A: Accept + Clone + Send + Sync + 'static, A::Stream: AsyncRead + AsyncWrite + Unpin + Send, A::Service: SendService> + Send + Service>, A::Future: Send, { - tokio::spawn({ + runtime.spawn({ async move { // Apply configuration HttpServerConfig::apply(&mut server.http_builder().http2()); @@ -273,16 +300,12 @@ async fn certificate_and_key( /// If there is a problem with the TLS configuration. async fn rustls_config( config: &ServerConfig, - network: &NetworkConfig, + certs: Vec, ) -> Result { let (cert, key) = certificate_and_key(config).await?; let mut trusted_certs = RootCertStore::empty(); - for cert in network - .peers() - .iter() - .filter_map(|peer| peer.certificate.clone()) - { + for cert in certs.into_iter().filter_map(|peer| peer.certificate) { // Note that this uses `webpki::TrustAnchor::try_from_cert_der`, which *does not* validate // the certificate. That is not required for security, but might be desirable to flag // configuration errors. @@ -306,68 +329,61 @@ async fn rustls_config( Ok(RustlsConfig::from_config(Arc::new(config))) } -/// Axum `Extension` indicating the authenticated remote helper identity, if any. +/// Axum `Extension` indicating the authenticated remote identity, if any. This can be either a +/// Shard authenticating or another Helper. // -// Presence or absence of authentication is indicated by presence or absence of the extension. Even -// at some inconvenience (e.g. `MaybeExtensionExt`), we avoid using `Option` within the extension, -// to avoid possible confusion about how many times the return from `req.extensions().get()` must be -// unwrapped to ensure valid authentication. -#[derive(Clone, Copy, Debug)] -struct ClientIdentity(pub HelperIdentity); +/// Presence or absence of authentication is indicated by presence or absence of the extension. Even +/// at some inconvenience (e.g. `MaybeExtensionExt`), we avoid using `Option` within the extension, +/// to avoid possible confusion about how many times the return from `req.extensions().get()` must be +/// unwrapped to ensure valid authentication. +#[derive(Clone, Copy, Debug, PartialEq)] +struct ClientIdentity(pub I); -impl Deref for ClientIdentity { - type Target = HelperIdentity; +impl Deref for ClientIdentity { + type Target = I; fn deref(&self) -> &Self::Target { &self.0 } } +impl TryFrom<&HeaderValue> for ClientIdentity { + type Error = Error; + + fn try_from(value: &HeaderValue) -> Result { + let header_str = value.to_str()?; + I::from_str(header_str) + .map_err(|e| Error::InvalidHeader(Box::new(e))) + .map(ClientIdentity) + } +} + /// `Accept`or that sets an axum `Extension` indiciating the authenticated remote helper identity. +/// Validating the certificate is something that happens earlier at connection time, this just +/// provide identity to the inner server handlers. #[derive(Clone)] -struct ClientCertRecognizingAcceptor { +struct ClientCertRecognizingAcceptor { inner: RustlsAcceptor, - network_config: Arc, + network_config: Arc>, } -impl ClientCertRecognizingAcceptor { - fn new(inner: RustlsAcceptor, network_config: NetworkConfig) -> Self { +impl ClientCertRecognizingAcceptor { + fn new(inner: RustlsAcceptor, network_config: NetworkConfig) -> Self { Self { inner, network_config: Arc::new(network_config), } } - - // This can't be a method (at least not that takes `&self`) because it needs to go in a 'static future. - fn identify_client( - network_config: &NetworkConfig, - cert_option: Option<&CertificateDer>, - ) -> Option { - let cert = cert_option?; - // We currently require an exact match with the peer cert (i.e. we don't support verifying - // the certificate against a truststore and identifying the peer by the certificate - // subject). This could be changed if the need arises. - for (id, peer) in network_config.enumerate_peers() { - if peer.certificate.as_ref() == Some(cert) { - return Some(ClientIdentity(id)); - } - } - // It might be nice to log something here. We could log the certificate base64? - error!( - "A client certificate was presented that does not match a known helper. Certificate: {}", - BASE64.encode(cert), - ); - None - } } -impl Accept for ClientCertRecognizingAcceptor +impl Accept for ClientCertRecognizingAcceptor where I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Send + 'static, + F: ConnectionFlavor, { type Stream = TlsStream; - type Service = SetClientIdentityFromCertificate; + type Service = SetClientIdentityFromCertificate; type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>; fn accept(&self, stream: I, service: S) -> Self::Future { @@ -376,7 +392,7 @@ where Box::pin(async move { let (stream, service) = acceptor.accept(stream, service).await.map_err(|err| { - error!("[ClientCertRecognizingAcceptor] connection error: {err}"); + error!("[ClientCertRecognizingAcceptor] Internal acceptor error: {err}"); err })?; @@ -387,27 +403,33 @@ where // certificate here, because the certificate must have passed full verification at // connection time. But it's possible the certificate subject is not something we // recognize as a helper. - let id = Self::identify_client( - &network_config, - stream - .get_ref() - .1 - .peer_certificates() - .and_then(<[_]>::first), - ); - let service = SetClientIdentityFromCertificate { inner: service, id }; + let opt_cert = stream + .get_ref() + .1 + .peer_certificates() + .and_then(<[_]>::first); + let option_id: Option = network_config.identify_cert(opt_cert); + let client_id = option_id.map(ClientIdentity); + let service = SetClientIdentityFromCertificate { + inner: service, + id: client_id, + }; Ok((stream, service)) }) } } #[derive(Clone)] -struct SetClientIdentityFromCertificate { +struct SetClientIdentityFromCertificate { inner: S, - id: Option, + id: Option>, } -impl>> Service> for SetClientIdentityFromCertificate { +impl Service> for SetClientIdentityFromCertificate +where + S: Service>, + F: ConnectionFlavor, +{ type Response = S::Response; type Error = S::Error; type Future = S::Future; @@ -425,27 +447,29 @@ impl>> Service> for SetClientIdentityFromCer } } -/// Name of the header that passes the client identity when not using HTTPS. -pub static HTTP_CLIENT_ID_HEADER: HeaderName = - HeaderName::from_static("x-unverified-client-identity"); - /// Service wrapper that gets a client helper identity from a header. /// /// Since this allows a client to claim any identity, it is completely /// insecure. It must only be used in contexts where that is acceptable. #[derive(Clone)] -struct SetClientIdentityFromHeader { +struct SetClientIdentityFromHeader { inner: S, + _restriction: PhantomData, } -impl SetClientIdentityFromHeader { +impl SetClientIdentityFromHeader { fn new(inner: S) -> Self { - Self { inner } + Self { + inner, + _restriction: PhantomData, + } } } -impl, Response = Response>> Service> - for SetClientIdentityFromHeader +impl Service> for SetClientIdentityFromHeader +where + S: Service, Response = Response>, + F: ConnectionFlavor, { type Response = Response; type Error = S::Error; @@ -457,11 +481,10 @@ impl, Response = Response>> Service> } fn call(&mut self, mut req: Request) -> Self::Future { - if let Some(header_value) = req.headers().get(&HTTP_CLIENT_ID_HEADER) { - let id_result = serde_json::from_slice(header_value.as_ref()) - .map_err(|e| Error::InvalidHeader(format!("{HTTP_CLIENT_ID_HEADER}: {e}").into())); + if let Some(header_value) = req.headers().get(F::identity_header()) { + let id_result = ClientIdentity::::try_from(header_value); match id_result { - Ok(id) => req.extensions_mut().insert(ClientIdentity(id)), + Ok(id) => req.extensions_mut().insert(id), Err(err) => return ready(Ok(err.into_response())).right_future(), }; } @@ -469,6 +492,28 @@ impl, Response = Response>> Service> } } +#[cfg(all(test, unit_test))] +mod tests { + use axum::http::HeaderValue; + + use crate::{helpers::HelperIdentity, net::server::ClientIdentity}; + + #[test] + fn identify_from_header_happy_case() { + let h = HeaderValue::from_static("A"); + let id = ClientIdentity::::try_from(&h); + assert_eq!(id.unwrap(), ClientIdentity(HelperIdentity::ONE)); + } + + #[test] + #[should_panic = "The string H1 is an invalid Helper Identity"] + fn identify_from_header_wrong_header() { + let h = HeaderValue::from_static("H1"); + let id = ClientIdentity::::try_from(&h); + id.unwrap(); + } +} + #[cfg(all(test, unit_test))] mod e2e_tests { use std::collections::HashMap; @@ -489,6 +534,7 @@ mod e2e_tests { client::danger::{ServerCertVerified, ServerCertVerifier}, pki_types::ServerName, }; + use rustls_pki_types::CertificateDer; use tracing::Level; use super::*; @@ -721,6 +767,7 @@ mod e2e_tests { let expected = expected_req(addr.to_string()); let req = http_req(&expected, uri::Scheme::HTTP, addr.to_string()); let response = client.request(req).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); assert_eq!( diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index e6edcc0f6..e62bccce6 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -15,16 +15,17 @@ use std::{ use once_cell::sync::Lazy; use rustls_pki_types::CertificateDer; -use tokio::task::JoinHandle; +use super::transport::MpcHttpTransport; use crate::{ config::{ ClientConfig, HpkeClientConfig, HpkeServerConfig, NetworkConfig, PeerConfig, ServerConfig, TlsConfig, }, + executor::{IpaJoinHandle, IpaRuntime}, helpers::{HandlerBox, HelperIdentity, RequestHandler}, hpke::{Deserializable as _, IpaPublicKey}, - net::{ClientIdentity, HttpTransport, MpcHelperClient, MpcHelperServer}, + net::{ClientIdentity, Helper, MpcHelperClient, MpcHelperServer}, sync::Arc, test_fixture::metrics::MetricsHandle, }; @@ -33,7 +34,7 @@ pub const DEFAULT_TEST_PORTS: [u16; 3] = [3000, 3001, 3002]; pub struct TestConfig { pub disable_https: bool, - pub network: NetworkConfig, + pub network: NetworkConfig, pub servers: [ServerConfig; 3], pub sockets: Option<[TcpListener; 3]>, } @@ -174,16 +175,13 @@ impl TestConfigBuilder { )) }, }) - .collect::>() - .try_into() - .unwrap(); - let network = NetworkConfig { + .collect::>(); + let network = NetworkConfig::::new_mpc( peers, - client: self - .use_http1 + self.use_http1 .then(ClientConfig::use_http1) .unwrap_or_default(), - }; + ); let servers = if self.disable_https { ports.map(|ports| server_config_insecure_http(ports, !self.disable_matchkey_encryption)) } else { @@ -201,9 +199,9 @@ impl TestConfigBuilder { pub struct TestServer { pub addr: SocketAddr, - pub handle: JoinHandle<()>, - pub transport: Arc, - pub server: MpcHelperServer, + pub handle: IpaJoinHandle<()>, + pub transport: MpcHttpTransport, + pub server: MpcHelperServer, pub client: MpcHelperClient, pub request_handler: Option>>, } @@ -273,7 +271,7 @@ impl TestServerBuilder { pub async fn build(self) -> TestServer { let identity = if self.disable_https { - ClientIdentity::Helper(HelperIdentity::ONE) + ClientIdentity::Header(HelperIdentity::ONE) } else { get_test_identity(HelperIdentity::ONE) }; @@ -291,22 +289,24 @@ impl TestServerBuilder { else { panic!("TestConfig should have allocated ports"); }; - let clients = MpcHelperClient::from_conf(&network_config, &identity.clone_with_key()); + let clients = MpcHelperClient::from_conf( + &IpaRuntime::current(), + &network_config, + &identity.clone_with_key(), + ); let handler = self.handler.as_ref().map(HandlerBox::owning_ref); - let (transport, server) = HttpTransport::new( + let client = clients[0].clone(); + let (transport, server) = MpcHttpTransport::new( + IpaRuntime::current(), HelperIdentity::ONE, server_config, network_config.clone(), - clients, + &clients, handler, ); - let (addr, handle) = server.start_on(Some(server_socket), self.metrics).await; - // Get the config for HelperIdentity::ONE - let h1_peer_config = network_config.peers.into_iter().next().unwrap(); - // At some point it might be appropriate to return two clients here -- the first being - // another helper and the second being a report collector. For now we use the same client - // for both types of calls. - let client = MpcHelperClient::new(&network_config.client, h1_peer_config, identity); + let (addr, handle) = server + .start_on(&IpaRuntime::current(), Some(server_socket), self.metrics) + .await; TestServer { addr, handle, diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 81d4bdcce..a0fcd92a3 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -9,14 +9,16 @@ use async_trait::async_trait; use futures::{Stream, TryFutureExt}; use pin_project::{pin_project, pinned_drop}; +use super::{client::resp_ok, ConnectionFlavor, Helper, Shard}; use crate::{ config::{NetworkConfig, ServerConfig}, + executor::IpaRuntime, helpers::{ query::QueryConfig, routing::{Addr, RouteId}, ApiError, BodyStream, HandlerRef, HelperIdentity, HelperResponse, NoQueryId, NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords, RequestHandler, RouteParams, - StepBinding, StreamCollection, Transport, + StepBinding, StreamCollection, Transport, TransportIdentity, }, net::{client::MpcHelperClient, error::Error, MpcHelperServer}, protocol::{Gate, QueryId}, @@ -24,20 +26,26 @@ use crate::{ sync::Arc, }; -/// HTTP transport for IPA helper service. -/// TODO: rename to MPC -pub struct HttpTransport { - identity: HelperIdentity, - clients: [MpcHelperClient; 3], - // TODO(615): supporting multiple queries likely require a hashmap here. It will be ok if we - // only allow one query at a time. - record_streams: StreamCollection, - handler: Option, +/// Shared implementation used by [`MpcHttpTransport`] and [`ShardHttpTransport`] +pub struct HttpTransport { + http_runtime: IpaRuntime, + identity: F::Identity, + clients: Vec>, + record_streams: StreamCollection, + handler: Option>, } -/// A stub for HTTP transport implementation, suitable for serviing inter-shard traffic -#[derive(Clone, Default)] -pub struct HttpShardTransport; +/// HTTP transport for helper to helper traffic. +#[derive(Clone)] +pub struct MpcHttpTransport { + inner_transport: Arc>, +} + +/// A stub for HTTP transport implementation, suitable for serving shard-to-shard traffic +#[derive(Clone)] +pub struct ShardHttpTransport { + inner_transport: Arc>, +} impl RouteParams for QueryConfig { type Params = String; @@ -59,31 +67,65 @@ impl RouteParams for QueryConfig { } } -impl HttpTransport { - #[must_use] - pub fn new( - identity: HelperIdentity, - server_config: ServerConfig, - network_config: NetworkConfig, - clients: [MpcHelperClient; 3], - handler: Option, - ) -> (Arc, MpcHelperServer) { - let transport = Self::new_internal(identity, clients, handler); - let server = MpcHelperServer::new(Arc::clone(&transport), server_config, network_config); - (transport, server) +impl HttpTransport { + async fn send< + D: Stream> + Send + 'static, + Q: QueryIdBinding, + S: StepBinding, + R: RouteParams, + >( + &self, + dest: F::Identity, + route: R, + data: D, + ) -> Result<(), Error> + where + Option: From, + Option: From, + { + let route_id = route.resource_identifier(); + let client_ix = dest.as_index(); + match route_id { + RouteId::Records => { + // TODO(600): These fallible extractions aren't really necessary. + let query_id = >::from(route.query_id()) + .expect("query_id required when sending records"); + let step = + >::from(route.gate()).expect("step required when sending records"); + let resp_future = self.clients[client_ix].step(query_id, &step, data)?; + // Use a dedicated HTTP runtime to poll this future for several reasons: + // - avoid blocking this task, if the current runtime is overloaded + // - use the runtime that enables IO (current runtime may not). + self.http_runtime + .spawn(resp_future.map_err(Into::into).and_then(resp_ok)) + .await?; + Ok(()) + } + RouteId::PrepareQuery => { + let req = serde_json::from_str(route.extra().borrow()).unwrap(); + self.clients[client_ix].prepare_query(req).await + } + evt @ (RouteId::QueryInput + | RouteId::ReceiveQuery + | RouteId::QueryStatus + | RouteId::CompleteQuery + | RouteId::KillQuery) => { + unimplemented!( + "attempting to send client-specific request {evt:?} to another helper" + ) + } + } } - fn new_internal( - identity: HelperIdentity, - clients: [MpcHelperClient; 3], - handler: Option, - ) -> Arc { - Arc::new(Self { - identity, - clients, - handler, - record_streams: StreamCollection::default(), - }) + fn receive>( + &self, + from: F::Identity, + route: &R, + ) -> ReceiveRecords { + ReceiveRecords::new( + (route.query_id(), from, route.gate()), + self.record_streams.clone(), + ) } /// Dispatches the given request to the [`RequestHandler`] connected to this transport. @@ -107,13 +149,13 @@ impl HttpTransport { /// This implementation is a poor man's safety net and only works because we run /// one query at a time and don't use query identifiers. #[pin_project(PinnedDrop)] - struct ClearOnDrop { - transport: Arc, + struct ClearOnDrop { + transport: Arc>, #[pin] inner: F, } - impl Future for ClearOnDrop { + impl Future for ClearOnDrop { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -122,7 +164,7 @@ impl HttpTransport { } #[pinned_drop] - impl PinnedDrop for ClearOnDrop { + impl PinnedDrop for ClearOnDrop { fn drop(self: Pin<&mut Self>) { self.transport.record_streams.clear(); } @@ -135,7 +177,7 @@ impl HttpTransport { .expect("A Handler should be set by now") .handle(Addr::from_route(None, req), body); - if let RouteId::CompleteQuery = route_id { + if let RouteId::CompleteQuery | RouteId::KillQuery = route_id { ClearOnDrop { transport: Arc::clone(&self), inner: r, @@ -145,30 +187,75 @@ impl HttpTransport { r.await } } +} + +impl MpcHttpTransport { + #[must_use] + pub fn new( + http_runtime: IpaRuntime, + identity: HelperIdentity, + server_config: ServerConfig, + network_config: NetworkConfig, + clients: &[MpcHelperClient; 3], + handler: Option>, + ) -> (Self, MpcHelperServer) { + let transport = Self { + inner_transport: Arc::new(HttpTransport { + http_runtime, + identity, + clients: clients.to_vec(), + handler, + record_streams: StreamCollection::default(), + }), + }; + + let server = MpcHelperServer::new_mpc(&transport, server_config, network_config); + (transport, server) + } - /// Connect an inbound stream of MPC record data. + /// Connect an inbound stream of record data. /// /// This is called by peer helpers via the HTTP server. pub fn receive_stream( - self: Arc, + &self, query_id: QueryId, gate: Gate, from: HelperIdentity, stream: BodyStream, ) { - self.record_streams + self.inner_transport + .record_streams .add_stream((query_id, from, gate), stream); } + + /// Dispatches the given request to the [`RequestHandler`] connected to this transport. + /// + /// ## Errors + /// Returns an error, if handler rejects the request for any reason. + /// + /// ## Panics + /// This will panic if request handler hasn't been previously set for this transport. + pub async fn dispatch>( + &self, + req: R, + body: BodyStream, + ) -> Result + where + Option: From, + { + let t = Arc::clone(&self.inner_transport); + t.dispatch(req, body).await + } } #[async_trait] -impl Transport for Arc { +impl Transport for MpcHttpTransport { type Identity = HelperIdentity; - type RecordsStream = ReceiveRecords; + type RecordsStream = ReceiveRecords; type Error = Error; - fn identity(&self) -> HelperIdentity { - self.identity + fn identity(&self) -> Self::Identity { + self.inner_transport.identity } async fn send< @@ -178,7 +265,7 @@ impl Transport for Arc { R: RouteParams, >( &self, - dest: HelperIdentity, + dest: Self::Identity, route: R, data: D, ) -> Result<(), Error> @@ -186,65 +273,58 @@ impl Transport for Arc { Option: From, Option: From, { - let route_id = route.resource_identifier(); - match route_id { - RouteId::Records => { - // TODO(600): These fallible extractions aren't really necessary. - let query_id = >::from(route.query_id()) - .expect("query_id required when sending records"); - let step = - >::from(route.gate()).expect("step required when sending records"); - let resp_future = self.clients[dest].step(query_id, &step, data)?; - // we don't need to spawn a task here. Gateway's sender interface already does that - // so this can just poll this future. - resp_future - .map_err(Into::into) - .and_then(MpcHelperClient::resp_ok) - .await?; - Ok(()) - } - RouteId::PrepareQuery => { - let req = serde_json::from_str(route.extra().borrow()).unwrap(); - self.clients[dest].prepare_query(req).await - } - evt @ (RouteId::QueryInput - | RouteId::ReceiveQuery - | RouteId::QueryStatus - | RouteId::CompleteQuery) => { - unimplemented!( - "attempting to send client-specific request {evt:?} to another helper" - ) - } - } + self.inner_transport.send(dest, route, data).await } fn receive>( &self, - from: HelperIdentity, + from: Self::Identity, route: R, ) -> Self::RecordsStream { - ReceiveRecords::new( - (route.query_id(), from, route.gate()), - self.record_streams.clone(), - ) + self.inner_transport.receive(from, &route) + } +} + +impl ShardHttpTransport { + #[must_use] + pub fn new( + http_runtime: IpaRuntime, + identity: ShardIndex, + server_config: ServerConfig, + network_config: NetworkConfig, + clients: Vec>, + handler: Option>, + ) -> (Self, MpcHelperServer) { + let transport = Self { + inner_transport: Arc::new(HttpTransport { + http_runtime, + identity, + clients, + handler, + record_streams: StreamCollection::default(), + }), + }; + + let server = MpcHelperServer::new_shards(&transport, server_config, network_config); + (transport, server) } } #[async_trait] -impl Transport for HttpShardTransport { +impl Transport for ShardHttpTransport { type Identity = ShardIndex; type RecordsStream = ReceiveRecords; - type Error = (); + type Error = Error; fn identity(&self) -> Self::Identity { - unimplemented!() + self.inner_transport.identity } async fn send( &self, - _dest: Self::Identity, - _route: R, - _data: D, + dest: Self::Identity, + route: R, + data: D, ) -> Result<(), Self::Error> where Option: From, @@ -254,15 +334,15 @@ impl Transport for HttpShardTransport { R: RouteParams, D: Stream> + Send + 'static, { - unimplemented!() + self.inner_transport.send(dest, route, data).await } fn receive>( &self, - _from: Self::Identity, - _route: R, + from: Self::Identity, + route: R, ) -> Self::RecordsStream { - unimplemented!() + self.inner_transport.receive(from, &route) } } @@ -283,7 +363,10 @@ mod tests { use crate::{ config::{NetworkConfig, ServerConfig}, ff::{FieldType, Fp31, Serializable}, - helpers::query::{QueryInput, QueryType::TestMultiply}, + helpers::{ + make_owned_handler, + query::{QueryInput, QueryType::TestMultiply}, + }, net::{ client::ClientIdentity, test::{get_test_identity, TestConfig, TestConfigBuilder, TestServer}, @@ -295,6 +378,32 @@ mod tests { static STEP: Lazy = Lazy::new(|| Gate::from("http-transport")); + #[tokio::test] + async fn clean_on_kill() { + let noop_handler = make_owned_handler(|_, _| async move { + { + Ok(HelperResponse::ok()) + } + }); + let TestServer { transport, .. } = TestServer::builder() + .with_request_handler(Arc::clone(&noop_handler)) + .build() + .await; + + transport.inner_transport.record_streams.add_stream( + (QueryId, HelperIdentity::ONE, Gate::default()), + BodyStream::empty(), + ); + assert_eq!(1, transport.inner_transport.record_streams.len()); + + Transport::clone_ref(&transport) + .dispatch((RouteId::KillQuery, QueryId), BodyStream::empty()) + .await + .unwrap(); + + assert!(transport.inner_transport.record_streams.is_empty()); + } + #[tokio::test] async fn receive_stream() { let (tx, rx) = channel::>>(1); @@ -306,10 +415,10 @@ mod tests { let body = BodyStream::from_bytes_stream(ReceiverStream::new(rx)); // Register the stream with the transport (normally called by step data HTTP API handler) - Arc::clone(&transport).receive_stream(QueryId, STEP.clone(), HelperIdentity::TWO, body); + transport.receive_stream(QueryId, STEP.clone(), HelperIdentity::TWO, body); // Request step data reception (normally called by protocol) - let mut stream = Arc::clone(&transport) + let mut stream = transport .receive(HelperIdentity::TWO, (QueryId, STEP.clone())) .into_bytes_stream(); @@ -341,29 +450,51 @@ mod tests { async fn make_helpers( sockets: [TcpListener; 3], server_config: [ServerConfig; 3], - network_config: &NetworkConfig, + network_config: &NetworkConfig, disable_https: bool, ) -> [HelperApp; 3] { join_all( zip(HelperIdentity::make_three(), zip(sockets, server_config)).map( |(id, (socket, server_config))| async move { let identity = if disable_https { - ClientIdentity::Helper(id) + ClientIdentity::Header(id) } else { get_test_identity(id) }; let (setup, handler) = AppSetup::new(AppConfig::default()); - let clients = MpcHelperClient::from_conf(network_config, &identity); - let (transport, server) = HttpTransport::new( + let clients = MpcHelperClient::from_conf( + &IpaRuntime::current(), + network_config, + &identity, + ); + let (transport, server) = MpcHttpTransport::new( + IpaRuntime::current(), id, - server_config, + server_config.clone(), network_config.clone(), - clients, + &clients, Some(handler), ); - server.start_on(Some(socket), ()).await; + // TODO: Following is just temporary until Shard Transport is actually used. + let shard_clients_config = network_config.client.clone(); + let shard_server_config = server_config; + let shard_network_config = + NetworkConfig::new_shards(vec![], shard_clients_config); + let (shard_transport, _shard_server) = ShardHttpTransport::new( + IpaRuntime::current(), + ShardIndex::FIRST, + shard_server_config, + shard_network_config, + vec![], + None, + ); + // --- + + server + .start_on(&IpaRuntime::current(), Some(socket), ()) + .await; - setup.connect(transport, HttpShardTransport) + setup.connect(transport, shard_transport) }, ), ) @@ -374,7 +505,11 @@ mod tests { } async fn test_three_helpers(mut conf: TestConfig) { - let clients = MpcHelperClient::from_conf(&conf.network, &ClientIdentity::None); + let clients = MpcHelperClient::from_conf( + &IpaRuntime::current(), + &conf.network, + &ClientIdentity::None, + ); let _helpers = make_helpers( conf.sockets.take().unwrap(), conf.servers, @@ -389,7 +524,11 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn happy_case_twice() { let mut conf = TestConfigBuilder::with_open_ports().build(); - let clients = MpcHelperClient::from_conf(&conf.network, &ClientIdentity::None); + let clients = MpcHelperClient::from_conf( + &IpaRuntime::current(), + &conf.network, + &ClientIdentity::None, + ); let _helpers = make_helpers( conf.sockets.take().unwrap(), conf.servers, diff --git a/ipa-core/src/protocol/basics/mod.rs b/ipa-core/src/protocol/basics/mod.rs index d0315b8cf..cd2e92be6 100644 --- a/ipa-core/src/protocol/basics/mod.rs +++ b/ipa-core/src/protocol/basics/mod.rs @@ -89,7 +89,10 @@ impl<'a, B: ShardBinding> BooleanProtocols> { } -impl<'a> BooleanProtocols> for AdditiveShare {} +impl<'a, B: ShardBinding> BooleanProtocols> + for AdditiveShare +{ +} // Used for aggregation tests impl<'a, B: ShardBinding> BooleanProtocols, 8> @@ -107,7 +110,7 @@ impl<'a, B: ShardBinding> BooleanProtocols, { } -impl<'a> BooleanProtocols, PRF_CHUNK> +impl<'a, B: ShardBinding> BooleanProtocols, PRF_CHUNK> for AdditiveShare { } @@ -124,7 +127,7 @@ impl<'a, B: ShardBinding> BooleanProtocols, { } -impl<'a> BooleanProtocols, AGG_CHUNK> +impl<'a, B: ShardBinding> BooleanProtocols, AGG_CHUNK> for AdditiveShare { } @@ -149,12 +152,20 @@ impl<'a, B: ShardBinding> BooleanProtocols BooleanProtocols, 3> + for AdditiveShare +{ +} + impl<'a, B: ShardBinding> BooleanProtocols, 32> for AdditiveShare { } -impl<'a> BooleanProtocols, 32> for AdditiveShare {} +impl<'a, B: ShardBinding> BooleanProtocols, 32> + for AdditiveShare +{ +} const_assert_eq!( AGG_CHUNK, diff --git a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs index 74df755c3..acccd670a 100644 --- a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs +++ b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs @@ -13,6 +13,7 @@ use crate::{ RecordId, }, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, Vectorizable}, + sharding::{NotSharded, ShardBinding}, }; /// This function implements an MPC multiply using the standard strategy, i.e. via computing the @@ -27,13 +28,14 @@ use crate::{ /// back via the error response /// ## Panics /// Panics if the mutex is found to be poisoned -pub async fn zkp_multiply<'a, F, const N: usize>( - ctx: DZKPUpgradedMaliciousContext<'a>, +pub async fn zkp_multiply<'a, B, F, const N: usize>( + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, a: &Replicated, b: &Replicated, ) -> Result, Error> where + B: ShardBinding, F: Field + DZKPCompatibleField, { // Shared randomness used to mask the values that are sent. @@ -62,17 +64,17 @@ where /// Implement secure multiplication for malicious contexts with replicated secret sharing. #[async_trait] -impl<'a, F: Field + DZKPCompatibleField, const N: usize> - SecureMul> for Replicated +impl<'a, B: ShardBinding, F: Field + DZKPCompatibleField, const N: usize> + SecureMul> for Replicated { async fn multiply<'fut>( &self, rhs: &Self, - ctx: DZKPUpgradedMaliciousContext<'a>, + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, ) -> Result where - DZKPUpgradedMaliciousContext<'a>: 'fut, + DZKPUpgradedMaliciousContext<'a, NotSharded>: 'fut, { zkp_multiply(ctx, record_id, self, rhs).await } @@ -84,7 +86,7 @@ mod test { ff::boolean::Boolean, protocol::{ basics::SecureMul, - context::{dzkp_validator::DZKPValidator, Context, UpgradableContext}, + context::{dzkp_validator::DZKPValidator, Context, UpgradableContext, TEST_DZKP_STEPS}, RecordId, }, rand::{thread_rng, Rng}, @@ -101,7 +103,7 @@ mod test { let res = world .malicious((a, b), |ctx, (a, b)| async move { - let validator = ctx.dzkp_validator(10); + let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 8); let mctx = validator.context(); let result = a .multiply(&b, mctx.set_total_records(1), RecordId::from(0)) diff --git a/ipa-core/src/protocol/basics/mul/malicious.rs b/ipa-core/src/protocol/basics/mul/malicious.rs index e55d855d6..92bb6bee7 100644 --- a/ipa-core/src/protocol/basics/mul/malicious.rs +++ b/ipa-core/src/protocol/basics/mul/malicious.rs @@ -16,6 +16,7 @@ use crate::{ malicious::{AdditiveShare as MaliciousReplicated, ExtendableFieldSimd}, semi_honest::AdditiveShare as Replicated, }, + sharding::ShardBinding, }; /// @@ -49,8 +50,8 @@ use crate::{ /// back via the error response /// ## Panics /// Panics if the mutex is found to be poisoned -pub async fn mac_multiply( - ctx: UpgradedMaliciousContext<'_, F>, +pub async fn mac_multiply( + ctx: UpgradedMaliciousContext<'_, F, B>, record_id: RecordId, a: &MaliciousReplicated, b: &MaliciousReplicated, @@ -108,19 +109,19 @@ where /// Implement secure multiplication for malicious contexts with replicated secret sharing. #[async_trait] -impl<'a, F: ExtendableFieldSimd, const N: usize> SecureMul> - for MaliciousReplicated +impl<'a, F: ExtendableFieldSimd, B: ShardBinding, const N: usize> + SecureMul> for MaliciousReplicated where Replicated: FromPrss, { async fn multiply<'fut>( &self, rhs: &Self, - ctx: UpgradedMaliciousContext<'a, F>, + ctx: UpgradedMaliciousContext<'a, F, B>, record_id: RecordId, ) -> Result where - UpgradedMaliciousContext<'a, F>: 'fut, + UpgradedMaliciousContext<'a, F, B>: 'fut, { mac_multiply(ctx, record_id, self, rhs).await } diff --git a/ipa-core/src/protocol/basics/mul/mod.rs b/ipa-core/src/protocol/basics/mul/mod.rs index 3194be9fb..89f5e107a 100644 --- a/ipa-core/src/protocol/basics/mul/mod.rs +++ b/ipa-core/src/protocol/basics/mul/mod.rs @@ -27,7 +27,7 @@ use crate::{ mod dzkp_malicious; pub(crate) mod malicious; mod semi_honest; -pub(in crate::protocol) mod step; +pub(crate) mod step; pub use semi_honest::sh_multiply as semi_honest_multiply; @@ -123,17 +123,19 @@ macro_rules! boolean_array_mul { } } - impl<'a> BooleanArrayMul> for Replicated<$vec> { + impl<'a, B: sharding::ShardBinding> BooleanArrayMul> + for Replicated<$vec> + { type Vectorized = Replicated; fn multiply<'fut>( - ctx: DZKPUpgradedMaliciousContext<'a>, + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, a: &'fut Self::Vectorized, b: &'fut Self::Vectorized, ) -> impl Future> + Send + 'fut where - DZKPUpgradedMaliciousContext<'a>: 'fut, + DZKPUpgradedMaliciousContext<'a, B>: 'fut, { use crate::protocol::basics::mul::dzkp_malicious::zkp_multiply; zkp_multiply(ctx, record_id, a, b) diff --git a/ipa-core/src/protocol/basics/reveal.rs b/ipa-core/src/protocol/basics/reveal.rs index 51d5b2891..2344896a8 100644 --- a/ipa-core/src/protocol/basics/reveal.rs +++ b/ipa-core/src/protocol/basics/reveal.rs @@ -1,3 +1,27 @@ +// Several of the reveal impls use distinct type parameters for the value being revealed +// and the context-assiciated field. +// +// For MAC, this takes the form of distinct `V` and `CtxF` type parameters. For DZKP, +// this takes the form of a `V` type parameter different from the implicit `Boolean` +// used by the context. +// +// This decoupling is needed to support: +// +// 1. The PRF evaluation protocol, which uses `Fp25519` for the malicious context, but +// needs to reveal `RP25519` values. +// 2. The breakdown reveal aggregation protocol, which uses `Boolean` for the malicious +// context, but needs to reveal `BK` values. +// +// The malicious reveal protocol must check the shares being revealed for consistency, +// but doesn't care that they are in the same field as is used for the malicious +// context. Contrast with multiplication, which can only be supported in the malicious +// context's field. +// +// It also doesn't matter that `V` and `CtxF` support the same vectorization dimension +// `N`, but the compiler would not be able to infer the value of a decoupled +// vectorization dimension for `CtxF` from context, so it's easier to make them the same +// absent a need for them to be different. + use std::{ future::Future, iter::{repeat, zip}, @@ -8,7 +32,6 @@ use futures::{FutureExt, TryFutureExt}; use crate::{ error::Error, - ff::boolean::Boolean, helpers::{Direction, MaybeFuture, Role}, protocol::{ boolean::step::TwoHundredFiftySixBitOpStep, @@ -98,7 +121,7 @@ where ctx.parallel_join(zip(&**self, repeat(ctx.clone())).enumerate().map( |(i, (bit, ctx))| async move { generic_reveal( - ctx.narrow(&TwoHundredFiftySixBitOpStep::Bit(i)), + ctx.narrow(&TwoHundredFiftySixBitOpStep::from(i)), record_id, excluded, bit, @@ -170,8 +193,6 @@ where } } -// Like the impl for `UpgradedMaliciousContext`, this impl uses distinct `V` and `CtxF` type -// parameters. See the comment on that impl for more details. impl<'a, B, V, CtxF, const N: usize> Reveal> for Replicated where @@ -194,12 +215,12 @@ where } } -impl<'a, B, const N: usize> Reveal> for Replicated +impl<'a, V, B, const N: usize> Reveal> for Replicated where B: ShardBinding, - Boolean: Vectorizable, + V: SharedValue + Vectorizable, { - type Output = >::Array; + type Output = >::Array; async fn generic_reveal<'fut>( &'fut self, @@ -270,15 +291,6 @@ where } } -// This impl uses distinct `V` and `CtxF` type parameters to support the PRF evaluation protocol, -// which uses `Fp25519` for the malicious context, but needs to reveal `RP25519` values. The -// malicious reveal protocol must check the shares being revealed for consistency, but doesn't care -// that they are in the same field as is used for the malicious context. Contrast with -// multiplication, which can only be supported in the malicious context's field. -// -// It also doesn't matter that `V` and `CtxF` support the same vectorization dimension `N`, but the -// compiler would not be able to infer the value of a decoupled vectorization dimension for `CtxF` -// from context, so it's easier to make them the same absent a need for them to be different. impl<'a, V, const N: usize, CtxF> Reveal> for Replicated where CtxF: ExtendableField, @@ -321,20 +333,21 @@ where } } -impl<'a, const N: usize> Reveal> for Replicated +impl<'a, V, B, const N: usize> Reveal> for Replicated where - Boolean: Vectorizable, + B: ShardBinding, + V: SharedValue + Vectorizable, { - type Output = >::Array; + type Output = >::Array; async fn generic_reveal<'fut>( &'fut self, - ctx: DZKPUpgradedMaliciousContext<'a>, + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, excluded: Option, ) -> Result, Error> where - DZKPUpgradedMaliciousContext<'a>: 'fut, + DZKPUpgradedMaliciousContext<'a, B>: 'fut, { malicious_reveal(ctx, record_id, excluded, self).await } diff --git a/ipa-core/src/protocol/boolean/and.rs b/ipa-core/src/protocol/boolean/and.rs index c05f5cdf8..9afbfd9af 100644 --- a/ipa-core/src/protocol/boolean/and.rs +++ b/ipa-core/src/protocol/boolean/and.rs @@ -51,7 +51,7 @@ where BitDecomposed::try_from( ctx.parallel_join(zip(a.iter(), b).enumerate().map(|(i, (a, b))| { - let ctx = ctx.narrow(&EightBitStep::Bit(i)); + let ctx = ctx.narrow(&EightBitStep::from(i)); a.multiply(b, ctx, record_id) })) .await?, diff --git a/ipa-core/src/protocol/boolean/or.rs b/ipa-core/src/protocol/boolean/or.rs index c8aa611c9..da07c35cd 100644 --- a/ipa-core/src/protocol/boolean/or.rs +++ b/ipa-core/src/protocol/boolean/or.rs @@ -1,9 +1,11 @@ use std::iter::zip; +use ipa_step::StepNarrow; + use crate::{ error::Error, ff::{boolean::Boolean, Field}, - protocol::{basics::SecureMul, boolean::step::SixteenBitStep, context::Context, RecordId}, + protocol::{basics::SecureMul, boolean::NBitStep, context::Context, Gate, RecordId}, secret_sharing::{ replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd, Linear as LinearSecretSharing, @@ -34,7 +36,7 @@ pub async fn or + SecureMul>( // // Supplying an iterator saves constructing a complete copy of the argument // in memory when it is a uniform constant. -pub async fn bool_or<'a, C, BI, const N: usize>( +pub async fn bool_or<'a, C, S, BI, const N: usize>( ctx: C, record_id: RecordId, a: &BitDecomposed>, @@ -42,17 +44,19 @@ pub async fn bool_or<'a, C, BI, const N: usize>( ) -> Result>, Error> where C: Context, + S: NBitStep, BI: IntoIterator, ::IntoIter: ExactSizeIterator> + Send, Boolean: FieldSimd, AdditiveShare: SecureMul, + Gate: StepNarrow, { let b = b.into_iter(); assert_eq!(a.len(), b.len()); BitDecomposed::try_from( ctx.parallel_join(zip(a.iter(), b).enumerate().map(|(i, (a, b))| { - let ctx = ctx.narrow(&SixteenBitStep::Bit(i)); + let ctx = ctx.narrow(&S::from(i)); async move { let ab = a.multiply(b, ctx, record_id).await?; Ok::<_, Error>(-ab + a + b) diff --git a/ipa-core/src/protocol/boolean/step.rs b/ipa-core/src/protocol/boolean/step.rs index 869f92726..0128cb037 100644 --- a/ipa-core/src/protocol/boolean/step.rs +++ b/ipa-core/src/protocol/boolean/step.rs @@ -1,63 +1,22 @@ use ipa_step_derive::CompactStep; #[derive(CompactStep)] -pub enum EightBitStep { - #[step(count = 8)] - Bit(usize), -} - -impl From for EightBitStep { - fn from(v: usize) -> Self { - Self::Bit(v) - } -} +#[step(count = 8, name = "bit")] +pub struct EightBitStep(usize); #[derive(CompactStep)] -pub enum SixteenBitStep { - #[step(count = 16)] - Bit(usize), -} - -impl From for SixteenBitStep { - fn from(v: usize) -> Self { - Self::Bit(v) - } -} +#[step(count = 16, name = "bit")] +pub struct SixteenBitStep(usize); #[derive(CompactStep)] -pub enum ThirtyTwoBitStep { - #[step(count = 32)] - Bit(usize), -} - -impl From for ThirtyTwoBitStep { - fn from(v: usize) -> Self { - Self::Bit(v) - } -} +#[step(count = 32, name = "bit")] +pub struct ThirtyTwoBitStep(usize); #[derive(CompactStep)] -pub enum TwoHundredFiftySixBitOpStep { - #[step(count = 256)] - Bit(usize), -} - -impl From for TwoHundredFiftySixBitOpStep { - fn from(v: usize) -> Self { - Self::Bit(v) - } -} +#[step(count = 256, name = "bit")] +pub struct TwoHundredFiftySixBitOpStep(usize); #[cfg(test)] #[derive(CompactStep)] -pub enum DefaultBitStep { - #[step(count = 256)] - Bit(usize), -} - -#[cfg(test)] -impl From for DefaultBitStep { - fn from(v: usize) -> Self { - Self::Bit(v) - } -} +#[step(count = 256, name = "bit")] +pub struct DefaultBitStep(usize); diff --git a/ipa-core/src/protocol/context/batcher.rs b/ipa-core/src/protocol/context/batcher.rs index 96337b629..974968bcf 100644 --- a/ipa-core/src/protocol/context/batcher.rs +++ b/ipa-core/src/protocol/context/batcher.rs @@ -6,8 +6,8 @@ use tokio::sync::watch; use crate::{ error::Error, helpers::TotalRecords, - protocol::RecordId, - sync::{Arc, Mutex}, + protocol::{context::dzkp_validator::TARGET_PROOF_SIZE, RecordId}, + sync::Mutex, }; /// Manages validation of batches of records for malicious protocols. @@ -88,20 +88,24 @@ impl<'a, B> Batcher<'a, B> { records_per_batch: usize, total_records: T, batch_constructor: Box B + Send + 'a>, - ) -> Arc> { - Arc::new(Mutex::new(Self { + ) -> Mutex { + Mutex::new(Self { batches: VecDeque::new(), first_batch: 0, records_per_batch, total_records: total_records.into(), batch_constructor, - })) + }) } pub fn set_total_records>(&mut self, total_records: T) { self.total_records = self.total_records.overwrite(total_records.into()); } + pub fn records_per_batch(&self) -> usize { + self.records_per_batch + } + fn batch_offset(&self, record_id: RecordId) -> usize { let batch_index = usize::from(record_id) / self.records_per_batch; batch_index @@ -112,13 +116,14 @@ impl<'a, B> Batcher<'a, B> { fn get_batch_by_offset(&mut self, batch_offset: usize) -> &mut BatchState { if self.batches.len() <= batch_offset { self.batches.reserve(batch_offset - self.batches.len() + 1); + let pending_records_capacity = self.records_per_batch.min(TARGET_PROOF_SIZE); while self.batches.len() <= batch_offset { let (validation_result, _) = watch::channel::(false); let state = BatchState { - batch: (self.batch_constructor)(self.first_batch + batch_offset), + batch: (self.batch_constructor)(self.first_batch + self.batches.len()), validation_result, pending_count: 0, - pending_records: bitvec![0; self.records_per_batch], + pending_records: bitvec![0; pending_records_capacity], }; self.batches.push_back(Some(state)); } @@ -153,10 +158,16 @@ impl<'a, B> Batcher<'a, B> { let total_count = min(self.records_per_batch, remaining_records); let record_offset_in_batch = usize::from(record_id) - first_record_in_batch; let batch = self.get_batch_by_offset(batch_offset); - assert!( - !batch.pending_records[record_offset_in_batch], - "validate_record called twice for record {record_id}", - ); + if batch.pending_records.len() <= record_offset_in_batch { + batch + .pending_records + .resize(record_offset_in_batch + 1, false); + } else { + assert!( + !batch.pending_records[record_offset_in_batch], + "validate_record called twice for record {record_id}", + ); + } // This assertion is stricter than the bounds check in `BitVec::set` when the // batch size is not a multiple of 8, or for a partial final batch. assert!( @@ -274,7 +285,10 @@ impl<'a, B> Batcher<'a, B> { mod tests { use std::{future::ready, pin::pin}; - use futures::future::{poll_immediate, try_join, try_join3, try_join4}; + use futures::{ + future::{join_all, poll_immediate, try_join, try_join3, try_join4}, + FutureExt, + }; use super::*; @@ -297,6 +311,23 @@ mod tests { ); } + #[test] + fn makes_batches_out_of_order() { + // Regression test for a bug where, when adding batches i..j to fill in a gap in + // the batch deque prior to out-of-order requested batch j, the batcher passed + // batch index `j` to the constructor for all of them, as opposed to the correct + // sequence of indices i..=j. + + let batcher = Batcher::new(1, 2, Box::new(std::convert::identity)); + let mut batcher = batcher.lock().unwrap(); + + batcher.get_batch(RecordId::from(1)); + batcher.get_batch(RecordId::from(0)); + + assert_eq!(batcher.get_batch(RecordId::from(0)).batch, 0); + assert_eq!(batcher.get_batch(RecordId::from(1)).batch, 1); + } + #[tokio::test] async fn validates_batches() { let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new())); @@ -537,6 +568,55 @@ mod tests { )); } + #[tokio::test] + async fn large_batch() { + // This test exercises the case where the preallocated size of `pending_records` + // was limited to `TARGET_PROOF_SIZE`, and we need to grow it alter. + let batcher = Batcher::new( + TARGET_PROOF_SIZE + 1, + TotalRecords::specified(TARGET_PROOF_SIZE + 1).unwrap(), + Box::new(|_| Vec::new()), + ); + + let mut futs = (0..TARGET_PROOF_SIZE) + .map(|i| { + batcher + .lock() + .unwrap() + .get_batch(RecordId::from(i)) + .batch + .push(i); + batcher + .lock() + .unwrap() + .validate_record(RecordId::from(i), |_i, _b| async { unreachable!() }) + .map(Result::unwrap) + .boxed() + }) + .collect::>(); + + batcher + .lock() + .unwrap() + .get_batch(RecordId::from(TARGET_PROOF_SIZE)) + .batch + .push(TARGET_PROOF_SIZE); + futs.push( + batcher + .lock() + .unwrap() + .validate_record(RecordId::from(TARGET_PROOF_SIZE), |i, b| { + assert!(i == 0 && b.as_slice() == (0..=TARGET_PROOF_SIZE).collect::>()); + ready(Ok(())) + }) + .map(Result::unwrap) + .boxed(), + ); + join_all(futs).await; + + assert!(batcher.lock().unwrap().is_empty()); + } + #[test] fn into_single_batch() { let batcher = Batcher::new(2, TotalRecords::Unspecified, Box::new(|_| Vec::new())); @@ -550,7 +630,7 @@ mod tests { .push(i); } - let batcher = Arc::into_inner(batcher).unwrap().into_inner().unwrap(); + let batcher = batcher.into_inner().unwrap(); assert_eq!(batcher.into_single_batch(), vec![0, 1]); } @@ -568,7 +648,7 @@ mod tests { .push(i); } - let batcher = Arc::into_inner(batcher).unwrap().into_inner().unwrap(); + let batcher = batcher.into_inner().unwrap(); batcher.into_single_batch(); } @@ -602,7 +682,7 @@ mod tests { }); assert_eq!(try_join(fut1, fut2).await.unwrap(), ((), ())); - let batcher = Arc::into_inner(batcher).unwrap().into_inner().unwrap(); + let batcher = batcher.into_inner().unwrap(); batcher.into_single_batch(); } } diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index b553e41b5..951b5b4c4 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -11,34 +11,65 @@ use crate::{ helpers::{MpcMessage, MpcReceivingEnd, Role, SendingEnd, TotalRecords}, protocol::{ context::{ - batcher::Batcher, - dzkp_validator::{Batch, Segment}, + dzkp_validator::{Batch, MaliciousDZKPValidatorInner, Segment}, prss::InstrumentedIndexedSharedRandomness, - step::ZeroKnowledgeProofValidateStep, Context as ContextTrait, DZKPContext, InstrumentedSequentialSharedRandomness, MaliciousContext, }, Gate, RecordId, }, seq_join::SeqJoin, - sync::{Arc, Mutex, Weak}, + sharding::ShardBinding, + sync::{Arc, Weak}, }; -pub(super) type DzkpBatcher<'a> = Mutex>; - /// Represents protocol context in malicious setting when using zero-knowledge proofs, /// i.e. secure against one active adversary in 3 party MPC ring. #[derive(Clone)] -pub struct DZKPUpgraded<'a> { - batcher: Weak>, - base_ctx: MaliciousContext<'a>, +pub struct DZKPUpgraded<'a, B: ShardBinding> { + validator_inner: Weak>, + base_ctx: MaliciousContext<'a, B>, } -impl<'a> DZKPUpgraded<'a> { - pub(super) fn new(batch: &Arc>, base_ctx: MaliciousContext<'a>) -> Self { +impl<'a, B: ShardBinding> DZKPUpgraded<'a, B> { + pub(super) fn new( + validator_inner: &Arc>, + base_ctx: MaliciousContext<'a, B>, + ) -> Self { + let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch(); + let active_work = if records_per_batch == 1 || records_per_batch == usize::MAX { + // If records_per_batch is 1, let active_work be anything. This only happens + // in tests; there shouldn't be a risk of deadlocks with one record per + // batch; and UnorderedReceiver capacity (which is set from active_work) + // must be at least two. + // + // Also rely on the protocol to ensure an appropriate active_work if + // records_per_batch is `usize::MAX` (unlimited batch size). Allocating + // storage for `usize::MAX` active records won't work. + base_ctx.active_work() + } else { + // Adjust active_work to match records_per_batch. If it is less, we will + // certainly stall, since every record in the batch remains incomplete until + // the batch is validated. It is possible that it can be larger, but making + // it the same seems safer for now. + let active_work = NonZeroUsize::new(records_per_batch).unwrap(); + tracing::debug!( + "Changed active_work from {} to {} to match batch size", + base_ctx.active_work().get(), + active_work, + ); + active_work + }; Self { - batcher: Arc::downgrade(batch), - base_ctx, + validator_inner: Arc::downgrade(validator_inner), + // This overrides the active work for this context and all children + // created from it by using narrow, clone, etc. + // This allows all steps participating in malicious validation + // to use the same active work window and prevent deadlocks. + // + // This also checks that active work is a power of two and + // panics if it is not. + base_ctx: base_ctx.set_active_work(active_work.get().try_into().unwrap()), } } @@ -49,36 +80,32 @@ impl<'a> DZKPUpgraded<'a> { } fn with_batch T, T>(&self, record_id: RecordId, action: C) -> T { - let batcher = self.batcher.upgrade().expect("Validator is active"); + let validator_inner = self.validator_inner.upgrade().expect("Validator is active"); - let mut batch = batcher.lock().unwrap(); - let state = batch.get_batch(record_id); + let mut batcher = validator_inner.batcher.lock().unwrap(); + let state = batcher.get_batch(record_id); (action)(&mut state.batch) } } #[async_trait] -impl<'a> DZKPContext for DZKPUpgraded<'a> { +impl<'a, B: ShardBinding> DZKPContext for DZKPUpgraded<'a, B> { async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> { - let validation_future = self + let validator_inner = self.validator_inner.upgrade().expect("validator is active"); + + let ctx = validator_inner.validate_ctx.clone(); + + let validation_future = validator_inner .batcher - .upgrade() - .expect("Validation batch is active") .lock() .unwrap() - .validate_record(record_id, |batch_idx, batch| { - batch.validate( - self.base_ctx - .narrow(&ZeroKnowledgeProofValidateStep::DZKPValidate(batch_idx)) - .validator_context(), - ) - }); + .validate_record(record_id, |batch_idx, batch| batch.validate(ctx, batch_idx)); validation_future.await } } -impl<'a> super::Context for DZKPUpgraded<'a> { +impl<'a, B: ShardBinding> super::Context for DZKPUpgraded<'a, B> { fn role(&self) -> Role { self.base_ctx.role() } @@ -130,13 +157,13 @@ impl<'a> super::Context for DZKPUpgraded<'a> { } } -impl<'a> SeqJoin for DZKPUpgraded<'a> { +impl<'a, B: ShardBinding> SeqJoin for DZKPUpgraded<'a, B> { fn active_work(&self) -> NonZeroUsize { self.base_ctx.active_work() } } -impl Debug for DZKPUpgraded<'_> { +impl Debug for DZKPUpgraded<'_, B> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "DZKPMaliciousContext") } diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 73a32e62b..2b33181b1 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -1,8 +1,9 @@ -use std::{cmp, collections::BTreeMap, fmt::Debug, future::ready}; +use std::{collections::BTreeMap, fmt::Debug, future::ready}; use async_trait::async_trait; use bitvec::prelude::{BitArray, BitSlice, Lsb0}; use futures::{stream, Future, FutureExt, Stream, StreamExt}; +use ipa_step::StepNarrow; use crate::{ error::{BoxError, Error}, @@ -12,17 +13,20 @@ use crate::{ context::{ batcher::Batcher, dzkp_field::{DZKPBaseField, UVTupleBlock}, - dzkp_malicious::{DZKPUpgraded as MaliciousDZKPUpgraded, DzkpBatcher}, + dzkp_malicious::DZKPUpgraded as MaliciousDZKPUpgraded, dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded, - step::ZeroKnowledgeProofValidateStep as Step, - Base, Context, DZKPContext, MaliciousContext, + step::DzkpValidationProtocolStep as Step, + Base, Context, DZKPContext, MaliciousContext, MaliciousProtocolSteps, }, - ipa_prf::validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, - Gate, RecordId, + ipa_prf::{ + validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, + LargeProofGenerator, SmallProofGenerator, + }, + Gate, RecordId, RecordIdRange, }, seq_join::{seq_join, SeqJoin}, sharding::ShardBinding, - sync::Arc, + sync::{Arc, Mutex}, }; pub type Array256Bit = BitArray<[u8; 32], Lsb0>; @@ -33,8 +37,44 @@ const BIT_ARRAY_LEN: usize = 256; const BIT_ARRAY_MASK: usize = BIT_ARRAY_LEN - 1; const BIT_ARRAY_SHIFT: usize = BIT_ARRAY_LEN.ilog2() as usize; +// The target size of a zero-knowledge proof, in GF(2) multiplies. Seven intermediate +// values are stored for each multiply, so the amount memory required is 7 times this +// value. +// +// To enable computing a read size for `OrdereringSender` that achieves good network +// utilization, the number of records in a proof must be a power of two. Protocols +// typically compute the size of a proof batch by dividing TARGET_PROOF_SIZE by +// an approximate number of multiplies per record, and then rounding up to a power +// of two. Thus, it is not necessary for TARGET_PROOF_SIZE to be a power of two. +// +// A smaller value is used for tests, to enable covering some corner cases with a +// reasonable runtime. Some of these tests use TARGET_PROOF_SIZE directly, so for tests +// it does need to be a power of two. +#[cfg(test)] +pub const TARGET_PROOF_SIZE: usize = 8192; +#[cfg(not(test))] pub const TARGET_PROOF_SIZE: usize = 50_000_000; +/// Maximum proof recursion depth. +// +// This is a hard limit. Each GF(2) multiply generates four G values and four H values, +// and the last level of the proof is limited to (small_recursion_factor - 1), so the +// restriction is: +// +// $$ large_recursion_factor * (small_recursion_factor - 1) +// * small_recursion_factor ^ (depth - 2) >= 4 * target_proof_size $$ +// +// With large_recursion_factor = 32 and small_recursion_factor = 8, this means: +// +// $$ depth >= log_8 (8/7 * target_proof_size) $$ +// +// Because the number of records in a proof batch is often rounded up to a power of two +// (and less significantly, because multiplication intermediate storage gets rounded up +// to blocks of 256), leaving some margin is advised. +// +// The implementation requires that MAX_PROOF_RECURSION is at least 2. +pub const MAX_PROOF_RECURSION: usize = 9; + /// `MultiplicationInputsBlock` is a block of fixed size of intermediate values /// that occur duringa multiplication. /// These values need to be verified since there might have been malicious behavior. @@ -92,6 +132,31 @@ impl MultiplicationInputsBlock { }) } + /// set using bitslices + /// ## Errors + /// Errors when length of slices is not 256 bit + #[allow(clippy::too_many_arguments)] + fn set( + &mut self, + x_left: &BitSliceType, + x_right: &BitSliceType, + y_left: &BitSliceType, + y_right: &BitSliceType, + prss_left: &BitSliceType, + prss_right: &BitSliceType, + z_right: &BitSliceType, + ) -> Result<(), BoxError> { + self.x_left = BitArray::try_from(x_left)?; + self.x_right = BitArray::try_from(x_right)?; + self.y_left = BitArray::try_from(y_left)?; + self.y_right = BitArray::try_from(y_right)?; + self.prss_left = BitArray::try_from(prss_left)?; + self.prss_right = BitArray::try_from(prss_right)?; + self.z_right = BitArray::try_from(z_right)?; + + Ok(()) + } + /// `Convert` allows to convert `MultiplicationInputs` into a format compatible with DZKPs /// This is the convert function called by the prover. fn convert_prover(&self) -> Vec> { @@ -220,29 +285,29 @@ impl<'a> SegmentEntry<'a> { /// `MultiplicationInputsBatch` stores a batch of multiplication inputs in a vector of `MultiplicationInputsBlock`. /// `first_record` is the first `RecordId` for the current batch. -/// `last_record` keeps track of the highest record that has been added to the batch. /// `max_multiplications` is the maximum amount of multiplications performed within a this batch. /// It is used to determine the vector length during the allocation. /// If there are more multiplications, it will cause a panic! /// `multiplication_bit_size` is the bit size of a single multiplication. The size will be consistent /// across all multiplications of a gate. -/// `is_empty` keeps track of whether any value has been added #[derive(Clone, Debug)] struct MultiplicationInputsBatch { - first_record: RecordId, - last_record: RecordId, + first_record: Option, max_multiplications: usize, multiplication_bit_size: usize, - is_empty: bool, vec: Vec, } impl MultiplicationInputsBatch { - /// Creates a new store for multiplication intermediates for records starting from - /// `first_record`. The size of the allocated vector is + /// Creates a new store for multiplication intermediates. The first record is + /// specified by `first_record`, or if that is `None`, is set automatically the + /// first time a segment is added to the batch. Once the first record is set, + /// attempting to add a segment before the first record will panic. + /// + /// The size of the allocated vector is /// `ceil((max_multiplications * multiplication_bit_size) / BIT_ARRAY_LEN)`. fn new( - first_record: RecordId, + first_record: Option, max_multiplications: usize, multiplication_bit_size: usize, ) -> Self { @@ -256,14 +321,12 @@ impl MultiplicationInputsBatch { // records. let capacity_bits = usize::min( TARGET_PROOF_SIZE, - max_multiplications * multiplication_bit_size, + max_multiplications.saturating_mul(multiplication_bit_size), ); Self { first_record, - last_record: first_record, max_multiplications, multiplication_bit_size, - is_empty: false, vec: Vec::with_capacity((capacity_bits + BIT_ARRAY_MASK) >> BIT_ARRAY_SHIFT), } } @@ -276,7 +339,7 @@ impl MultiplicationInputsBatch { /// returns whether the store is empty fn is_empty(&self) -> bool { - self.is_empty + self.vec.is_empty() } /// `insert_segment` allows to include a new segment in `MultiplicationInputsBatch`. @@ -291,21 +354,27 @@ impl MultiplicationInputsBatch { // check segment size debug_assert_eq!(segment.len(), self.multiplication_bit_size); + let first_record = *self.first_record.get_or_insert(record_id); + // panics when record_id is out of bounds - assert!(record_id >= self.first_record); assert!( - record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record)), + record_id >= first_record, + "record_id out of range in insert_segment. record {record_id} is before \ + first record {first_record}", + ); + assert!( + usize::from(record_id) + < self + .max_multiplications + .saturating_add(usize::from(first_record)), "record_id out of range in insert_segment. record {record_id} is beyond \ segment of length {} starting at {}", self.max_multiplications, - self.first_record, + first_record, ); - // update last record - self.last_record = cmp::max(self.last_record, record_id); - // panics when record_id is too large to fit in, i.e. when it is out of bounds - if segment.len() <= 256 { + if segment.len() < 256 { self.insert_segment_small(record_id, segment); } else { self.insert_segment_large(record_id, &segment); @@ -320,17 +389,8 @@ impl MultiplicationInputsBatch { /// than the first record of the batch, i.e. `first_record` /// or too large, i.e. `first_record+max_multiplications` fn insert_segment_small(&mut self, record_id: RecordId, segment: Segment) { - // check length - debug_assert!(segment.len() <= 256); - - // panics when record_id is out of bounds - assert!(record_id >= self.first_record); - assert!( - record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record)) - ); - // panics when record_id is less than first_record - let id_within_batch = usize::from(record_id) - usize::from(self.first_record); + let id_within_batch = usize::from(record_id) - usize::from(self.first_record.unwrap()); // round up segment length to a power of two since we want to have divisors of 256 let length = segment.len().next_power_of_two(); @@ -371,16 +431,7 @@ impl MultiplicationInputsBatch { /// than the first record of the batch, i.e. `first_record` /// or too large, i.e. `first_record+max_multiplications` fn insert_segment_large(&mut self, record_id: RecordId, segment: &Segment) { - // check length - debug_assert_eq!(segment.len() % 256, 0); - - // panics when record_id is out of bounds - assert!(record_id >= self.first_record); - assert!( - record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record)) - ); - - let id_within_batch = usize::from(record_id) - usize::from(self.first_record); + let id_within_batch = usize::from(record_id) - usize::from(self.first_record.unwrap()); let block_id = (segment.len() * id_within_batch) >> BIT_ARRAY_SHIFT; let length_in_blocks = segment.len() >> BIT_ARRAY_SHIFT; if self.vec.len() < block_id { @@ -389,8 +440,9 @@ impl MultiplicationInputsBatch { } for i in 0..length_in_blocks { - self.vec.push( - MultiplicationInputsBlock::clone_from( + if self.vec.len() > block_id + i { + MultiplicationInputsBlock::set( + &mut self.vec[block_id + i], &segment.x_left.0[256 * i..256 * (i + 1)], &segment.x_right.0[256 * i..256 * (i + 1)], &segment.y_left.0[256 * i..256 * (i + 1)], @@ -399,8 +451,21 @@ impl MultiplicationInputsBatch { &segment.prss_right.0[256 * i..256 * (i + 1)], &segment.z_right.0[256 * i..256 * (i + 1)], ) - .unwrap(), - ); + .unwrap(); + } else { + self.vec.push( + MultiplicationInputsBlock::clone_from( + &segment.x_left.0[256 * i..256 * (i + 1)], + &segment.x_right.0[256 * i..256 * (i + 1)], + &segment.y_left.0[256 * i..256 * (i + 1)], + &segment.y_right.0[256 * i..256 * (i + 1)], + &segment.prss_left.0[256 * i..256 * (i + 1)], + &segment.prss_right.0[256 * i..256 * (i + 1)], + &segment.z_right.0[256 * i..256 * (i + 1)], + ) + .unwrap(), + ); + } } } @@ -444,12 +509,24 @@ impl MultiplicationInputsBatch { #[derive(Debug)] pub(super) struct Batch { max_multiplications_per_gate: usize, - first_record: RecordId, + first_record: Option, inner: BTreeMap, } impl Batch { - fn new(first_record: RecordId, max_multiplications_per_gate: usize) -> Self { + /// Creates a new `Batch` for multiplication intermediates from multiple gates. The + /// first record is specified by `first_record`, or if that is `None`, is set + /// automatically for each gate the first time a segment from that gate is added. + /// + /// Once the first record is set, attempting to add a segment before the first + /// record will panic. It is likely, but not guaranteed, that protocol execution + /// proceeds in order, so a problem here can easily escape testing. + /// * When using the `Batcher` in multi-batch mode, `first_record` is calculated + /// from the batch index and the number of records in a batch, so there is no + /// possibility of attempting to add a record before the start of the batch. + /// * The only protocol that manages batches explicitly is the aggregation protocol + /// (`breakdown_reveal_aggregation`). It is structured to operate in order. + fn new(first_record: Option, max_multiplications_per_gate: usize) -> Self { Self { max_multiplications_per_gate, first_record, @@ -519,9 +596,22 @@ impl Batch { /// ## Panics /// If `usize` to `u128` conversion fails. - pub(super) async fn validate(self, ctx: Base<'_>) -> Result<(), Error> { + pub(super) async fn validate( + self, + ctx: Base<'_, B>, + batch_index: usize, + ) -> Result<(), Error> { + const PRSS_RECORDS_PER_BATCH: usize = LargeProofGenerator::PROOF_LENGTH + + (MAX_PROOF_RECURSION - 1) * SmallProofGenerator::PROOF_LENGTH + + 2; // P and Q masks + let proof_ctx = ctx.narrow(&Step::GenerateProof); + let record_id = RecordId::from(batch_index); + let prss_record_id_start = RecordId::from(batch_index * PRSS_RECORDS_PER_BATCH); + let prss_record_id_end = RecordId::from((batch_index + 1) * PRSS_RECORDS_PER_BATCH); + let prss_record_ids = RecordIdRange::from(prss_record_id_start..prss_record_id_end); + if self.is_empty() { return Ok(()); } @@ -533,11 +623,12 @@ impl Batch { q_mask_from_left_prover, ) = { // generate BatchToVerify - ProofBatch::generate(&proof_ctx, self.get_field_values_prover()) + ProofBatch::generate(&proof_ctx, prss_record_ids, self.get_field_values_prover()) }; let chunk_batch = BatchToVerify::generate_batch_to_verify( proof_ctx, + record_id, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -547,7 +638,7 @@ impl Batch { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = chunk_batch - .generate_challenges(ctx.narrow(&Step::Challenge)) + .generate_challenges(ctx.narrow(&Step::Challenge), record_id) .await; let (sum_of_uv, p_r_right_prover, q_r_left_prover) = { @@ -581,6 +672,7 @@ impl Batch { chunk_batch .verify( ctx.narrow(&Step::VerifyProof), + record_id, sum_of_uv, p_r_right_prover, q_r_left_prover, @@ -616,7 +708,18 @@ pub trait DZKPValidator: Send + Sync { /// /// # Panics /// May panic if the above restrictions on validator usage are not followed. - async fn validate(self) -> Result<(), Error>; + async fn validate(self) -> Result<(), Error> + where + Self: Sized, + { + self.validate_indexed(0).await + } + + /// Validates all of the multiplies associated with this validator, specifying + /// an explicit batch index. + /// + /// This should be used when the protocol is explicitly managing batches. + async fn validate_indexed(self, batch_index: usize) -> Result<(), Error>; /// `is_verified` checks that there are no `MultiplicationInputs` that have not been verified /// within the associated `DZKPBatch` @@ -662,6 +765,20 @@ pub trait DZKPValidator: Send + Sync { } } +// Wrapper to avoid https://github.com/rust-lang/rust/issues/100013. +pub fn validated_seq_join<'st, V, S, F, O>( + validator: V, + source: S, +) -> impl Stream> + Send + 'st +where + V: DZKPValidator + 'st, + S: Stream + Send + 'st, + F: Future> + Send + 'st, + O: Send + Sync + 'static, +{ + validator.validated_seq_join(source) +} + #[derive(Clone)] pub struct SemiHonestDZKPValidator<'a, B: ShardBinding> { context: SemiHonestDZKPUpgraded<'a, B>, @@ -687,7 +804,7 @@ impl<'a, B: ShardBinding> DZKPValidator for SemiHonestDZKPValidator<'a, B> { // Semi-honest validator doesn't do anything, so doesn't care. } - async fn validate(self) -> Result<(), Error> { + async fn validate_indexed(self, _batch_index: usize) -> Result<(), Error> { Ok(()) } @@ -696,50 +813,59 @@ impl<'a, B: ShardBinding> DZKPValidator for SemiHonestDZKPValidator<'a, B> { } } +type DzkpBatcher<'a> = Batcher<'a, Batch>; + +/// The DZKP validator, and all associated contexts, each hold a reference to a single +/// instance of `MaliciousDZKPValidatorInner`. +pub(super) struct MaliciousDZKPValidatorInner<'a, B: ShardBinding> { + pub(super) batcher: Mutex>, + pub(super) validate_ctx: Base<'a, B>, +} + /// `MaliciousDZKPValidator` corresponds to pub struct `Malicious` and implements the trait `DZKPValidator` /// The implementation of `validate` of the `DZKPValidator` trait depends on generic `DF` -pub struct MaliciousDZKPValidator<'a> { +pub struct MaliciousDZKPValidator<'a, B: ShardBinding> { // This is an `Option` because we want to consume it in `DZKPValidator::validate`, // but we also want to implement `Drop`. Note that the `is_verified` check in `Drop` // does nothing when `batcher_ref` is already `None`. - batcher_ref: Option>>, - protocol_ctx: MaliciousDZKPUpgraded<'a>, - validate_ctx: MaliciousContext<'a>, + inner_ref: Option>>, + protocol_ctx: MaliciousDZKPUpgraded<'a, B>, } #[async_trait] -impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> { - type Context = MaliciousDZKPUpgraded<'a>; +impl<'a, B: ShardBinding> DZKPValidator for MaliciousDZKPValidator<'a, B> { + type Context = MaliciousDZKPUpgraded<'a, B>; - fn context(&self) -> MaliciousDZKPUpgraded<'a> { + fn context(&self) -> MaliciousDZKPUpgraded<'a, B> { self.protocol_ctx.clone() } fn set_total_records>(&mut self, total_records: T) { - self.batcher_ref + self.inner_ref .as_ref() - .unwrap() + .expect("validator should be active") + .batcher .lock() .unwrap() .set_total_records(total_records); } - async fn validate(mut self) -> Result<(), Error> { - let batcher_arc = self - .batcher_ref + async fn validate_indexed(mut self, batch_index: usize) -> Result<(), Error> { + let arc = self + .inner_ref .take() .expect("nothing else should be consuming the batcher"); - let batcher_mutex = Arc::into_inner(batcher_arc) + let MaliciousDZKPValidatorInner { + batcher: batcher_mutex, + validate_ctx, + } = Arc::into_inner(arc) .expect("validator should hold the only strong reference to batcher"); + let batcher = batcher_mutex.into_inner().unwrap(); batcher .into_single_batch() - .validate( - self.validate_ctx - .narrow(&Step::DZKPValidate(0)) - .validator_context(), - ) + .validate(validate_ctx, batch_index) .await } @@ -749,7 +875,13 @@ impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> { /// ## Errors /// Errors when there are `MultiplicationInputs` that have not been verified. fn is_verified(&self) -> Result<(), Error> { - let batcher = self.batcher_ref.as_ref().unwrap().lock().unwrap(); + let batcher = self + .inner_ref + .as_ref() + .expect("validator should be active") + .batcher + .lock() + .unwrap(); if batcher.is_empty() { Ok(()) } else { @@ -758,32 +890,50 @@ impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> { } } -impl<'a> MaliciousDZKPValidator<'a> { +impl<'a, B: ShardBinding> MaliciousDZKPValidator<'a, B> { #[must_use] - pub fn new(ctx: MaliciousContext<'a>, max_multiplications_per_gate: usize) -> Self { + #[allow(clippy::needless_pass_by_value)] + pub fn new( + ctx: MaliciousContext<'a, B>, + steps: MaliciousProtocolSteps, + max_multiplications_per_gate: usize, + ) -> Self + where + Gate: StepNarrow, + S: ipa_step::Step + ?Sized, + { let batcher = Batcher::new( max_multiplications_per_gate, ctx.total_records(), Box::new(move |batch_index| { - Batch::new( - RecordId::from(batch_index * max_multiplications_per_gate), - max_multiplications_per_gate, - ) + let first_record = (max_multiplications_per_gate != usize::MAX) + .then(|| RecordId::from(batch_index * max_multiplications_per_gate)); + Batch::new(first_record, max_multiplications_per_gate) }), ); - let protocol_ctx = - MaliciousDZKPUpgraded::new(&batcher, ctx.narrow(&Step::DZKPMaliciousProtocol)); + let inner = Arc::new(MaliciousDZKPValidatorInner { + batcher, + validate_ctx: ctx.narrow(steps.validate).validator_context(), + }); + let protocol_ctx = MaliciousDZKPUpgraded::new(&inner, ctx.narrow(steps.protocol)); Self { - batcher_ref: Some(batcher), + inner_ref: Some(inner), protocol_ctx, - validate_ctx: ctx, } } } -impl<'a> Drop for MaliciousDZKPValidator<'a> { +impl<'a, B: ShardBinding> Drop for MaliciousDZKPValidator<'a, B> { fn drop(&mut self) { - if self.batcher_ref.is_some() { + // If `validate` has not been called, and we are not unwinding, check that the + // validator is not holding unverified multiplies. + // * If `validate` has been called (i.e. the validator was used in the + // non-`validate_record` mode of operation), then `self.inner_ref` is `None`, + // because validation consumed the batcher via `self.inner_ref`. + // * Unwinding can happen at any time, so complaining about incomplete + // validation is likely just extra noise, and the additional panic + // during unwinding could be confusing. + if self.inner_ref.is_some() && !std::thread::panicking() { self.is_verified().unwrap(); } } @@ -798,36 +948,158 @@ mod tests { }; use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec}; - use futures::{StreamExt, TryStreamExt}; + use futures::{stream, StreamExt, TryStreamExt}; use futures_util::stream::iter; - use proptest::{prop_compose, proptest, sample::select}; - use rand::{thread_rng, Rng}; + use proptest::{ + prelude::{Just, Strategy}, + prop_compose, prop_oneof, proptest, + test_runner::Config as ProptestConfig, + }; + use rand::{distributions::Standard, prelude::Distribution}; use crate::{ error::Error, - ff::{boolean::Boolean, Fp61BitPrime}, + ff::{ + boolean::Boolean, + boolean_array::{BooleanArray, BA16, BA20, BA256, BA3, BA32, BA64, BA8}, + Fp61BitPrime, + }, protocol::{ - basics::SecureMul, + basics::{select, BooleanArrayMul, SecureMul}, context::{ dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, dzkp_validator::{ - Batch, DZKPValidator, Segment, SegmentEntry, Step, BIT_ARRAY_LEN, - TARGET_PROOF_SIZE, + Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, }, - Context, UpgradableContext, + Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, + UpgradableContext, TEST_DZKP_STEPS, }, Gate, RecordId, }, + rand::{thread_rng, Rng}, secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, Vectorizable, }, seq_join::{seq_join, SeqJoin}, + sharding::NotSharded, test_fixture::{join3v, Reconstruct, Runner, TestWorld}, }; + async fn test_select_semi_honest() + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let world = TestWorld::default(); + let context = world.contexts(); + let mut rng = thread_rng(); + + let bit = rng.gen::(); + let a = rng.gen::(); + let b = rng.gen::(); + + let bit_shares = bit.share_with(&mut rng); + let a_shares = a.share_with(&mut rng); + let b_shares = b.share_with(&mut rng); + + let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( + |(ctx, (bit_share, (a_share, b_share)))| async move { + let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); + let sh_ctx = v.context(); + + let result = select( + sh_ctx.set_total_records(1), + RecordId::from(0), + &bit_share, + &a_share, + &b_share, + ) + .await?; + + v.validate().await?; + + Ok::<_, Error>(result) + }, + ); + + let [ab0, ab1, ab2] = join3v(futures).await; + + let ab = [ab0, ab1, ab2].reconstruct(); + + assert_eq!(ab, if bit.into() { a } else { b }); + } + + #[tokio::test] + async fn select_semi_honest() { + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + } + + async fn test_select_malicious() + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let world = TestWorld::default(); + let context = world.malicious_contexts(); + let mut rng = thread_rng(); + + let bit = rng.gen::(); + let a = rng.gen::(); + let b = rng.gen::(); + + let bit_shares = bit.share_with(&mut rng); + let a_shares = a.share_with(&mut rng); + let b_shares = b.share_with(&mut rng); + + let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( + |(ctx, (bit_share, (a_share, b_share)))| async move { + let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); + let m_ctx = v.context(); + + let result = select( + m_ctx.set_total_records(1), + RecordId::from(0), + &bit_share, + &a_share, + &b_share, + ) + .await?; + + v.validate().await?; + + Ok::<_, Error>(result) + }, + ); + + let [ab0, ab1, ab2] = join3v(futures).await; + + let ab = [ab0, ab1, ab2].reconstruct(); + + assert_eq!(ab, if bit.into() { a } else { b }); + } + #[tokio::test] - async fn dzkp_malicious() { + async fn select_malicious() { + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + } + + #[tokio::test] + async fn two_multiplies_malicious() { const COUNT: usize = 32; let mut rng = thread_rng(); @@ -839,11 +1111,8 @@ mod tests { .malicious( original_inputs.clone().into_iter(), |ctx, input_shares| async move { - let v = ctx.dzkp_validator(COUNT); - let m_ctx = v - .context() - .narrow(&Step::DZKPMaliciousProtocol) - .set_total_records(COUNT - 1); + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, COUNT); + let m_ctx = v.context().set_total_records(COUNT - 1); let m_results = seq_join( NonZeroUsize::new(COUNT).unwrap(), @@ -890,9 +1159,54 @@ mod tests { } } + /// Similar to `test_select_malicious`, but operating on vectors + async fn multi_select_malicious(count: usize, max_multiplications_per_gate: usize) + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let mut rng = thread_rng(); + + let bit: Vec = repeat_with(|| rng.gen::()).take(count).collect(); + let a: Vec = repeat_with(|| rng.gen()).take(count).collect(); + let b: Vec = repeat_with(|| rng.gen()).take(count).collect(); + + let [ab0, ab1, ab2]: [Vec>; 3] = TestWorld::default() + .malicious( + zip(bit.clone(), zip(a.clone(), b.clone())), + |ctx, inputs| async move { + let v = ctx + .set_total_records(count) + .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); + let m_ctx = v.context(); + + v.validated_seq_join(stream::iter(inputs).enumerate().map( + |(i, (bit_share, (a_share, b_share)))| { + let m_ctx = m_ctx.clone(); + async move { + select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share) + .await + } + }, + )) + .try_collect() + .await + }, + ) + .await + .map(Result::unwrap); + + let ab: Vec = [ab0, ab1, ab2].reconstruct(); + + for i in 0..count { + assert_eq!(ab[i], if bit[i].into() { a[i] } else { b[i] }); + } + } + /// test for testing `validated_seq_join` - /// similar to `complex_circuit` in `validator.rs` - async fn complex_circuit_dzkp( + /// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment) + async fn chained_multiplies_dzkp( count: usize, max_multiplications_per_gate: usize, ) -> Result<(), Error> { @@ -922,7 +1236,7 @@ mod tests { .map(|(ctx, input_shares)| async move { let v = ctx .set_total_records(count - 1) - .dzkp_validator(ctx.active_work().get()); + .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); let m_ctx = v.context(); let m_results = v @@ -951,7 +1265,7 @@ mod tests { .into_iter() .zip([h1_shares, h2_shares, h3_shares]) .map(|(ctx, input_shares)| async move { - let v = ctx.dzkp_validator(max_multiplications_per_gate); + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); let m_ctx = v.context(); let m_results = v @@ -998,38 +1312,182 @@ mod tests { Ok(()) } + fn record_count_strategy() -> impl Strategy { + // The chained_multiplies test has count - 1 records, so 1 is not a valid input size. + // It is for multi_select though. + prop_oneof![2usize..=512, (1u32..=9).prop_map(|i| 1usize << i)] + } + + fn max_multiplications_per_gate_strategy(record_count: usize) -> impl Strategy { + let max_max_mults = record_count.min(128); + (0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i) + } + prop_compose! { - fn arb_count_and_chunk()((log_count, log_multiplication_amount) in select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) { - (1usize< (usize, usize) + { + (record_count, max_mults) } } proptest! { + #![proptest_config(ProptestConfig::with_cases(20))] #[test] - fn test_complex_circuit_dzkp((count, multiplication_amount) in arb_count_and_chunk()){ - let future = async { - let _ = complex_circuit_dzkp(count, multiplication_amount).await; - }; - tokio::runtime::Runtime::new().unwrap().block_on(future); + fn batching_proptest((record_count, max_multiplications_per_gate) in batching()) { + println!("record_count {record_count} batch {max_multiplications_per_gate}"); + // This condition is correct only for active_work = 16 and record size of 1 byte. + if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 { + // TODO: #1300, read_size | batch_size. + // Note: for active work < 2048, read size matches active work. + + // Besides read_size | batch_size, there is also a constraint + // something like active_work > read_size + batch_size - 1. + println!("skipping config due to read_size vs. batch_size constraints"); + } else { + tokio::runtime::Runtime::new().unwrap().block_on(async { + chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap(); + /* + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + */ + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + /* + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + */ + }); + } + } + } + + #[tokio::test] + async fn large_batch() { + multi_select_malicious::(2 * TARGET_PROOF_SIZE, 2 * TARGET_PROOF_SIZE).await; + } + + // Similar to multi_select_malicious, but instead of using `validated_seq_join`, passes + // `usize::MAX` as the batch size and does a single `v.validate()`. + #[tokio::test] + async fn large_single_batch() { + let count: usize = TARGET_PROOF_SIZE + 1; + let mut rng = thread_rng(); + + let bit: Vec = repeat_with(|| rng.gen::()).take(count).collect(); + let a: Vec = repeat_with(|| rng.gen()).take(count).collect(); + let b: Vec = repeat_with(|| rng.gen()).take(count).collect(); + + let [ab0, ab1, ab2]: [Vec>; 3] = TestWorld::default() + .malicious( + zip(bit.clone(), zip(a.clone(), b.clone())), + |ctx, inputs| async move { + let v = ctx + .set_total_records(count) + .dzkp_validator(TEST_DZKP_STEPS, usize::MAX); + let m_ctx = v.context(); + + let result = seq_join( + m_ctx.active_work(), + stream::iter(inputs).enumerate().map( + |(i, (bit_share, (a_share, b_share)))| { + let m_ctx = m_ctx.clone(); + async move { + select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share) + .await + } + }, + ), + ) + .try_collect() + .await + .unwrap(); + + v.validate().await.unwrap(); + + result + }, + ) + .await; + + let ab: Vec = [ab0, ab1, ab2].reconstruct(); + + for i in 0..count { + assert_eq!(ab[i], if bit[i].into() { a[i] } else { b[i] }); + } + } + + #[tokio::test] + #[should_panic(expected = "ContextUnsafe(\"DZKPMaliciousContext\")")] + async fn missing_validate() { + let mut rng = thread_rng(); + + let a = rng.gen::(); + let b = rng.gen::(); + + TestWorld::default() + .malicious((a, b), |ctx, (a, b)| async move { + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 1); + let m_ctx = v.context().set_total_records(1); + + a.multiply(&b, m_ctx, RecordId::FIRST).await.unwrap() + + // `validate` should appear here. + }) + .await; + } + + #[tokio::test] + #[should_panic(expected = "panicking before validate")] + #[allow(unreachable_code)] + async fn missing_validate_panic() { + let mut rng = thread_rng(); + + let a = rng.gen::(); + let b = rng.gen::(); + + TestWorld::default() + .malicious((a, b), |ctx, (a, b)| async move { + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 1); + let m_ctx = v.context().set_total_records(1); + + let _result = a.multiply(&b, m_ctx, RecordId::FIRST).await.unwrap(); + + panic!("panicking before validate"); + }) + .await; + } + + fn segment_from_entry(entry: SegmentEntry) -> Segment { + Segment::from_entries( + entry.clone(), + entry.clone(), + entry.clone(), + entry.clone(), + entry.clone(), + entry.clone(), + entry, + ) + } + + impl Batch { + fn with_implicit_first_record(max_multiplications_per_gate: usize) -> Self { + Batch::new(None, max_multiplications_per_gate) } } #[test] fn batch_allocation_small() { const SIZE: usize = 1; - let mut batch = Batch::new(RecordId::FIRST, SIZE); + let mut batch = Batch::with_implicit_first_record(SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); - let segment_entry = >::as_segment_entry(&zero_vec); - let segment = Segment::from_entries( - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry, - ); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); batch.push(Gate::default(), RecordId::FIRST, segment); assert_eq!(batch.inner.get(&Gate::default()).unwrap().vec.len(), 1); assert!(batch.inner.get(&Gate::default()).unwrap().vec.capacity() >= SIZE); @@ -1039,19 +1497,12 @@ mod tests { #[test] fn batch_allocation_big() { const SIZE: usize = 2 * TARGET_PROOF_SIZE; - let mut batch = Batch::new(RecordId::FIRST, SIZE); + let mut batch = Batch::with_implicit_first_record(SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); - let segment_entry = >::as_segment_entry(&zero_vec); - let segment = Segment::from_entries( - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry, - ); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); batch.push(Gate::default(), RecordId::FIRST, segment); assert_eq!(batch.inner.get(&Gate::default()).unwrap().vec.len(), 1); assert!( @@ -1067,19 +1518,12 @@ mod tests { #[test] fn batch_fill() { const SIZE: usize = 10; - let mut batch = Batch::new(RecordId::FIRST, SIZE); + let mut batch = Batch::with_implicit_first_record(SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); - let segment_entry = >::as_segment_entry(&zero_vec); - let segment = Segment::from_entries( - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry, - ); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); for i in 0..SIZE { batch.push(Gate::default(), RecordId::from(i), segment.clone()); } @@ -1088,25 +1532,131 @@ mod tests { assert!(batch.inner.get(&Gate::default()).unwrap().vec.capacity() <= 2); } + #[test] + fn batch_fill_out_of_order() { + let mut batch = Batch::with_implicit_first_record(3); + let ba0 = BA256::from((0, 0)); + let ba1 = BA256::from((0, 1)); + let ba2 = BA256::from((0, 2)); + let segment = segment_from_entry(>::as_segment_entry( + &ba0, + )); + batch.push(Gate::default(), RecordId::from(0), segment.clone()); + let segment = segment_from_entry(>::as_segment_entry( + &ba2, + )); + batch.push(Gate::default(), RecordId::from(2), segment.clone()); + let segment = segment_from_entry(>::as_segment_entry( + &ba1, + )); + batch.push(Gate::default(), RecordId::from(1), segment.clone()); + assert_eq!(batch.inner.get(&Gate::default()).unwrap().vec.len(), 3); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[0].x_left, + ba0.as_bitslice() + ); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[1].x_left, + ba1.as_bitslice() + ); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[2].x_left, + ba2.as_bitslice() + ); + } + + #[test] + fn batch_fill_at_offset() { + const SIZE: usize = 3; + let mut batch = Batch::with_implicit_first_record(SIZE); + let ba0 = BA256::from((0, 0)); + let ba1 = BA256::from((0, 1)); + let ba2 = BA256::from((0, 2)); + let segment = segment_from_entry(>::as_segment_entry( + &ba0, + )); + batch.push(Gate::default(), RecordId::from(4), segment.clone()); + let segment = segment_from_entry(>::as_segment_entry( + &ba1, + )); + batch.push(Gate::default(), RecordId::from(5), segment.clone()); + let segment = segment_from_entry(>::as_segment_entry( + &ba2, + )); + batch.push(Gate::default(), RecordId::from(6), segment.clone()); + assert_eq!(batch.inner.get(&Gate::default()).unwrap().vec.len(), 3); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[0].x_left, + ba0.as_bitslice() + ); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[1].x_left, + ba1.as_bitslice() + ); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[2].x_left, + ba2.as_bitslice() + ); + } + + #[test] + fn batch_explicit_first_record() { + const SIZE: usize = 3; + let mut batch = Batch::new(Some(RecordId::from(4)), SIZE); + let ba6 = BA256::from((0, 6)); + let segment = segment_from_entry(>::as_segment_entry( + &ba6, + )); + batch.push(Gate::default(), RecordId::from(6), segment.clone()); + assert_eq!(batch.inner.get(&Gate::default()).unwrap().vec.len(), 3); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[2].x_left, + ba6.as_bitslice() + ); + } + + #[test] + fn batch_is_empty() { + const SIZE: usize = 10; + let mut batch = Batch::with_implicit_first_record(SIZE); + assert!(batch.is_empty()); + let zero = Boolean::ZERO; + let zero_vec: >::Array = zero.into_array(); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); + batch.push(Gate::default(), RecordId::FIRST, segment); + assert!(!batch.is_empty()); + } + + #[test] + #[should_panic( + expected = "record_id out of range in insert_segment. record 0 is before first record 10" + )] + fn batch_underflow() { + const SIZE: usize = 10; + let mut batch = Batch::with_implicit_first_record(SIZE); + let zero = Boolean::ZERO; + let zero_vec: >::Array = zero.into_array(); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); + batch.push(Gate::default(), RecordId::from(10), segment.clone()); + batch.push(Gate::default(), RecordId::from(0), segment.clone()); + } + #[test] #[should_panic( expected = "record_id out of range in insert_segment. record 10 is beyond segment of length 10 starting at 0" )] fn batch_overflow() { const SIZE: usize = 10; - let mut batch = Batch::new(RecordId::FIRST, SIZE); + let mut batch = Batch::with_implicit_first_record(SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); - let segment_entry = >::as_segment_entry(&zero_vec); - let segment = Segment::from_entries( - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry, - ); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); for i in 0..=SIZE { batch.push(Gate::default(), RecordId::from(i), segment.clone()); } @@ -1237,13 +1787,13 @@ mod tests { // test for small and large segments, i.e. 8bit and 512 bit for segment_size in [8usize, 512usize] { // generate batch for the prover - let mut batch_prover = Batch::new(RecordId::FIRST, 1024 / segment_size); + let mut batch_prover = Batch::with_implicit_first_record(1024 / segment_size); // generate batch for the verifier on the left of the prover - let mut batch_left = Batch::new(RecordId::FIRST, 1024 / segment_size); + let mut batch_left = Batch::with_implicit_first_record(1024 / segment_size); // generate batch for the verifier on the right of the prover - let mut batch_right = Batch::new(RecordId::FIRST, 1024 / segment_size); + let mut batch_right = Batch::with_implicit_first_record(1024 / segment_size); // fill the batches with random values populate_batch( @@ -1311,14 +1861,16 @@ mod tests { let [h1_batch, h2_batch, h3_batch] = world .malicious((a, b), |ctx, (a, b)| async move { - let mut validator = ctx.dzkp_validator(10); + let mut validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 8); let mctx = validator.context(); let _ = a .multiply(&b, mctx.set_total_records(1), RecordId::from(0)) .await .unwrap(); - let batcher_mutex = Arc::into_inner(validator.batcher_ref.take().unwrap()).unwrap(); + let batcher_mutex = Arc::into_inner(validator.inner_ref.take().unwrap()) + .unwrap() + .batcher; batcher_mutex.into_inner().unwrap().into_single_batch() }) .await; diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index d001b8ab6..0253a810d 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -13,11 +13,14 @@ use crate::{ protocol::{ basics::mul::{semi_honest_multiply, step::MaliciousMultiplyStep::RandomnessForValidation}, context::{ - batcher::Batcher, dzkp_validator::MaliciousDZKPValidator, - prss::InstrumentedIndexedSharedRandomness, step::UpgradeStep, upgrade::Upgradable, - validator, validator::BatchValidator, Base, Context as ContextTrait, - InstrumentedSequentialSharedRandomness, SpecialAccessToUpgradedContext, - UpgradableContext, UpgradedContext, + batcher::Batcher, + dzkp_validator::MaliciousDZKPValidator, + prss::InstrumentedIndexedSharedRandomness, + step::UpgradeStep, + upgrade::Upgradable, + validator::{self, BatchValidator}, + Base, Context as ContextTrait, InstrumentedSequentialSharedRandomness, + SpecialAccessToUpgradedContext, UpgradableContext, UpgradedContext, }, prss::{Endpoint as PrssEndpoint, FromPrss}, Gate, RecordId, @@ -27,33 +30,48 @@ use crate::{ semi_honest::AdditiveShare as Replicated, }, seq_join::SeqJoin, - sharding::NotSharded, + sharding::{NotSharded, ShardBinding}, sync::Arc, }; +pub struct MaliciousProtocolSteps<'a, S: Step + ?Sized> { + pub protocol: &'a S, + pub validate: &'a S, +} + +#[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))] +pub(crate) const TEST_DZKP_STEPS: MaliciousProtocolSteps< + 'static, + super::step::MaliciousProtocolStep, +> = MaliciousProtocolSteps { + protocol: &super::step::MaliciousProtocolStep::MaliciousProtocol, + validate: &super::step::MaliciousProtocolStep::Validate, +}; + #[derive(Clone)] -pub struct Context<'a> { - inner: Base<'a>, +pub struct Context<'a, B: ShardBinding> { + inner: Base<'a, B>, } -impl<'a> Context<'a> { +impl<'a> Context<'a, NotSharded> { pub fn new(participant: &'a PrssEndpoint, gateway: &'a Gateway) -> Self { - Self::new_with_gate(participant, gateway, Gate::default()) + Self::new_with_gate(participant, gateway, Gate::default(), NotSharded) } +} - pub fn new_with_gate(participant: &'a PrssEndpoint, gateway: &'a Gateway, gate: Gate) -> Self { +impl<'a, B: ShardBinding> Context<'a, B> { + pub fn new_with_gate( + participant: &'a PrssEndpoint, + gateway: &'a Gateway, + gate: Gate, + shard: B, + ) -> Self { Self { - inner: Base::new_complete( - participant, - gateway, - gate, - TotalRecords::Unspecified, - NotSharded, - ), + inner: Base::new_complete(participant, gateway, gate, TotalRecords::Unspecified, shard), } } - pub(crate) fn validator_context(self) -> Base<'a> { + pub(crate) fn validator_context(self) -> Base<'a, B> { // The DZKP validator uses communcation channels internally. We don't want any TotalRecords // set by the protocol to apply to those channels. Base { @@ -61,9 +79,16 @@ impl<'a> Context<'a> { ..self.inner } } + + #[must_use] + pub fn set_active_work(self, new_active_work: NonZeroU32PowerOfTwo) -> Self { + Self { + inner: self.inner.set_active_work(new_active_work), + } + } } -impl<'a> super::Context for Context<'a> { +impl<'a, B: ShardBinding> super::Context for Context<'a, B> { fn role(&self) -> Role { self.inner.role() } @@ -113,46 +138,67 @@ impl<'a> super::Context for Context<'a> { } } -impl<'a> UpgradableContext for Context<'a> { - type Validator = BatchValidator<'a, F>; +impl<'a, B: ShardBinding> UpgradableContext for Context<'a, B> { + type Validator = BatchValidator<'a, F, B>; fn validator(self) -> Self::Validator { BatchValidator::new(self) } - type DZKPValidator = MaliciousDZKPValidator<'a>; + type DZKPValidator = MaliciousDZKPValidator<'a, B>; - fn dzkp_validator(self, max_multiplications_per_gate: usize) -> Self::DZKPValidator { - MaliciousDZKPValidator::new(self, max_multiplications_per_gate) + fn dzkp_validator( + self, + steps: MaliciousProtocolSteps, + max_multiplications_per_gate: usize, + ) -> Self::DZKPValidator + where + Gate: StepNarrow, + S: Step + ?Sized, + { + MaliciousDZKPValidator::new(self, steps, max_multiplications_per_gate) } } -impl<'a> SeqJoin for Context<'a> { +impl<'a, B: ShardBinding> SeqJoin for Context<'a, B> { fn active_work(&self) -> NonZeroUsize { self.inner.active_work() } } -impl Debug for Context<'_> { +impl Debug for Context<'_, B> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "MaliciousContext") } } -use crate::sync::{Mutex, Weak}; +use crate::{ + sync::{Mutex, Weak}, + utils::NonZeroU32PowerOfTwo, +}; -pub(super) type MacBatcher<'a, F> = Mutex>>; +pub(super) type MacBatcher<'a, F, B> = Mutex>>; /// Represents protocol context in malicious setting, i.e. secure against one active adversary /// in 3 party MPC ring. #[derive(Clone)] -pub struct Upgraded<'a, F: ExtendableField> { - batch: Weak>, - base_ctx: Context<'a>, +pub struct Upgraded<'a, F: ExtendableField, B: ShardBinding> { + batch: Weak>, + base_ctx: Context<'a, B>, } -impl<'a, F: ExtendableField> Upgraded<'a, F> { - pub(super) fn new(batch: &Arc>, ctx: Context<'a>) -> Self { +impl<'a, F: ExtendableField, B: ShardBinding> Upgraded<'a, F, B> { + pub(super) fn new(batch: &Arc>, ctx: Context<'a, B>) -> Self { + // The DZKP malicious context adjusts active_work to match records_per_batch. + // The MAC validator currently configures the batcher with records_per_batch = + // active_work. If the latter behavior changes, this code may need to be + // updated. + let records_per_batch = batch.lock().unwrap().records_per_batch(); + let active_work = ctx.active_work().get(); + assert_eq!( + records_per_batch, active_work, + "Expect MAC validation batch size ({records_per_batch}) to match active work ({active_work})", + ); Self { batch: Arc::downgrade(batch), base_ctx: ctx, @@ -188,7 +234,7 @@ impl<'a, F: ExtendableField> Upgraded<'a, F> { self.with_batch(record_id, |v| v.r_share().clone()) } - fn with_batch) -> T, T>( + fn with_batch) -> T, T>( &self, record_id: RecordId, action: C, @@ -202,7 +248,7 @@ impl<'a, F: ExtendableField> Upgraded<'a, F> { } #[async_trait] -impl<'a, F: ExtendableField> UpgradedContext for Upgraded<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> UpgradedContext for Upgraded<'a, F, B> { type Field = F; async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> { @@ -218,7 +264,7 @@ impl<'a, F: ExtendableField> UpgradedContext for Upgraded<'a, F> { } } -impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> super::Context for Upgraded<'a, F, B> { fn role(&self) -> Role { self.base_ctx.role() } @@ -270,7 +316,7 @@ impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { } } -impl<'a, F: ExtendableField> SeqJoin for Upgraded<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> SeqJoin for Upgraded<'a, F, B> { fn active_work(&self) -> NonZeroUsize { self.base_ctx.active_work() } @@ -280,15 +326,17 @@ impl<'a, F: ExtendableField> SeqJoin for Upgraded<'a, F> { /// protocols should be generic over `SecretShare` trait and not requiring this cast and taking /// `ProtocolContext<'a, S: SecretShare, F: Field>` as the context. If that is not possible, /// this implementation makes it easier to reinterpret the context as semi-honest. -impl<'a, F: ExtendableField> SpecialAccessToUpgradedContext for Upgraded<'a, F> { - type Base = Base<'a>; +impl<'a, F: ExtendableField, B: ShardBinding> SpecialAccessToUpgradedContext + for Upgraded<'a, F, B> +{ + type Base = Base<'a, B>; fn base_context(self) -> Self::Base { self.base_ctx.inner } } -impl Debug for Upgraded<'_, F> { +impl Debug for Upgraded<'_, F, B> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "MaliciousContext<{:?}>", type_name::()) } @@ -297,7 +345,8 @@ impl Debug for Upgraded<'_, F> { /// Upgrading a semi-honest replicated share using malicious context produces /// a MAC-secured share with the same vectorization factor. #[async_trait] -impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> for Replicated +impl<'a, V: ExtendableFieldSimd, B: ShardBinding, const N: usize> Upgradable> + for Replicated where Replicated<::ExtendedField, N>: FromPrss, { @@ -305,7 +354,7 @@ where async fn upgrade( self, - ctx: Upgraded<'a, V>, + ctx: Upgraded<'a, V, B>, record_id: RecordId, ) -> Result { let ctx = ctx.narrow(&UpgradeStep); @@ -339,7 +388,7 @@ where #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> +impl<'a, V: ExtendableFieldSimd, B: ShardBinding, const N: usize> Upgradable> for (Replicated, Replicated) where Replicated<::ExtendedField, N>: FromPrss, @@ -348,7 +397,7 @@ where async fn upgrade( self, - ctx: Upgraded<'a, V>, + ctx: Upgraded<'a, V, B>, record_id: RecordId, ) -> Result { let (l, r) = self; @@ -360,12 +409,12 @@ where #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V: ExtendableField> Upgradable> for () { +impl<'a, V: ExtendableField, B: ShardBinding> Upgradable> for () { type Output = (); async fn upgrade( self, - _context: Upgraded<'a, V>, + _context: Upgraded<'a, V, B>, _record_id: RecordId, ) -> Result { Ok(()) @@ -374,28 +423,30 @@ impl<'a, V: ExtendableField> Upgradable> for () { #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V, U> Upgradable> for Vec +impl<'a, V, U, B> Upgradable> for Vec where V: ExtendableField, - U: Upgradable, Output: Send> + Send + 'a, + U: Upgradable, Output: Send> + Send + 'a, + B: ShardBinding, { type Output = Vec; async fn upgrade( self, - ctx: Upgraded<'a, V>, + ctx: Upgraded<'a, V, B>, record_id: RecordId, ) -> Result { /// Need a standalone function to avoid GAT issue that apparently can manifest /// even with `async_trait`. - fn upgrade_vec<'a, V, U>( - ctx: Upgraded<'a, V>, + fn upgrade_vec<'a, V, U, B>( + ctx: Upgraded<'a, V, B>, record_id: RecordId, input: Vec, ) -> impl std::future::Future, Error>> + 'a where V: ExtendableField, - U: Upgradable> + 'a, + U: Upgradable> + 'a, + B: ShardBinding, { let mut upgraded = Vec::with_capacity(input.len()); async move { diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index af2e0142b..d84dd0926 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -9,30 +9,34 @@ pub mod step; pub mod upgrade; mod batcher; -/// Validators are not used in IPA v3 yet. Once we make use of MAC-based validation, -/// this flag can be removed -#[allow(dead_code)] pub mod validator; -use std::{collections::HashMap, iter, num::NonZeroUsize, pin::pin}; +use std::{collections::HashMap, num::NonZeroUsize, pin::pin}; use async_trait::async_trait; pub use dzkp_malicious::DZKPUpgraded as DZKPUpgradedMaliciousContext; pub use dzkp_semi_honest::DZKPUpgraded as DZKPUpgradedSemiHonestContext; -use futures::{stream, Stream, StreamExt}; +use futures::{stream, Stream, StreamExt, TryStreamExt}; use ipa_step::{Step, StepNarrow}; -pub use malicious::{Context as MaliciousContext, Upgraded as UpgradedMaliciousContext}; +pub use malicious::MaliciousProtocolSteps; use prss::{InstrumentedIndexedSharedRandomness, InstrumentedSequentialSharedRandomness}; pub use semi_honest::Upgraded as UpgradedSemiHonestContext; pub use validator::Validator; pub type SemiHonestContext<'a, B = NotSharded> = semi_honest::Context<'a, B>; pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>; +pub type MaliciousContext<'a, B = NotSharded> = malicious::Context<'a, B>; +pub type ShardedMaliciousContext<'a> = malicious::Context<'a, Sharded>; +pub type UpgradedMaliciousContext<'a, F, B = NotSharded> = malicious::Upgraded<'a, F, B>; + +#[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))] +pub(crate) use malicious::TEST_DZKP_STEPS; + use crate::{ error::Error, helpers::{ - ChannelId, Direction, Gateway, Message, MpcMessage, MpcReceivingEnd, Role, SendingEnd, - ShardReceivingEnd, TotalRecords, + stream::ExactSizeStream, ChannelId, Direction, Gateway, Message, MpcMessage, + MpcReceivingEnd, Role, SendingEnd, ShardReceivingEnd, TotalRecords, }, protocol::{ context::dzkp_validator::DZKPValidator, @@ -42,6 +46,7 @@ use crate::{ secret_sharing::replicated::malicious::ExtendableField, seq_join::SeqJoin, sharding::{NotSharded, ShardBinding, ShardConfiguration, ShardIndex, Sharded}, + utils::NonZeroU32PowerOfTwo, }; /// Context used by each helper to perform secure computation. Provides access to shared randomness @@ -109,7 +114,14 @@ pub trait UpgradableContext: Context { type DZKPValidator: DZKPValidator; - fn dzkp_validator(self, max_multiplications_per_gate: usize) -> Self::DZKPValidator; + fn dzkp_validator( + self, + steps: MaliciousProtocolSteps, + max_multiplications_per_gate: usize, + ) -> Self::DZKPValidator + where + Gate: StepNarrow, + S: Step + ?Sized; } pub type MacUpgraded = <::Validator as Validator>::Context; @@ -153,6 +165,7 @@ pub struct Base<'a, B: ShardBinding = NotSharded> { inner: Inner<'a>, gate: Gate, total_records: TotalRecords, + active_work: NonZeroU32PowerOfTwo, /// This indicates whether the system uses sharding or no. It's not ideal that we keep it here /// because it gets cloned often, a potential solution to that, if this shows up on flame graph, /// would be to move it to [`Inner`] struct. @@ -171,9 +184,18 @@ impl<'a, B: ShardBinding> Base<'a, B> { inner: Inner::new(participant, gateway), gate, total_records, + active_work: gateway.config().active_work_as_power_of_two(), sharding, } } + + #[must_use] + pub fn set_active_work(self, new_active_work: NonZeroU32PowerOfTwo) -> Self { + Self { + active_work: new_active_work, + ..self.clone() + } + } } impl ShardedContext for Base<'_, Sharded> { @@ -208,6 +230,7 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> { inner: self.inner.clone(), gate: self.gate.narrow(step), total_records: self.total_records, + active_work: self.active_work, sharding: self.sharding.clone(), } } @@ -217,6 +240,7 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> { inner: self.inner.clone(), gate: self.gate.clone(), total_records: self.total_records.overwrite(total_records), + active_work: self.active_work, sharding: self.sharding.clone(), } } @@ -245,9 +269,11 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> { } fn send_channel(&self, role: Role) -> SendingEnd { - self.inner - .gateway - .get_mpc_sender(&ChannelId::new(role, self.gate.clone()), self.total_records) + self.inner.gateway.get_mpc_sender( + &ChannelId::new(role, self.gate.clone()), + self.total_records, + self.active_work, + ) } fn recv_channel(&self, role: Role) -> MpcReceivingEnd { @@ -313,7 +339,7 @@ impl ShardConfiguration for Base<'_, Sharded> { impl<'a, B: ShardBinding> SeqJoin for Base<'a, B> { fn active_work(&self) -> NonZeroUsize { - self.inner.gateway.config().active_work() + self.active_work.to_non_zero_usize() } } @@ -339,6 +365,12 @@ impl<'a> Inner<'a> { /// N per shard). Each channel stays open until the very last row is processed, then they are explicitly /// closed, even if nothing has been communicated between that pair. /// +/// ## Stream size +/// [`reshard_try_stream`] takes a regular stream, but will panic at runtime, if the stream +/// upper bound size is not known. Opting out for a runtime check is necessary for it to work +/// with query inputs, where the submitter stream is truncated to take at most `sz` elements. +/// This would mean that stream may have less than `sz` elements and resharding should work. +/// /// ## Shard picking considerations /// It is expected for `shard_picker` to select shards uniformly, by either using [`prss`] or sampling /// random values with enough entropy. Failure to do so may lead to extra memory overhead - this @@ -348,25 +380,29 @@ impl<'a> Inner<'a> { /// /// [`calculations`]: https://docs.google.com/document/d/1vej6tYgNV3GWcldD4tl7a4Z9EeZwda3F5u7roPGArlU/ /// +/// /// ## Panics -/// When `shard_picker` returns an out-of-bounds index. +/// When `shard_picker` returns an out-of-bounds index or if the input stream size +/// upper bound is not known. The latter may be the case for infinite streams. /// /// ## Errors -/// If cross-shard communication fails -pub async fn reshard( +/// If cross-shard communication fails or if an input stream +/// yields an `Err` element. +/// +pub async fn reshard_try_stream( ctx: C, input: L, shard_picker: S, ) -> Result, crate::error::Error> where - L: IntoIterator, - L::IntoIter: ExactSizeIterator, + L: Stream>, S: Fn(C, RecordId, &K) -> ShardIndex, K: Message + Clone, C: ShardedContext, { - let input = input.into_iter(); - let input_len = input.len(); + let (_, Some(input_len)) = input.size_hint() else { + panic!("input stream must have size upper bound for resharding to work") + }; // We set channels capacity to be at least 1 to be able to open send channels to all peers. // It is prohibited to create them if total records is not set. We also over-provision here @@ -392,14 +428,18 @@ where // Request data from all shards. let rcv_stream = ctx .recv_from_shards::() - .map(|(shard_id, v)| { - ( - shard_id, - v.map(Option::Some).map_err(crate::error::Error::from), - ) + .map(|(shard_id, v)| match v { + Ok(v) => Ok((shard_id, Some(v))), + Err(e) => Err(e), }) .fuse(); + let input = pin!(input); + // Annoying consequence of not having async closures stable. async blocks + // cannot capture `Copy` values and there is no way to express that + // only some things need to be moved in Rust + let mut counter = 0_u32; + // This produces a stream of outcomes of send requests. // In order to make it compatible with receive stream, it also returns records that must // stay on this shard, according to `shard_picker`'s decision. @@ -408,36 +448,42 @@ where // whole resharding process. // If send was successful, we set the argument to Ok(None). Only records assigned to this shard // by the `shard_picker` will have the value of Ok(Some(Value)) - let send_stream = futures::stream::unfold( + let send_stream = futures::stream::try_unfold( // it is crucial that the following execution is completed sequentially, in order for record id // tracking per shard to work correctly. If tasks complete out of order, this will cause share // misplacement on the recipient side. - ( - input.enumerate().zip(iter::repeat(ctx.clone())), - &mut send_channels, - ), - |(mut input, send_channels)| async { - // Process more data as it comes in, or close the sending channels, if there is nothing - // left. - if let Some(((i, val), ctx)) = input.next() { - let dest_shard = shard_picker(ctx, RecordId::from(i), &val); - if dest_shard == my_shard { - Some(((my_shard, Ok(Some(val.clone()))), (input, send_channels))) + (input, &mut send_channels, &mut counter), + |(mut input, send_channels, i)| { + let ctx = ctx.clone(); + async { + // Process more data as it comes in, or close the sending channels, if there is nothing + // left. + if let Some(val) = input.try_next().await? { + if usize::try_from(*i).unwrap() >= input_len { + return Err(crate::error::Error::RecordIdOutOfRange { + record_id: RecordId::from(*i), + total_records: input_len, + }); + } + + let dest_shard = shard_picker(ctx, RecordId::from(*i), &val); + *i += 1; + if dest_shard == my_shard { + Ok(Some(((my_shard, Some(val)), (input, send_channels, i)))) + } else { + let (record_id, se) = send_channels.get_mut(&dest_shard).unwrap(); + se.send(*record_id, val) + .await + .map_err(crate::error::Error::from)?; + *record_id += 1; + Ok(Some(((my_shard, None), (input, send_channels, i)))) + } } else { - let (record_id, se) = send_channels.get_mut(&dest_shard).unwrap(); - let send_result = se - .send(*record_id, val) - .await - .map_err(crate::error::Error::from) - .map(|()| None); - *record_id += 1; - Some(((my_shard, send_result), (input, send_channels))) - } - } else { - for (last_record, send_channel) in send_channels.values() { - send_channel.close(*last_record).await; + for (last_record, send_channel) in send_channels.values() { + send_channel.close(*last_record).await; + } + Ok(None) } - None } }, ) @@ -469,8 +515,8 @@ where // This approach makes sure we do what we can - send or receive. let mut send_recv = pin!(futures::stream::select(send_stream, rcv_stream)); - while let Some((shard_id, v)) = send_recv.next().await { - if let Some(m) = v? { + while let Some((shard_id, v)) = send_recv.try_next().await? { + if let Some(m) = v { r[usize::from(shard_id)].push(m); } } @@ -478,6 +524,71 @@ where Ok(r.into_iter().flatten().collect()) } +/// Provides the same functionality as [`reshard_try_stream`] on +/// infallible streams +/// +/// ## Stream size +/// Note that it currently works for streams where size is known in advance. Mainly because +/// we want to set up send buffer sizes and avoid sending records one-by-one to each shard. +/// Other than that, there are no technical limitation here, and it could be possible to make it +/// work with regular streams or opt-out to runtime checks as [`reshard_try_stream`] does. +/// +/// +/// ```compile_fail +/// use futures::stream::{self, StreamExt}; +/// use ipa_core::protocol::context::reshard_stream; +/// use ipa_core::ff::boolean::Boolean; +/// use ipa_core::secret_sharing::SharedValue; +/// async { +/// let a = [Boolean::ZERO]; +/// let mut s = stream::iter(a.into_iter()).cycle(); +/// // this should fail to compile: +/// // the trait bound `futures::stream::Cycle<...>: ExactSizeStream` is not satisfied +/// reshard_stream(todo!(), s, todo!()).await; +/// }; +/// ``` +/// ## Panics +/// When `shard_picker` returns an out-of-bounds index. +/// +/// ## Errors +/// If cross-shard communication fails +pub async fn reshard_stream( + ctx: C, + input: L, + shard_picker: S, +) -> Result, crate::error::Error> +where + L: ExactSizeStream, + S: Fn(C, RecordId, &K) -> ShardIndex, + K: Message + Clone, + C: ShardedContext, +{ + reshard_try_stream(ctx, input.map(Ok), shard_picker).await +} + +/// Same as [`reshard_stream`] but takes an iterator with the known size +/// as input. +/// +/// ## Panics +/// When `shard_picker` returns an out-of-bounds index. +/// +/// ## Errors +/// If cross-shard communication fails +pub async fn reshard_iter( + ctx: C, + input: L, + shard_picker: S, +) -> Result, crate::error::Error> +where + L: IntoIterator, + L::IntoIter: ExactSizeIterator, + S: Fn(C, RecordId, &K) -> ShardIndex, + K: Message + Clone, + C: ShardedContext, +{ + reshard_stream(ctx, stream::iter(input.into_iter()), shard_picker).await +} + /// trait for contexts that allow MPC multiplications that are protected against a malicious helper by using a DZKP #[async_trait] pub trait DZKPContext: Context { @@ -498,10 +609,11 @@ pub trait DZKPContext: Context { #[cfg(all(test, unit_test))] mod tests { - use std::{iter, iter::repeat}; + use std::{iter, iter::repeat, pin::Pin, task::Poll}; - use futures::{future::join_all, stream::StreamExt, try_join}; + use futures::{future::join_all, ready, stream, stream::StreamExt, try_join, Stream}; use ipa_step::StepNarrow; + use pin_project::pin_project; use rand::{ distributions::{Distribution, Standard}, Rng, @@ -517,16 +629,20 @@ mod tests { protocol::{ basics::ShareKnownValue, context::{ - reshard, step::MaliciousProtocolStep::MaliciousProtocol, upgrade::Upgradable, - Context, ShardedContext, UpgradableContext, Validator, + reshard_iter, reshard_stream, reshard_try_stream, + step::MaliciousProtocolStep::MaliciousProtocol, upgrade::Upgradable, Context, + ShardedContext, UpgradableContext, Validator, }, prss::SharedRandomness, RecordId, }, - secret_sharing::replicated::{ - malicious::{AdditiveShare as MaliciousReplicated, ExtendableField}, - semi_honest::AdditiveShare as Replicated, - ReplicatedSecretSharing, + secret_sharing::{ + replicated::{ + malicious::{AdditiveShare as MaliciousReplicated, ExtendableField}, + semi_honest::AdditiveShare as Replicated, + ReplicatedSecretSharing, + }, + SharedValue, }, sharding::{ShardConfiguration, ShardIndex}, telemetry::metrics::{ @@ -796,7 +912,34 @@ mod tests { /// Ensure global record order across shards is consistent. #[test] - fn shard_picker() { + fn reshard_stream_test() { + run(|| async move { + const SHARDS: u32 = 5; + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + + let input: Vec<_> = (0..SHARDS).map(BA8::truncate_from).collect(); + let r = world + .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { + let shard_input = stream::iter(shard_input); + reshard_stream(ctx, shard_input, |_, record_id, _| { + ShardIndex::from(u32::from(record_id) % SHARDS) + }) + .await + .unwrap() + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } + + /// Ensure global record order across shards is consistent. + #[test] + fn reshard_iter_test() { run(|| async move { const SHARDS: u32 = 5; let world: TestWorld> = @@ -804,7 +947,7 @@ mod tests { let input: Vec<_> = (0..SHARDS).map(BA8::truncate_from).collect(); let r = world .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { - reshard(ctx, shard_input, |_, record_id, _| { + reshard_iter(ctx, shard_input, |_, record_id, _| { ShardIndex::from(u32::from(record_id) % SHARDS) }) .await @@ -819,6 +962,201 @@ mod tests { }); } + #[test] + fn reshard_try_stream_basic() { + run(|| async move { + const SHARDS: u32 = 5; + let input: Vec<_> = (0..SHARDS).map(BA8::truncate_from).collect(); + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let r = world + .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { + reshard_try_stream(ctx, stream::iter(shard_input).map(Ok), |_, record_id, _| { + ShardIndex::from(u32::from(record_id) % SHARDS) + }) + .await + .unwrap() + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } + + #[test] + #[should_panic(expected = "RecordIdOutOfRange { record_id: RecordId(1), total_records: 1 }")] + fn reshard_try_stream_more_items_than_expected() { + #[pin_project] + struct AdversaryStream { + #[pin] + inner: S, + wrong_length: usize, + } + + impl AdversaryStream { + fn new(inner: S, wrong_length: usize) -> Self { + assert!(wrong_length > 0); + Self { + inner, + wrong_length, + } + } + } + + impl Stream for AdversaryStream { + type Item = S::Item; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.project(); + + this.inner.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.wrong_length)) + } + } + + run(|| async move { + const SHARDS: u32 = 5; + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let input: Vec<_> = (0..5 * SHARDS).map(BA8::truncate_from).collect(); + world + .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { + reshard_try_stream( + ctx, + AdversaryStream::new(stream::iter(shard_input).map(Ok), 1), + |_, _, _| ShardIndex::FIRST, + ) + .await + .unwrap() + }) + .await; + }); + } + + #[test] + fn reshard_try_stream_less_items_than_expected() { + /// This allows advertising higher upper bound limit + /// that actual number of elements in the stream. + /// reshard should be able to tolerate that + #[pin_project] + struct Wrapper { + #[pin] + inner: S, + expected_len: usize, + } + + impl Wrapper { + fn new(inner: S, expected_len: usize) -> Self { + assert!(expected_len > 0); + Self { + inner, + expected_len, + } + } + } + + impl Stream for Wrapper { + type Item = S::Item; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.project(); + let r = match ready!(this.inner.poll_next(cx)) { + Some(val) => { + *this.expected_len -= 1; + Poll::Ready(Some(val)) + } + None => Poll::Ready(None), + }; + + assert!( + *this.expected_len > 0, + "Stream should have less elements than expected" + ); + r + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.expected_len)) + } + } + + run(|| async move { + const SHARDS: u32 = 5; + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let input: Vec<_> = (0..5 * SHARDS).map(BA8::truncate_from).collect(); + let r = world + .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { + reshard_try_stream( + ctx, + Wrapper::new(stream::iter(shard_input).map(Ok), 25), + |_, record_id, _| ShardIndex::from(u32::from(record_id) % SHARDS), + ) + .await + .unwrap() + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } + + #[test] + #[should_panic(expected = "input stream must have size upper bound for resharding to work")] + fn reshard_try_stream_infinite() { + run(|| async move { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest(Vec::::new().into_iter(), |ctx, _| async move { + reshard_try_stream(ctx, stream::repeat(BA8::ZERO).map(Ok), |_, _, _| { + ShardIndex::FIRST + }) + .await + .unwrap() + }) + .await; + }); + } + + #[test] + fn reshard_try_stream_err() { + run(|| async move { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest(Vec::::new().into_iter(), |ctx, _| async move { + let err = reshard_try_stream( + ctx, + stream::iter(vec![ + Ok(BA8::ZERO), + Err(crate::error::Error::InconsistentShares), + ]), + |_, _, _| ShardIndex::FIRST, + ) + .await + .unwrap_err(); + assert!(matches!(err, crate::error::Error::InconsistentShares)); + }) + .await; + }); + } + #[test] fn prss_one_side() { run(|| async { diff --git a/ipa-core/src/protocol/context/semi_honest.rs b/ipa-core/src/protocol/context/semi_honest.rs index 1be359879..bd8c2e260 100644 --- a/ipa-core/src/protocol/context/semi_honest.rs +++ b/ipa-core/src/protocol/context/semi_honest.rs @@ -16,9 +16,10 @@ use crate::{ }, protocol::{ context::{ - dzkp_validator::SemiHonestDZKPValidator, upgrade::Upgradable, - validator::SemiHonest as Validator, Base, InstrumentedIndexedSharedRandomness, - InstrumentedSequentialSharedRandomness, ShardedContext, SpecialAccessToUpgradedContext, + dzkp_validator::SemiHonestDZKPValidator, step::MaliciousProtocolStep, + upgrade::Upgradable, validator::SemiHonest as Validator, Base, Context as _, + InstrumentedIndexedSharedRandomness, InstrumentedSequentialSharedRandomness, + MaliciousProtocolSteps, ShardedContext, SpecialAccessToUpgradedContext, UpgradableContext, UpgradedContext, }, prss::Endpoint as PrssEndpoint, @@ -151,13 +152,21 @@ impl<'a, B: ShardBinding> UpgradableContext for Context<'a, B> { type Validator = Validator<'a, B, F>; fn validator(self) -> Self::Validator { - Self::Validator::new(self.inner) + Self::Validator::new(self.inner.narrow(&MaliciousProtocolStep::MaliciousProtocol)) } type DZKPValidator = SemiHonestDZKPValidator<'a, B>; - fn dzkp_validator(self, _max_multiplications_per_gate: usize) -> Self::DZKPValidator { - Self::DZKPValidator::new(self.inner) + fn dzkp_validator( + self, + steps: MaliciousProtocolSteps, + _max_multiplications_per_gate: usize, + ) -> Self::DZKPValidator + where + S: ipa_step::Step + ?Sized, + Gate: StepNarrow, + { + Self::DZKPValidator::new(self.inner.narrow(steps.protocol)) } } @@ -293,14 +302,14 @@ impl Debug for Upgraded<'_, B, F> { } #[async_trait] -impl<'a, V: ExtendableField + Vectorizable, const N: usize> - Upgradable> for Replicated +impl<'a, V: ExtendableField + Vectorizable, B: ShardBinding, const N: usize> + Upgradable> for Replicated { type Output = Replicated; async fn upgrade( self, - _context: Upgraded<'a, NotSharded, V>, + _context: Upgraded<'a, B, V>, _record_id: RecordId, ) -> Result { Ok(self) diff --git a/ipa-core/src/protocol/context/step.rs b/ipa-core/src/protocol/context/step.rs index a3fcdab02..24a8872be 100644 --- a/ipa-core/src/protocol/context/step.rs +++ b/ipa-core/src/protocol/context/step.rs @@ -2,7 +2,7 @@ use ipa_step_derive::CompactStep; /// Upgrades all use this step to distinguish protocol steps from the step that is used to upgrade inputs. #[derive(CompactStep)] -#[step(name = "upgrade")] +#[step(name = "upgrade", child = crate::protocol::basics::mul::step::MaliciousMultiplyStep)] pub(crate) struct UpgradeStep; /// Steps used by the validation component of malicious protocol execution. @@ -10,8 +10,10 @@ pub(crate) struct UpgradeStep; #[derive(CompactStep)] pub(crate) enum MaliciousProtocolStep { /// For the execution of the malicious protocol. + #[step(child = crate::protocol::ipa_prf::step::PrfStep)] MaliciousProtocol, /// The final validation steps. + #[step(child = ValidateStep)] Validate, } @@ -22,23 +24,25 @@ pub(crate) enum ValidateStep { /// Reveal the value of `r`, necessary for validation. RevealR, /// Check that there is no disagreement between accumulated values. + #[step(child = crate::protocol::basics::step::CheckZeroStep)] CheckZero, } -/// Steps used by the validation component of the DZKP #[derive(CompactStep)] -pub(crate) enum ZeroKnowledgeProofValidateStep { - /// For the execution of the malicious protocol. - DZKPMaliciousProtocol, - /// Step for computing `p * q` between proof verifiers - PTimesQ, - /// Step for producing challenge between proof verifiers - Challenge, - /// Steps for validating the DZK proofs for each batch. - #[step(count = 256)] - DZKPValidate(usize), +pub(crate) enum DzkpValidationProtocolStep { /// Step for proof generation GenerateProof, + /// Step for producing challenge between proof verifiers + Challenge, /// Step for proof verification + #[step(child = DzkpProofVerifyStep)] VerifyProof, } + +#[derive(CompactStep)] +pub(crate) enum DzkpProofVerifyStep { + /// Step for computing `p * q` between proof verifiers + PTimesQ, + /// Step for computing `G_diff` between proof verifiers + Diff, +} diff --git a/ipa-core/src/protocol/context/validator.rs b/ipa-core/src/protocol/context/validator.rs index 63a212a2c..a71b395c3 100644 --- a/ipa-core/src/protocol/context/validator.rs +++ b/ipa-core/src/protocol/context/validator.rs @@ -199,45 +199,45 @@ impl MaliciousAccumulator { /// When batch is validated, `r` is revealed and can never be /// used again. In fact, it gets out of scope after successful validation /// so no code can get access to it. -pub struct BatchValidator<'a, F: ExtendableField> { - batches_ref: Arc>, - protocol_ctx: MaliciousContext<'a>, +pub struct BatchValidator<'a, F: ExtendableField, B: ShardBinding> { + batches_ref: Arc>, + protocol_ctx: MaliciousContext<'a, B>, } -impl<'a, F: ExtendableField> BatchValidator<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> BatchValidator<'a, F, B> { /// Create a new validator for malicious context. /// /// ## Panics /// If total records is not set. #[must_use] - pub fn new(ctx: MaliciousContext<'a>) -> Self { + pub fn new(ctx: MaliciousContext<'a, B>) -> Self { let TotalRecords::Specified(total_records) = ctx.total_records() else { panic!("Total records must be specified before creating the validator"); }; // TODO: Right now we set the batch work to be equal to active_work, // but it does not need to be. We can make this configurable if needed. - let records_per_batch = ctx.active_work().get().min(total_records.get()); + let records_per_batch = ctx.active_work().get(); Self { protocol_ctx: ctx.narrow(&Step::MaliciousProtocol), - batches_ref: Batcher::new( + batches_ref: Arc::new(Batcher::new( records_per_batch, total_records, Box::new(move |batch_index| Malicious::new(ctx.clone(), batch_index)), - ), + )), } } } -pub struct Malicious<'a, F: ExtendableField> { +pub struct Malicious<'a, F: ExtendableField, B: ShardBinding> { r_share: Replicated, pub(super) accumulator: MaliciousAccumulator, - validate_ctx: Base<'a>, + validate_ctx: Base<'a, B>, offset: usize, } -impl Malicious<'_, F> { +impl Malicious<'_, F, B> { /// ## Errors /// If the two information theoretic MACs are not equal (after multiplying by `r`), this indicates that one of the parties /// must have launched an additive attack. At this point the honest parties should abort the protocol. This method throws an @@ -294,21 +294,21 @@ impl Malicious<'_, F> { } } -impl<'a, F> Validator for BatchValidator<'a, F> +impl<'a, F, B: ShardBinding> Validator for BatchValidator<'a, F, B> where F: ExtendableField, { - type Context = UpgradedMaliciousContext<'a, F>; + type Context = UpgradedMaliciousContext<'a, F, B>; fn context(&self) -> Self::Context { UpgradedMaliciousContext::new(&self.batches_ref, self.protocol_ctx.clone()) } } -impl<'a, F: ExtendableField> Malicious<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> Malicious<'a, F, B> { #[must_use] #[allow(clippy::needless_pass_by_value)] - pub fn new(ctx: MaliciousContext<'a>, offset: usize) -> Self { + pub fn new(ctx: MaliciousContext<'a, B>, offset: usize) -> Self { // Each invocation requires 3 calls to PRSS to generate the state. // Validation occurs in batches and `offset` indicates which batch // we're in right now. @@ -386,7 +386,7 @@ impl<'a, F: ExtendableField> Malicious<'a, F> { } } -impl Debug for Malicious<'_, F> { +impl Debug for Malicious<'_, F, B> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "MaliciousValidator<{:?}>", type_name::()) } diff --git a/ipa-core/src/protocol/dp/mod.rs b/ipa-core/src/protocol/dp/mod.rs index a44b73e6d..856073b77 100644 --- a/ipa-core/src/protocol/dp/mod.rs +++ b/ipa-core/src/protocol/dp/mod.rs @@ -15,12 +15,16 @@ use crate::{ helpers::{query::DpMechanism, Direction, Role, TotalRecords}, protocol::{ boolean::step::ThirtyTwoBitStep, - context::{dzkp_validator::DZKPValidator, Context, DZKPUpgraded, UpgradableContext}, + context::{ + dzkp_validator::DZKPValidator, Context, DZKPUpgraded, MaliciousProtocolSteps, + UpgradableContext, + }, dp::step::{ApplyDpNoise, DPStep}, ipa_prf::{ aggregation::{aggregate_values, aggregate_values_proof_chunk}, boolean_ops::addition_sequential::integer_add, oprf_padding::insecure::OPRFPaddingDp, + step::IpaPrfStep, }, prss::{FromPrss, SharedRandomness}, BooleanProtocols, RecordId, @@ -167,7 +171,7 @@ where let aggregation_input = Box::pin(stream::iter(vector_input_to_agg.into_iter()).map(Ok)); // Step 3: Call `aggregate_values` to sum up Bernoulli noise. let noise_vector: Result>, Error> = - aggregate_values::<_, OV, B>(ctx, aggregation_input, num_bernoulli).await; + aggregate_values::<_, OV, B>(ctx, aggregation_input, num_bernoulli, None).await; noise_vector } /// `apply_dp_noise` takes the noise distribution parameters (`num_bernoulli` and in the future `quantization_scale`) @@ -224,6 +228,7 @@ where /// # Panics /// may panic from asserts down in `gen_binomial_noise` /// +#[allow(clippy::too_many_lines)] pub async fn dp_for_histogram( ctx: C, histogram_bin_values: BitDecomposed>, @@ -240,6 +245,10 @@ where BitDecomposed>: for<'a> TransposeFrom<&'a [AdditiveShare; B], Error = Infallible>, { + let steps = MaliciousProtocolSteps { + protocol: &IpaPrfStep::DifferentialPrivacy, + validate: &IpaPrfStep::DifferentialPrivacyValidate, + }; match dp_params { DpMechanism::NoDp => Ok(Vec::transposed_from(&histogram_bin_values)?), DpMechanism::Binomial { epsilon } => { @@ -286,7 +295,8 @@ where "num_bernoulli of {num_bernoulli} may result in excessively large DZKP" ); } - let dp_validator = ctx.dzkp_validator(num_bernoulli); + + let dp_validator = ctx.dzkp_validator(steps, num_bernoulli); let noisy_histogram = apply_dp_noise::<_, B, OV>( dp_validator.context(), @@ -328,7 +338,7 @@ where OV::BITS, ); - let dp_validator = ctx.dzkp_validator(1); + let dp_validator = ctx.dzkp_validator(steps, 1); let noised_output = apply_laplace_noise_pass::<_, OV, B>( &dp_validator.context().narrow(&DPStep::LaplacePass1), @@ -543,7 +553,7 @@ fn delta_constraint(num_bernoulli: u32, noise_params: &NoiseParams) -> bool { lhs >= rhs } /// error of mechanism in Thm 1 -#[allow(dead_code)] +#[cfg(all(test, unit_test))] fn error(num_bernoulli: u32, noise_params: &NoiseParams) -> f64 { noise_params.dimensions * noise_params.quantization_scale.powi(2) diff --git a/ipa-core/src/protocol/hybrid/mod.rs b/ipa-core/src/protocol/hybrid/mod.rs new file mode 100644 index 000000000..482f6e939 --- /dev/null +++ b/ipa-core/src/protocol/hybrid/mod.rs @@ -0,0 +1,79 @@ +pub(crate) mod step; + +use crate::{ + error::Error, + ff::{ + boolean_array::{BooleanArray, BA5, BA8}, + U128Conversions, + }, + helpers::query::DpMechanism, + protocol::{ + context::{ShardedContext, UpgradableContext}, + ipa_prf::{oprf_padding::PaddingParameters, shuffle::Shuffle}, + }, + report::hybrid::IndistinguishableHybridReport, + secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, +}; + +// In theory, we could support (runtime-configured breakdown count) ≤ (compile-time breakdown count) +// ≤ 2^|bk|, with all three values distinct, but at present, there is no runtime configuration and +// the latter two must be equal. The implementation of `move_single_value_to_bucket` does support a +// runtime-specified count via the `breakdown_count` parameter, and implements a runtime check of +// its value. +// +// It would usually be more appropriate to make `MAX_BREAKDOWNS` an associated constant rather than +// a const parameter. However, we want to use it to enforce a correct pairing of the `BK` type +// parameter and the `B` const parameter, and specifying a constraint like +// `BreakdownKey` on an associated constant is not currently supported. (Nor is +// supplying an associated constant `::MAX_BREAKDOWNS` as the value of a const +// parameter.) Structured the way we have it, it probably doesn't make sense to use the +// `BreakdownKey` trait in places where the `B` const parameter is not already available. +// +// These could be imported from src/protocl/ipa_prf/mod.rs +// however we've copy/pasted them here with the intention of deleting that file [TODO] +pub trait BreakdownKey: BooleanArray + U128Conversions {} +impl BreakdownKey<32> for BA5 {} +impl BreakdownKey<256> for BA8 {} + +/// The Hybrid Protocol +/// +/// This protocol takes in a [`Vec>`] +/// and aggregates it into a summary report. `HybridReport`s are either +/// impressions or conversion. The protocol joins these based on their matchkeys, +/// sums the values from conversions grouped by the breakdown key on impressions. +/// To accomplish this, hte protocol performs the follwoing steps +/// 1. Generates a random number of "dummy records" (needed to mask the information that will +/// be revealed in step 4, and thereby provide a differential privacy guarantee on +/// that information leakage) +/// 2. Shuffles the input +/// 3. Computes an OPRF of these elliptic curve points and reveals this "pseudonym" +/// 4. Groups together rows with the same OPRF and sums both the breakdown keys and values. +/// 5. Generates a random number of "dummy records" (needed to mask the information that will +/// be revealed in step 7) +/// 6. Shuffles the input +/// 7. Reveals breakdown keys +/// 8. Sums the values by breakdown keys +/// 9. Adds random noise to the total value for each breakdown key (to provide a +/// differential privacy guarantee) +/// +/// # Errors +/// Propagates errors from config issues or while running the protocol +/// # Panics +/// Propagates errors from config issues or while running the protocol +pub async fn hybrid_protocol<'ctx, C, BK, V, HV, const SS_BITS: usize, const B: usize>( + _ctx: C, + input_rows: Vec>, + _dp_params: DpMechanism, + _dp_padding_params: PaddingParameters, +) -> Result>, Error> +where + C: UpgradableContext + 'ctx + Shuffle + ShardedContext, + BK: BreakdownKey, + V: BooleanArray + U128Conversions, + HV: BooleanArray + U128Conversions, +{ + if input_rows.is_empty() { + return Ok(vec![Replicated::ZERO; B]); + } + unimplemented!("protocol::hybrid::hybrid_protocol is not fully implemented") +} diff --git a/ipa-core/src/protocol/hybrid/step.rs b/ipa-core/src/protocol/hybrid/step.rs new file mode 100644 index 000000000..5de0051be --- /dev/null +++ b/ipa-core/src/protocol/hybrid/step.rs @@ -0,0 +1,6 @@ +use ipa_step_derive::CompactStep; + +#[derive(CompactStep)] +pub(crate) enum HybridStep { + ReshardByTag, +} diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index b08a3ece9..b0b17396a 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -1,12 +1,9 @@ -use std::{ - convert::Infallible, - pin::{pin, Pin}, -}; +use std::{convert::Infallible, pin::pin}; -use futures::{stream, Stream}; +use futures::stream; use futures_util::{StreamExt, TryStreamExt}; -use super::{aggregate_values, AggResult}; +use super::aggregate_values; use crate::{ error::{Error, UnwrapInfallible}, ff::{ @@ -16,10 +13,15 @@ use crate::{ }, helpers::TotalRecords, protocol::{ - basics::semi_honest_reveal, - context::Context, + basics::{reveal, Reveal}, + context::{ + dzkp_validator::DZKPValidator, Context, DZKPUpgraded, MaliciousProtocolSteps, + UpgradableContext, + }, ipa_prf::{ - aggregation::step::AggregationStep, + aggregation::{ + aggregate_values_proof_chunk, step::AggregationStep as Step, AGGREGATE_DEPTH, + }, oprf_padding::{apply_dp_padding, PaddingParameters}, prf_sharding::{AttributionOutputs, SecretSharedAttributionOutputs}, shuffle::shuffle_attribution_outputs, @@ -29,7 +31,7 @@ use crate::{ }, secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, FieldSimd, - TransposeFrom, + TransposeFrom, Vectorizable, }, seq_join::seq_join, }; @@ -48,34 +50,97 @@ use crate::{ /// 2. Reveal breakdown keys. This is the key difference to the previous /// aggregation (see [`reveal_breakdowns`]). /// 3. Add all values for each breakdown. +/// +/// This protocol explicitly manages proof batches for DZKP-based malicious security by +/// processing chunks of values from `intermediate_results.chunks()`. Procession +/// through record IDs is not uniform for all of the gates in the protocol. The first +/// layer of the reduction adds N pairs of records, the second layer adds N/2 pairs of +/// records, etc. This has a few consequences: +/// * We must specify a batch size of `usize::MAX` when calling `dzkp_validator`. +/// * We must track record IDs across chunks, so that subsequent chunks can +/// start from the last record ID that was used in the previous chunk. +/// * Because the first record ID in the proof batch is set implicitly, we must +/// guarantee that it submits multiplication intermediates before any other +/// record. This is currently ensured by the serial operation of the aggregation +/// protocol (i.e. by not using `seq_join`). +#[tracing::instrument(name = "breakdown_reveal_aggregation", skip_all, fields(total = attributed_values.len()))] pub async fn breakdown_reveal_aggregation( ctx: C, attributed_values: Vec>, + padding_params: &PaddingParameters, ) -> Result>, Error> where - C: Context, + C: UpgradableContext, Boolean: FieldSimd, - Replicated: BooleanProtocols, + Replicated: BooleanProtocols, B>, BK: BreakdownKey, + Replicated: Reveal, Output = >::Array>, TV: BooleanArray + U128Conversions, HV: BooleanArray + U128Conversions, BitDecomposed>: for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, { - let dp_padding_params = PaddingParameters::default(); // Apply DP padding for Breakdown Reveal Aggregation let attributed_values_padded = apply_dp_padding::<_, AttributionOutputs, Replicated>, B>( - ctx.narrow(&AggregationStep::PaddingDp), + ctx.narrow(&Step::PaddingDp), attributed_values, - dp_padding_params, + padding_params, ) .await?; - let attributions = shuffle_attributions(&ctx, attributed_values_padded).await?; - let grouped_tvs = reveal_breakdowns(&ctx, attributions).await?; - let num_rows = grouped_tvs.max_len; - aggregate_values::<_, HV, B>(ctx, grouped_tvs.into_stream(), num_rows).await + let attributions = shuffle_attributions::<_, BK, TV, B>(&ctx, attributed_values_padded).await?; + // Revealing the breakdowns doesn't do any multiplies, so won't make it as far as + // doing a proof, but we need the validator to obtain an upgraded malicious context. + let validator = ctx.clone().dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::Reveal, + validate: &Step::RevealValidate, + }, + usize::MAX, + ); + let grouped_tvs = reveal_breakdowns(&validator.context(), attributions).await?; + validator.validate().await?; + let mut intermediate_results: Vec>> = grouped_tvs.into(); + + // Any real-world aggregation should be able to complete in two layers (two + // iterations of the `while` loop below). Tests with small `TARGET_PROOF_SIZE` + // may exceed that. + let mut chunk_counter = 0; + let mut depth = 0; + let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()); + + while intermediate_results.len() > 1 { + let mut record_ids = [RecordId::FIRST; AGGREGATE_DEPTH]; + let mut next_intermediate_results = Vec::new(); + for chunk in intermediate_results.chunks(agg_proof_chunk) { + let chunk_len = chunk.len(); + let validator = ctx.clone().dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::aggregate(depth), + validate: &Step::AggregateValidate, + }, + usize::MAX, // See note about batching above. + ); + let result = aggregate_values::<_, HV, B>( + validator.context(), + stream::iter(chunk).map(|v| Ok(v.clone())).boxed(), + chunk_len, + Some(&mut record_ids), + ) + .await?; + validator.validate_indexed(chunk_counter).await?; + chunk_counter += 1; + next_intermediate_results.push(result); + } + depth += 1; + intermediate_results = next_intermediate_results; + } + + Ok(intermediate_results + .into_iter() + .next() + .expect("aggregation input must not be empty")) } /// Shuffles attribution Breakdown key and Trigger Value secret shares. Input @@ -92,7 +157,7 @@ where BK: BreakdownKey, TV: BooleanArray + U128Conversions, { - let shuffle_ctx = parent_ctx.narrow(&AggregationStep::Shuffle); + let shuffle_ctx = parent_ctx.narrow(&Step::Shuffle); shuffle_attribution_outputs::<_, BK, TV, BA64>(shuffle_ctx, contribs).await } @@ -114,25 +179,17 @@ where Replicated: BooleanProtocols, Boolean: FieldSimd, BK: BreakdownKey, + Replicated: Reveal>::Array>, TV: BooleanArray + U128Conversions, { - let reveal_ctx = parent_ctx - .narrow(&AggregationStep::RevealStep) - .set_total_records(TotalRecords::specified(attributions.len())?); + let reveal_ctx = parent_ctx.set_total_records(TotalRecords::specified(attributions.len())?); let reveal_work = stream::iter(attributions).enumerate().map(|(i, ao)| { let record_id = RecordId::from(i); let reveal_ctx = reveal_ctx.clone(); async move { - let revealed_bk = semi_honest_reveal( - reveal_ctx, - record_id, - None, - &ao.attributed_breakdown_key_bits, - ) - .await? - // Full reveal is used, meaning it is not possible to return None here - .unwrap(); + let revealed_bk = + reveal(reveal_ctx, record_id, &ao.attributed_breakdown_key_bits).await?; let revealed_bk = BK::from_array(&revealed_bk); let Ok(bk) = usize::try_from(revealed_bk.as_u128()) else { return Err(Error::Internal); @@ -171,22 +228,27 @@ impl GroupedTriggerValues { self.max_len = self.tvs[bk].len(); } } +} - fn into_stream<'fut>(mut self) -> Pin> + Send + 'fut>> - where - Boolean: FieldSimd, - BitDecomposed>: - for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, - { - let iter = (0..self.max_len).map(move |_| { - let slice: [Replicated; B] = self +impl From> + for Vec>> +where + Boolean: FieldSimd, + BitDecomposed>: + for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, +{ + fn from( + mut grouped_tvs: GroupedTriggerValues, + ) -> Vec>> { + let iter = (0..grouped_tvs.max_len).map(move |_| { + let slice: [Replicated; B] = grouped_tvs .tvs .each_mut() .map(|tv| tv.pop().unwrap_or(Replicated::ZERO)); - Ok(BitDecomposed::transposed_from(&slice).unwrap_infallible()) + BitDecomposed::transposed_from(&slice).unwrap_infallible() }); - Box::pin(stream::iter(iter)) + iter.collect() } } @@ -195,6 +257,8 @@ pub mod tests { use futures::TryFutureExt; use rand::{seq::SliceRandom, Rng}; + #[cfg(not(feature = "shuttle"))] + use crate::{ff::boolean_array::BA16, test_executor::run}; use crate::{ ff::{ boolean::Boolean, @@ -203,12 +267,13 @@ pub mod tests { }, protocol::ipa_prf::{ aggregation::breakdown_reveal::breakdown_reveal_aggregation, + oprf_padding::PaddingParameters, prf_sharding::{AttributionOutputsTestInput, SecretSharedAttributionOutputs}, }, secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom, }, - test_executor::run, + test_executor::run_with, test_fixture::{Reconstruct, Runner, TestWorld}, }; @@ -222,7 +287,11 @@ pub mod tests { #[test] fn semi_honest_happy_path() { - run(|| async { + // if shuttle executor is enabled, run this test only once. + // it is a very expensive test to explore all possible states, + // sometimes github bails after 40 minutes of running it + // (workers there are really slow). + run_with::<_, _, 3>(|| async { let world = TestWorld::default(); let mut rng = rand::thread_rng(); let mut expectation = Vec::new(); @@ -242,7 +311,7 @@ pub mod tests { } inputs.shuffle(&mut rng); let result: Vec<_> = world - .upgraded_semi_honest(inputs.into_iter(), |ctx, input_rows| async move { + .semi_honest(inputs.into_iter(), |ctx, input_rows| async move { let aos = input_rows .into_iter() .map(|ti| SecretSharedAttributionOutputs { @@ -251,12 +320,16 @@ pub mod tests { }) .collect(); let r: Vec> = - breakdown_reveal_aggregation::<_, BA5, BA3, BA8, 32>(ctx, aos) - .map_ok(|d: BitDecomposed>| { - Vec::transposed_from(&d).unwrap() - }) - .await - .unwrap(); + breakdown_reveal_aggregation::<_, BA5, BA3, BA8, 32>( + ctx, + aos, + &PaddingParameters::relaxed(), + ) + .map_ok(|d: BitDecomposed>| { + Vec::transposed_from(&d).unwrap() + }) + .await + .unwrap(); r }) .await @@ -266,4 +339,58 @@ pub mod tests { assert_eq!(result, expectation); }); } + + #[test] + #[cfg(not(feature = "shuttle"))] // too slow + fn malicious_happy_path() { + type HV = BA16; + run(|| async { + let world = TestWorld::default(); + let mut rng = rand::thread_rng(); + let mut expectation = Vec::new(); + for _ in 0..32 { + expectation.push(rng.gen_range(0u128..512)); + } + // The size of input needed here to get complete coverage (more precisely, + // the size of input to the final aggregation using `aggregate_values`) + // depends on `TARGET_PROOF_SIZE`. + let expectation = expectation; // no more mutability for safety + let mut inputs = Vec::new(); + for (bk, expected_hv) in expectation.iter().enumerate() { + let mut remainder = *expected_hv; + while remainder > 7 { + let tv = rng.gen_range(0u128..8); + remainder -= tv; + inputs.push(input_row(bk, tv)); + } + inputs.push(input_row(bk, remainder)); + } + inputs.shuffle(&mut rng); + let result: Vec<_> = world + .malicious(inputs.into_iter(), |ctx, input_rows| async move { + let aos = input_rows + .into_iter() + .map(|ti| SecretSharedAttributionOutputs { + attributed_breakdown_key_bits: ti.0, + capped_attributed_trigger_value: ti.1, + }) + .collect(); + breakdown_reveal_aggregation::<_, BA5, BA3, HV, 32>( + ctx, + aos, + &PaddingParameters::relaxed(), + ) + .map_ok(|d: BitDecomposed>| { + Vec::transposed_from(&d).unwrap() + }) + .await + .unwrap() + }) + .await + .reconstruct(); + let result = result.iter().map(|v: &HV| v.as_u128()).collect::>(); + assert_eq!(32, result.len()); + assert_eq!(result, expectation); + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/bucket.rs b/ipa-core/src/protocol/ipa_prf/aggregation/bucket.rs deleted file mode 100644 index dea77c2f5..000000000 --- a/ipa-core/src/protocol/ipa_prf/aggregation/bucket.rs +++ /dev/null @@ -1,287 +0,0 @@ -use embed_doc_image::embed_doc_image; - -use crate::{ - error::Error, - ff::boolean::Boolean, - helpers::repeat_n, - protocol::{ - basics::SecureMul, boolean::and::bool_and_8_bit, context::Context, - ipa_prf::aggregation::step::BucketStep, RecordId, - }, - secret_sharing::{replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd}, -}; - -const MAX_BREAKDOWNS: usize = 512; // constrained by the compact step ability to generate dynamic steps - -#[derive(thiserror::Error, Debug)] -pub enum MoveToBucketError { - #[error("Bad value for the breakdown key: {0}")] - InvalidBreakdownKey(String), -} - -impl From for Error { - fn from(error: MoveToBucketError) -> Self { - match error { - e @ MoveToBucketError::InvalidBreakdownKey(_) => { - Error::InvalidQueryParameter(Box::new(e)) - } - } - } -} - -#[embed_doc_image("tree-aggregation", "images/tree_aggregation.png")] -/// This function moves a single value to a correct bucket using tree aggregation approach -/// -/// Here is how it works -/// The combined value, [`value`] forms the root of a binary tree as follows: -/// ![Tree propagation][tree-aggregation] -/// -/// This value is propagated through the tree, with each subsequent iteration doubling the number of multiplications. -/// In the first round, r=BK-1, multiply the most significant bit ,[`bd_key`]_r by the value to get [`bd_key`]_r.[`value`]. From that, -/// produce [`row_contribution`]_r,0 =[`value`]-[`bd_key`]_r.[`value`] and [`row_contribution`]_r,1=[`bd_key`]_r.[`value`]. -/// This takes the most significant bit of `bd_key` and places value in one of the two child nodes of the binary tree. -/// At each successive round, the next most significant bit is propagated from the leaf nodes of the tree into further leaf nodes: -/// [`row_contribution`]_r+1,q,0 =[`row_contribution`]_r,q - [`bd_key`]_r+1.[`row_contribution`]_r,q and [`row_contribution`]_r+1,q,1 =[`bd_key`]_r+1.[`row_contribution`]_r,q. -/// The work of each iteration therefore doubles relative to the one preceding. -/// -/// In case a malicious entity sends a out of range breakdown key (i.e. greater than the max count) to this function, we need to do some -/// extra processing to ensure contribution doesn't end up in a wrong bucket. However, this requires extra multiplications. -/// This would potentially not be needed in IPA (as the breakdown key is provided by the report collector, so a bad value only spoils their own result) but useful for PAM. -/// This can be by passing `robust` as true. -/// -/// ## Errors -/// If `breakdown_count` does not fit into `BK` bits or greater than or equal to $2^9$ -#[allow(dead_code)] -pub async fn move_single_value_to_bucket( - ctx: C, - record_id: RecordId, - bd_key: BitDecomposed>, - value: BitDecomposed>, - breakdown_count: usize, - robust: bool, -) -> Result>>, Error> -where - C: Context, - Boolean: FieldSimd, - AdditiveShare: SecureMul, -{ - let mut step: usize = 1 << bd_key.len(); - - if breakdown_count > step { - Err(MoveToBucketError::InvalidBreakdownKey(format!( - "Asking for more buckets ({breakdown_count}) than bits in the breakdown key ({}) allow", - bd_key.len() - )))?; - } - - if breakdown_count > MAX_BREAKDOWNS { - Err(MoveToBucketError::InvalidBreakdownKey( - "Our step implementation (BucketStep) cannot go past {MAX_BREAKDOWNS} breakdown keys" - .to_string(), - ))?; - } - - let mut row_contribution = vec![value; breakdown_count]; - - // To move a value to one of 2^bd_key_bits buckets requires 2^bd_key_bits - 1 multiplications - // They happen in a tree like fashion: - // 1 multiplication for the first bit - // 2 for the second bit - // 4 for the 3rd bit - // And so on. Simply ordering them sequentially is a functional way - // of enumerating them without creating more step transitions than necessary - let mut multiplication_channel = 0; - - for bit_of_bdkey in bd_key.iter().rev() { - let span = step >> 1; - if !robust && span > breakdown_count { - step = span; - continue; - } - - let contributions = ctx - .parallel_join((0..breakdown_count).step_by(step).enumerate().filter_map( - |(i, tree_index)| { - let bucket_c = ctx.narrow(&BucketStep::from(multiplication_channel + i)); - - let index_contribution = &row_contribution[tree_index]; - - (robust || tree_index + span < breakdown_count).then(|| { - bool_and_8_bit( - bucket_c, - record_id, - index_contribution, - repeat_n(bit_of_bdkey, index_contribution.len()), - ) - }) - }, - )) - .await?; - multiplication_channel += contributions.len(); - - for (index, bdbit_contribution) in contributions.into_iter().enumerate() { - let left_index = index * step; - let right_index = left_index + span; - - // bdbit_contribution is either zero or equal to row_contribution. So it - // is okay to do a carryless "subtraction" here. - for (r, b) in row_contribution[left_index] - .iter_mut() - .zip(bdbit_contribution.iter()) - { - *r -= b; - } - if right_index < breakdown_count { - for (r, b) in row_contribution[right_index] - .iter_mut() - .zip(bdbit_contribution) - { - *r = b; - } - } - } - step = span; - } - Ok(row_contribution) -} - -#[cfg(all(test, unit_test))] -pub mod tests { - use rand::thread_rng; - - use super::move_single_value_to_bucket; - use crate::{ - ff::{boolean::Boolean, boolean_array::BA8, Gf8Bit, Gf9Bit, U128Conversions}, - protocol::{context::Context, RecordId}, - rand::Rng, - secret_sharing::{BitDecomposed, SharedValue}, - test_executor::run, - test_fixture::{Reconstruct, Runner, TestWorld}, - }; - - const MAX_BREAKDOWN_COUNT: usize = 256; - const VALUE: u32 = 10; - - async fn move_to_bucket(count: usize, breakdown_key: usize, robust: bool) -> Vec { - let breakdown_key_bits = BitDecomposed::decompose(Gf8Bit::BITS, |i| { - Boolean::from((breakdown_key >> i) & 1 == 1) - }); - let value = - BitDecomposed::decompose(Gf8Bit::BITS, |i| Boolean::from((VALUE >> i) & 1 == 1)); - - TestWorld::default() - .semi_honest( - (breakdown_key_bits, value), - |ctx, (breakdown_key_share, value_share)| async move { - move_single_value_to_bucket::<_, 1>( - ctx.set_total_records(1), - RecordId::from(0), - breakdown_key_share, - value_share, - count, - robust, - ) - .await - .unwrap() - }, - ) - .await - .reconstruct() - .into_iter() - .map(|val| val.into_iter().collect()) - .collect() - } - - #[test] - fn semi_honest_move_in_range() { - run(|| async move { - let mut rng = thread_rng(); - let count = rng.gen_range(1..MAX_BREAKDOWN_COUNT); - let breakdown_key = rng.gen_range(0..count); - let mut expected = vec![BA8::ZERO; count]; - expected[breakdown_key] = BA8::truncate_from(VALUE); - - let result = move_to_bucket(count, breakdown_key, false).await; - assert_eq!(result, expected, "expected value at index {breakdown_key}"); - }); - } - - #[test] - fn semi_honest_move_in_range_robust() { - run(|| async move { - let mut rng = thread_rng(); - let count = rng.gen_range(1..MAX_BREAKDOWN_COUNT); - let breakdown_key = rng.gen_range(0..count); - let mut expected = vec![BA8::ZERO; count]; - expected[breakdown_key] = BA8::truncate_from(VALUE); - - let result = move_to_bucket(count, breakdown_key, true).await; - assert_eq!(result, expected, "expected value at index {breakdown_key}"); - }); - } - - #[test] - fn semi_honest_move_out_of_range() { - run(move || async move { - let mut rng: rand::rngs::ThreadRng = thread_rng(); - let count = rng.gen_range(2..MAX_BREAKDOWN_COUNT - 1); - let breakdown_key = rng.gen_range(count..MAX_BREAKDOWN_COUNT); - - let result = move_to_bucket(count, breakdown_key, false).await; - assert_eq!(result.len(), count); - assert_eq!( - result.into_iter().fold(0, |acc, v| acc + v.as_u128()), - u128::from(VALUE) - ); - }); - } - - #[test] - fn semi_honest_move_out_of_range_robust() { - run(move || async move { - let mut rng: rand::rngs::ThreadRng = thread_rng(); - let count = rng.gen_range(2..MAX_BREAKDOWN_COUNT - 1); - let breakdown_key = rng.gen_range(count..MAX_BREAKDOWN_COUNT); - - let result = move_to_bucket(count, breakdown_key, true).await; - assert_eq!(result.len(), count); - assert!(result.into_iter().all(|x| x == BA8::ZERO)); - }); - } - - #[test] - #[should_panic(expected = "Asking for more buckets")] - fn move_out_of_range_too_many_buckets_type() { - run(move || async move { - _ = move_to_bucket(MAX_BREAKDOWN_COUNT + 1, 0, false).await; - }); - } - - #[test] - #[should_panic(expected = "Asking for more buckets")] - fn move_out_of_range_too_many_buckets_steps() { - run(move || async move { - let breakdown_key_bits = BitDecomposed::decompose(Gf9Bit::BITS, |_| Boolean::FALSE); - let value = - BitDecomposed::decompose(Gf8Bit::BITS, |i| Boolean::from((VALUE >> i) & 1 == 1)); - - _ = TestWorld::default() - .semi_honest( - (breakdown_key_bits, value), - |ctx, (breakdown_key_share, value_share)| async move { - move_single_value_to_bucket::<_, 1>( - ctx.set_total_records(1), - RecordId::from(0), - breakdown_key_share, - value_share, - 513, - false, - ) - .await - .unwrap() - }, - ) - .await; - }); - } -} diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index 9248e5a65..a5ca281e6 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -1,30 +1,23 @@ -use std::{any::type_name, convert::Infallible, iter, pin::Pin}; +use std::{any::type_name, cmp::max, iter, pin::Pin}; -use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use tracing::Instrument; -use typenum::Const; use crate::{ - error::{Error, LengthError, UnwrapInfallible}, + error::{Error, LengthError}, ff::{boolean::Boolean, boolean_array::BooleanArray, U128Conversions}, helpers::{ - stream::{ - div_round_up, process_stream_by_chunks, ChunkBuffer, FixedLength, TryFlattenItersExt, - }, + stream::{ChunkBuffer, FixedLength}, TotalRecords, }, protocol::{ - basics::{BooleanArrayMul, BooleanProtocols, SecureMul}, + basics::BooleanProtocols, boolean::{step::ThirtyTwoBitStep, NBitStep}, - context::{ - dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, - Context, DZKPContext, UpgradableContext, - }, + context::{dzkp_validator::TARGET_PROOF_SIZE, Context}, ipa_prf::{ - aggregation::step::{AggregateChunkStep, AggregateValuesStep, AggregationStep as Step}, + aggregation::step::{AggregateChunkStep, AggregateValuesStep}, boolean_ops::addition_sequential::{integer_add, integer_sat_add}, - prf_sharding::{AttributionOutputs, SecretSharedAttributionOutputs}, - BreakdownKey, AGG_CHUNK, + prf_sharding::AttributionOutputs, }, RecordId, }, @@ -35,7 +28,6 @@ use crate::{ }; pub(crate) mod breakdown_reveal; -mod bucket; pub(crate) mod step; type AttributionOutputsChunk = AttributionOutputs< @@ -91,169 +83,6 @@ where } } -// Aggregation -// -// The input to aggregation is a stream of tuples of (attributed breakdown key, attributed trigger -// value) for each record. -// -// The first stage of aggregation decodes the breakdown key to produce a vector of trigger value -// to be added to each output bucket. At most one element of this vector can be non-zero, -// corresponding to the breakdown key value. This stage is implemented by the -// `move_single_value_to_bucket` function. -// -// The second stage of aggregation sums these vectors across all records, to produce the final -// output histogram. -// -// The first stage of aggregation is vectorized over records, meaning that a chunk of N -// records is collected, and the `move_single_value_to_bucket` function is called to -// decode the breakdown keys for all of those records simultaneously. -// -// The second stage of aggregation is vectorized over histogram buckets, meaning that -// the values in all `B` output buckets are added simultaneously. -// -// An intermediate transpose occurs between the two stages of aggregation, to convert from the -// record-vectorized representation to the bucket-vectorized representation. -// -// The input to this transpose is `&[BitDecomposed>]`, indexed -// by buckets, bits of trigger value, and contribution rows. -// -// The output is `&[BitDecomposed>]`, indexed by -// contribution rows, bits of trigger value, and buckets. -#[tracing::instrument(name = "aggregate", skip_all, fields(streams = contributions_stream_len))] -pub async fn aggregate_contributions<'ctx, C, St, BK, TV, HV, const B: usize>( - ctx: C, - contributions_stream: St, - mut contributions_stream_len: usize, -) -> Result>, Error> -where - C: UpgradableContext + 'ctx, - St: Stream, Error>> + Send, - BK: BreakdownKey, - TV: BooleanArray + U128Conversions, - HV: BooleanArray + U128Conversions, - Boolean: FieldSimd, - Replicated: BooleanProtocols<::Context, B>, - Replicated: SecureMul<::Context>, - Replicated: BooleanArrayMul<::Context>, - Replicated: BooleanArrayMul<::Context>, - BitDecomposed>: - for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, - BitDecomposed>: - for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, - Vec>>: for<'a> TransposeFrom< - &'a [BitDecomposed>], - Error = Infallible, - >, - Vec>: - for<'a> TransposeFrom<&'a BitDecomposed>, Error = LengthError>, -{ - assert!(contributions_stream_len != 0); - - let move_to_bucket_chunk_size = TARGET_PROOF_SIZE / B / usize::try_from(TV::BITS).unwrap(); - let move_to_bucket_records = - TotalRecords::specified(div_round_up(contributions_stream_len, Const::))?; - let validator = ctx - .narrow(&Step::MoveToBucket) - .set_total_records(move_to_bucket_records) - .dzkp_validator(move_to_bucket_chunk_size); - let bucket_ctx = validator.context(); - // move each value to the correct bucket - let row_contribution_chunk_stream = process_stream_by_chunks( - contributions_stream, - AttributionOutputs { - attributed_breakdown_key_bits: vec![], - capped_attributed_trigger_value: vec![], - }, - move |idx, chunk: AttributionOutputsChunk| { - let record_id = RecordId::from(idx); - let validate_ctx = bucket_ctx.clone(); - let ctx = bucket_ctx - .clone() - .set_total_records(TotalRecords::Indeterminate); - async move { - let result = bucket::move_single_value_to_bucket::<_, AGG_CHUNK>( - ctx.clone(), - record_id, - chunk.attributed_breakdown_key_bits, - chunk.capped_attributed_trigger_value, - B, - false, - ) - .instrument(tracing::debug_span!("move_to_bucket", chunk = idx)) - .await; - - validate_ctx.validate_record(record_id).await?; - - result - } - }, - ); - - let mut aggregation_input = row_contribution_chunk_stream - // Rather than transpose out of record-vectorized form and then transpose again back - // into bucket-vectorized form, we use a special transpose (the "aggregation - // intermediate transpose") that combines the two steps. - // - // Since the bucket-vectorized representation is separable by records, we do the - // transpose within the `Chunk` wrapper using `Chunk::map`, and then invoke - // `Chunk::into_iter` via `try_flatten_iters` to produce an unchunked stream of - // records, vectorized by buckets. - .then(|fut| { - fut.map(|res| { - res.map(|chunk| { - chunk.map(|data| Vec::transposed_from(data.as_slice()).unwrap_infallible()) - }) - }) - }) - .try_flatten_iters() - .boxed(); - - let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()); - let chunks = iter::from_fn(|| { - if contributions_stream_len >= agg_proof_chunk { - contributions_stream_len -= agg_proof_chunk; - Some(agg_proof_chunk) - } else if contributions_stream_len > 0 { - let chunk = contributions_stream_len; - contributions_stream_len = 0; - Some(chunk) - } else { - None - } - }); - let mut intermediate_results = Vec::new(); - let mut chunk_counter = 0; - for chunk in chunks { - let ctx = ctx.narrow(&Step::AggregateChunk(chunk_counter)); - chunk_counter += 1; - let stream = aggregation_input.by_ref().take(chunk); - let validator = ctx.dzkp_validator(agg_proof_chunk); - let result = - aggregate_values::<_, HV, B>(validator.context(), stream.boxed(), chunk).await?; - validator.validate().await?; - intermediate_results.push(Ok(result)); - } - - if intermediate_results.len() > 1 { - let ctx = ctx.narrow(&Step::AggregateChunk(chunk_counter)); - let validator = ctx.dzkp_validator(agg_proof_chunk); - let stream_len = intermediate_results.len(); - let aggregated_result = aggregate_values::<_, HV, B>( - validator.context(), - stream::iter(intermediate_results).boxed(), - stream_len, - ) - .await?; - validator.validate().await?; - Ok(aggregated_result) - } else { - intermediate_results - .into_iter() - .next() - .expect("aggregation input must not be empty") - } -} - /// A vector of histogram contributions for each output bucket. /// /// Aggregation is vectorized over histogram buckets, so bit 0 for every histogram bucket is stored @@ -268,9 +97,13 @@ pub type AggResult = Result /// /// $\sum_{i = 1}^k 2^{k - i} (b + i - 1) \approx 2^k (b + 1) = N (b + 1)$ pub fn aggregate_values_proof_chunk(input_width: usize, input_item_bits: usize) -> usize { - TARGET_PROOF_SIZE / input_width / (input_item_bits + 1) + max(2, TARGET_PROOF_SIZE / input_width / (input_item_bits + 1)).next_power_of_two() } +// This is the step count for AggregateChunkStep. We need it to size RecordId arrays. +// This value must be at least the log of the aggregation chunk size. +pub const AGGREGATE_DEPTH: usize = 24; + /// Aggregate output contributions /// /// In the case of attribution, each item in `aggregated_stream` is a vector of values to be added @@ -287,10 +120,12 @@ pub fn aggregate_values_proof_chunk(input_width: usize, input_item_bits: usize) /// /// It might be possible to save some cost by using naive wrapping arithmetic. Another /// possibility would be to combine all carries into a single "overflow detected" bit. +#[tracing::instrument(name = "aggregate_values", skip_all, fields(num_rows = num_rows))] pub async fn aggregate_values<'ctx, 'fut, C, OV, const B: usize>( ctx: C, mut aggregated_stream: Pin> + Send + 'fut>>, mut num_rows: usize, + record_ids: Option<&mut [RecordId; AGGREGATE_DEPTH]>, ) -> Result>, Error> where 'ctx: 'fut, @@ -308,24 +143,31 @@ where OV::BITS, ); + let mut record_id_store = None; + let record_ids = + record_ids.unwrap_or_else(|| record_id_store.insert([RecordId::FIRST; AGGREGATE_DEPTH])); + let mut depth = 0; while num_rows > 1 { // Indeterminate TotalRecords is currently required because aggregation does not poll // futures in parallel (thus cannot reach a batch of records). // // We reduce pairwise, passing through the odd record at the end if there is one, so the - // number of outputs (`next_num_rows`) gets rounded up. If calculating an explicit total - // records, that would get rounded down. + // number of outputs (`next_num_rows`) gets rounded up. The number of addition operations + // (number of used record IDs) gets rounded down. let par_agg_ctx = ctx - .narrow(&AggregateChunkStep::Aggregate(depth)) + .narrow(&AggregateChunkStep::from(depth)) .set_total_records(TotalRecords::Indeterminate); let next_num_rows = (num_rows + 1) / 2; + let base_record_id = record_ids[depth]; + record_ids[depth] += num_rows / 2; aggregated_stream = Box::pin( FixedLength::new(aggregated_stream, num_rows) .try_chunks(2) .enumerate() .then(move |(i, chunk_res)| { let ctx = par_agg_ctx.clone(); + let record_id = base_record_id + i; async move { match chunk_res { Err(e) => { @@ -340,7 +182,6 @@ where assert_eq!(chunk_pair.len(), 2); let b = chunk_pair.pop().unwrap(); let a = chunk_pair.pop().unwrap(); - let record_id = RecordId::from(i); if a.len() < usize::try_from(OV::BITS).unwrap() { // If we have enough output bits, add and keep the carry. let (mut sum, carry) = integer_add::<_, AdditionStep, B>( @@ -368,7 +209,7 @@ where "reduce", depth = depth, rows = num_rows, - record = i + record = u32::from(record_id), )) }), ); @@ -438,7 +279,7 @@ pub mod tests { let result: BitDecomposed = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -461,7 +302,7 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -487,7 +328,7 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -515,7 +356,7 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -542,7 +383,7 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -561,7 +402,7 @@ pub mod tests { run(|| async move { let result = TestWorld::default() .upgraded_semi_honest((), |ctx, ()| { - aggregate_values::<_, BA8, 8>(ctx, stream::empty().boxed(), 0) + aggregate_values::<_, BA8, 8>(ctx, stream::empty().boxed(), 0, None) }) .await .map(Result::unwrap) @@ -582,7 +423,7 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await; @@ -602,7 +443,7 @@ pub mod tests { let _ = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len() + 1; - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -624,7 +465,7 @@ pub mod tests { let _ = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len() - 1; - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -710,6 +551,7 @@ pub mod tests { ctx, stream::iter(inputs).boxed(), num_rows, + None, ) }) .await diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs index f75654088..8be4fdcd1 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs @@ -9,30 +9,19 @@ pub(crate) enum AggregationStep { PaddingDp, #[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)] Shuffle, - RevealStep, - #[step(child = BucketStep)] - MoveToBucket, - #[step(count = 32, child = AggregateChunkStep)] - AggregateChunk(usize), -} - -/// the number of steps must be kept in sync with `MAX_BREAKDOWNS` defined -/// [here](https://tinyurl.com/mwnbbnj6) -#[derive(CompactStep)] -#[step(count = 512, child = crate::protocol::boolean::step::EightBitStep, name = "b")] -pub struct BucketStep(usize); - -impl From for BucketStep { - fn from(v: usize) -> Self { - Self(v) - } + Reveal, + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] + RevealValidate, // only partly used -- see code + #[step(count = 4, child = AggregateChunkStep, name = "chunks")] + Aggregate(usize), + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] + AggregateValidate, } +// The step count here is duplicated as the AGGREGATE_DEPTH constant in the code. #[derive(CompactStep)] -pub(crate) enum AggregateChunkStep { - #[step(count = 32, child = AggregateValuesStep)] - Aggregate(usize), -} +#[step(count = 24, child = AggregateValuesStep, name = "fold")] +pub(crate) struct AggregateChunkStep(usize); #[derive(CompactStep)] pub(crate) enum AggregateValuesStep { diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs index eff85fe2d..f5fcefe20 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs @@ -1,11 +1,10 @@ -use std::iter::repeat; +use std::iter::{repeat, repeat_n}; use ipa_step::StepNarrow; use crate::{ error::Error, ff::boolean::Boolean, - helpers::repeat_n, protocol::{ basics::{BooleanProtocols, SecureMul}, boolean::{or::bool_or, NBitStep}, @@ -73,7 +72,7 @@ where .await?; // if carry==1 then {all ones} else {result} - bool_or( + bool_or::<_, S, _, N>( ctx.narrow::(&Step::Select), record_id, &result, diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs index 22a6c4ae5..2dabdc3f4 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs @@ -367,7 +367,7 @@ where #[cfg(all(test, unit_test))] mod tests { - use std::iter::{self, repeat_with}; + use std::iter::{self, repeat_n, repeat_with}; use curve25519_dalek::Scalar; use futures::stream::TryStreamExt; @@ -378,9 +378,9 @@ mod tests { use super::*; use crate::{ ff::{boolean_array::BA64, Serializable}, - helpers::{repeat_n, stream::process_slice_by_chunks}, + helpers::stream::process_slice_by_chunks, protocol::{ - context::{dzkp_validator::DZKPValidator, UpgradableContext}, + context::{dzkp_validator::DZKPValidator, UpgradableContext, TEST_DZKP_STEPS}, ipa_prf::{CONV_CHUNK, CONV_PROOF_CHUNK, PRF_CHUNK}, }, rand::thread_rng, @@ -415,7 +415,7 @@ mod tests { let [res0, res1, res2] = world .semi_honest(records.into_iter(), |ctx, records| async move { let c_ctx = ctx.set_total_records((COUNT + CONV_CHUNK - 1) / CONV_CHUNK); - let validator = &c_ctx.dzkp_validator(CONV_PROOF_CHUNK); + let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, CONV_PROOF_CHUNK); let m_ctx = validator.context(); seq_join( m_ctx.active_work(), @@ -477,7 +477,7 @@ mod tests { let [res0, res1, res2] = world .malicious(records.into_iter(), |ctx, records| async move { let c_ctx = ctx.set_total_records(TOTAL_RECORDS); - let validator = &c_ctx.dzkp_validator(PROOF_CHUNK); + let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, PROOF_CHUNK); let m_ctx = validator.context(); seq_join( m_ctx.active_work(), @@ -518,7 +518,7 @@ mod tests { TestWorld::default() .semi_honest(iter::empty::(), |ctx, _records| async move { let c_ctx = ctx.set_total_records(1); - let validator = &c_ctx.dzkp_validator(1); + let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, 1); let m_ctx = validator.context(); let match_keys = BitDecomposed::new(repeat_n( AdditiveShare::::ZERO, @@ -532,7 +532,7 @@ mod tests { .await .unwrap() }) - .await + .await; }); } diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/step.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/step.rs index c27a8db44..056f5889a 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/step.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/step.rs @@ -1,15 +1,15 @@ use ipa_step_derive::CompactStep; /// FIXME: This step is not generic enough to be used in the `saturated_addition` protocol. -/// It constrains the input to be at most 2 bytes and it will panic in runtime if it is greater +/// It constrains the input to be at most 4 bytes and it will panic in runtime if it is greater /// than that. The issue is that compact gate requires concrete type to be put as child. /// If we ever see it being an issue, we should make a few implementations of this similar to what /// we've done for bit steps #[derive(CompactStep)] pub(crate) enum SaturatedAdditionStep { - #[step(child = crate::protocol::boolean::step::SixteenBitStep)] + #[step(child = crate::protocol::boolean::step::ThirtyTwoBitStep)] Add, - #[step(child = crate::protocol::boolean::step::SixteenBitStep)] + #[step(child = crate::protocol::boolean::step::ThirtyTwoBitStep)] Select, } diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 2477f0867..32dbc3a18 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -50,6 +50,16 @@ where } } +impl Default for CanonicalLagrangeDenominator +where + F: PrimeField + TryFrom, + >::Error: Debug, +{ + fn default() -> Self { + Self::new() + } +} + /// `LagrangeTable` is a precomputed table for the Lagrange evaluation. /// Allows to compute points on the polynomial, i.e. output points, /// given enough points on the polynomial, i.e. input points, diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index 780eabcf2..808dc4476 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -3,14 +3,14 @@ use std::{borrow::Borrow, iter::zip, marker::PhantomData}; #[cfg(all(test, unit_test))] use crate::ff::Fp31; use crate::{ - error::{Error, Error::DZKPMasks}, + error::Error::{self, DZKPMasks}, ff::{Fp61BitPrime, PrimeField}, helpers::hashing::{compute_hash, hash_to_field}, protocol::{ context::Context, ipa_prf::malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, prss::SharedRandomness, - RecordId, + RecordId, RecordIdRange, }, }; @@ -179,7 +179,7 @@ impl ProofGenerat .collect::>() } - fn gen_proof_shares_from_prss(ctx: &C, record_counter: &mut RecordId) -> ([F; P], [F; P]) + fn gen_proof_shares_from_prss(ctx: &C, record_ids: &mut RecordIdRange) -> ([F; P], [F; P]) where C: Context, { @@ -187,9 +187,9 @@ impl ProofGenerat let mut out_right = [F::ZERO; P]; // use PRSS for i in 0..P { - let (left, right) = ctx.prss().generate_fields::(*record_counter); - - *record_counter += 1; + let (left, right) = ctx + .prss() + .generate_fields::(record_ids.expect_next()); out_left[i] = left; out_right[i] = right; @@ -215,7 +215,7 @@ impl ProofGenerat /// `my_proof_left_share` has type `Vec<[F; P]>`, pub fn gen_artefacts_from_recursive_step( ctx: &C, - record_counter: &mut RecordId, + record_ids: &mut RecordIdRange, lagrange_table: &LagrangeTable, uv_iterator: J, ) -> (UVValues, [F; P], [F; P]) @@ -230,7 +230,7 @@ impl ProofGenerat // generate proof shares from prss let (share_of_proof_from_prover_left, my_proof_right_share) = - Self::gen_proof_shares_from_prss(ctx, record_counter); + Self::gen_proof_shares_from_prss(ctx, record_ids); // generate prover left proof let my_proof_left_share = Self::gen_other_proof_share(my_proof, my_proof_right_share); @@ -267,7 +267,7 @@ mod test { lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, prover::{LargeProofGenerator, SmallProofGenerator, TestProofGenerator, UVValues}, }, - RecordId, + RecordId, RecordIdRange, }, seq_join::SeqJoin, test_executor::run, @@ -396,11 +396,11 @@ mod test { // first iteration let world = TestWorld::default(); - let mut record_counter = RecordId::from(0); + let mut record_ids = RecordIdRange::ALL; let (uv_values, _, _) = TestProofGenerator::gen_artefacts_from_recursive_step::<_, _, _, 4>( &world.contexts()[0], - &mut record_counter, + &mut record_ids, &lagrange_table, uv_1.iter(), ); @@ -496,11 +496,11 @@ mod test { let world = TestWorld::default(); let [helper_1_proofs, helper_2_proofs, helper_3_proofs] = world .semi_honest((), |ctx, ()| async move { - let mut record_counter = RecordId::from(0); + let mut record_ids = RecordIdRange::ALL; (0..NUM_PROOFS) .map(|i| { - assert_eq!(i * 7, usize::from(record_counter)); - TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_counter) + assert_eq!(i * 7, usize::from(record_ids.peek_first())); + TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_ids) }) .collect::>() }) @@ -550,9 +550,9 @@ mod test { let [(h1_proof_left, h1_proof_right), (h2_proof_left, h2_proof_right), (h3_proof_left, h3_proof_right)] = world .semi_honest((), |ctx, ()| async move { - let mut record_counter = RecordId::from(0); + let mut record_ids = RecordIdRange::ALL; let (proof_share_left, my_share_of_right) = - TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_counter); + TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_ids); let proof_u128 = match ctx.role() { Role::H1 => PROOF_1, Role::H2 => PROOF_2, diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 6dc108524..55a02f9f8 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -21,8 +21,8 @@ use crate::{ protocol::{ basics::{BooleanArrayMul, BooleanProtocols, Reveal}, context::{ - dzkp_validator::DZKPValidator, Context, DZKPUpgraded, DZKPUpgradedSemiHonestContext, - MacUpgraded, SemiHonestContext, UpgradableContext, UpgradedSemiHonestContext, + dzkp_validator::DZKPValidator, DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, + UpgradableContext, }, ipa_prf::{ boolean_ops::convert_to_fp25519, @@ -31,6 +31,7 @@ use crate::{ prf_sharding::{ attribute_cap_aggregate, histograms_ranges_sortkeys, PrfShardedIpaInputRow, }, + step::IpaPrfStep, }, prss::FromPrss, RecordId, @@ -40,7 +41,6 @@ use crate::{ SharedValue, TransposeFrom, Vectorizable, }, seq_join::seq_join, - sharding::NotSharded, }; pub(crate) mod aggregation; @@ -55,6 +55,8 @@ pub(crate) mod shuffle; pub(crate) mod step; pub mod validation_protocol; +pub use malicious_security::prover::{LargeProofGenerator, SmallProofGenerator}; + /// Match key type pub type MatchKey = BA64; /// Match key size @@ -96,7 +98,7 @@ use crate::{ protocol::{ context::Validator, dp::dp_for_histogram, - ipa_prf::{oprf_padding::PaddingParameters, prf_eval::PrfSharing}, + ipa_prf::{oprf_padding::PaddingParameters, prf_eval::PrfSharing, shuffle::Shuffle}, }, secret_sharing::replicated::semi_honest::AdditiveShare, }; @@ -220,25 +222,33 @@ where /// Propagates errors from config issues or while running the protocol /// # Panics /// Propagates errors from config issues or while running the protocol -pub async fn oprf_ipa<'ctx, BK, TV, HV, TS, const SS_BITS: usize, const B: usize>( - ctx: SemiHonestContext<'ctx>, +pub async fn oprf_ipa<'ctx, C, BK, TV, HV, TS, const SS_BITS: usize, const B: usize>( + ctx: C, input_rows: Vec>, attribution_window_seconds: Option, dp_params: DpMechanism, dp_padding_params: PaddingParameters, ) -> Result>, Error> where + C: UpgradableContext + 'ctx + Shuffle, BK: BreakdownKey, TV: BooleanArray + U128Conversions, HV: BooleanArray + U128Conversions, TS: BooleanArray + U128Conversions, Boolean: FieldSimd, - Replicated: - BooleanProtocols, B>, - Replicated: BooleanProtocols, B>, - for<'a> Replicated: BooleanArrayMul>, - for<'a> Replicated: BooleanArrayMul>, - for<'a> Replicated: BooleanArrayMul>, + Replicated: BooleanProtocols>, + Replicated: BooleanProtocols, B>, + Replicated: BooleanProtocols, AGG_CHUNK>, + Replicated: BooleanProtocols, CONV_CHUNK>, + Replicated: BooleanProtocols, SORT_CHUNK>, + Replicated: + PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, + Replicated: + Reveal, Output = >::Array>, + Replicated: BooleanArrayMul> + + Reveal, Output = >::Array>, + Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul>, BitDecomposed>: for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, BitDecomposed>: @@ -262,7 +272,7 @@ where let padded_input_rows = apply_dp_padding::<_, OPRFIPAInputRow, B>( ctx.narrow(&Step::PaddingDp), input_rows, - dp_padding_params, + &dp_padding_params, ) .await?; @@ -290,23 +300,22 @@ where prfd_inputs, attribution_window_seconds, &row_count_histogram, + &dp_padding_params, ) .await?; - let noisy_output_histogram = dp_for_histogram::<_, B, HV, SS_BITS>( - ctx.narrow(&Step::DifferentialPrivacy), - output_histogram, - dp_params, - ) - .await?; + let noisy_output_histogram = + dp_for_histogram::<_, B, HV, SS_BITS>(ctx, output_histogram, dp_params).await?; Ok(noisy_output_histogram) } // We expect 2*256 = 512 gates in total for two additions per conversion. The vectorization factor // is CONV_CHUNK. Let `len` equal the number of converted shares. The total amount of // multiplications is CONV_CHUNK*512*len. We want CONV_CHUNK*512*len ≈ 50M, or len ≈ 381, for a -// reasonably-sized proof. -const CONV_PROOF_CHUNK: usize = 400; +// reasonably-sized proof. There is also a constraint on proof chunks to be powers of two, so +// we pick the closest power of two close to 381 but less than that value. 256 gives us around 33M +// multiplications per batch +const CONV_PROOF_CHUNK: usize = 256; #[tracing::instrument(name = "compute_prf_for_inputs", skip_all)] async fn compute_prf_for_inputs( @@ -327,11 +336,15 @@ where let conv_records = TotalRecords::specified(div_round_up(input_rows.len(), Const::))?; let eval_records = TotalRecords::specified(div_round_up(input_rows.len(), Const::))?; - let convert_ctx = ctx - .narrow(&Step::ConvertFp25519) - .set_total_records(conv_records); + let convert_ctx = ctx.set_total_records(conv_records); - let validator = convert_ctx.dzkp_validator(CONV_PROOF_CHUNK); + let validator = convert_ctx.dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::ConvertFp25519, + validate: &Step::ConvertFp25519Validate, + }, + CONV_PROOF_CHUNK, + ); let m_ctx = validator.context(); let curve_pts = seq_join( @@ -351,9 +364,11 @@ where .try_collect::>() .await?; - let eval_ctx = ctx.narrow(&Step::EvalPrf).set_total_records(eval_records); - let prf_key = gen_prf_key(&eval_ctx); - let validator = eval_ctx.validator::(); + let prf_key = gen_prf_key(&ctx.narrow(&IpaPrfStep::PrfKeyGen)); + let validator = ctx + .narrow(&Step::EvalPrf) + .set_total_records(eval_records) + .validator::(); let eval_ctx = validator.context(); let prf_of_match_keys = seq_join( @@ -440,11 +455,58 @@ pub mod tests { ]; // trigger value of 2 attributes to earlier source row with breakdown 1 and trigger // value of 5 attributes to source row with breakdown 2. let dp_params = DpMechanism::NoDp; - let padding_params = PaddingParameters::relaxed(); + let padding_params = if cfg!(feature = "shuttle") { + // To reduce runtime. There is also a hard upper limit in the shuttle + // config (`max_steps`), that may need to be increased to support larger + // runs. + PaddingParameters::no_padding() + } else { + PaddingParameters::relaxed() + }; let mut result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { - oprf_ipa::( + oprf_ipa::<_, BA5, BA3, BA16, BA20, 5, 32>( + ctx, + input_rows, + None, + dp_params, + padding_params, + ) + .await + .unwrap() + }) + .await + .reconstruct(); + result.truncate(EXPECTED.len()); + assert_eq!( + result.iter().map(|&v| v.as_u128()).collect::>(), + EXPECTED, + ); + }); + } + + #[test] + fn malicious() { + const EXPECTED: &[u128] = &[0, 2, 5, 0, 0, 0, 0, 0]; + + run(|| async { + let world = TestWorld::default(); + + let records: Vec = vec![ + test_input(0, 12345, false, 1, 0), + test_input(5, 12345, false, 2, 0), + test_input(10, 12345, true, 0, 5), + test_input(0, 68362, false, 1, 0), + test_input(20, 68362, true, 0, 2), + ]; // trigger value of 2 attributes to earlier source row with breakdown 1 and trigger + // value of 5 attributes to source row with breakdown 2. + let dp_params = DpMechanism::NoDp; + let padding_params = PaddingParameters::no_padding(); + + let mut result: Vec<_> = world + .malicious(records.into_iter(), |ctx, input_rows| async move { + oprf_ipa::<_, BA5, BA3, BA16, BA20, 5, 32>( ctx, input_rows, None, @@ -501,7 +563,7 @@ pub mod tests { ]; let mut result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { - oprf_ipa::( + oprf_ipa::<_, BA5, BA3, BA16, BA20, SS_BITS, B>( ctx, input_rows, None, @@ -562,7 +624,7 @@ pub mod tests { let mut result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { - oprf_ipa::( + oprf_ipa::<_, BA5, BA3, BA8, BA20, 5, 32>( ctx, input_rows, None, @@ -598,7 +660,7 @@ pub mod tests { let mut result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { - oprf_ipa::( + oprf_ipa::<_, BA5, BA3, BA8, BA20, 5, 32>( ctx, input_rows, None, @@ -652,7 +714,124 @@ pub mod tests { let padding_params = PaddingParameters::no_padding(); let mut result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { - oprf_ipa::( + oprf_ipa::<_, BA8, BA3, BA16, BA20, 5, 256>( + ctx, + input_rows, + None, + dp_params, + padding_params, + ) + .await + .unwrap() + }) + .await + .reconstruct(); + result.truncate(EXPECTED.len()); + assert_eq!( + result.iter().map(|&v| v.as_u128()).collect::>(), + EXPECTED, + ); + }); + } +} + +#[cfg(all(test, all(compact_gate, feature = "in-memory-infra")))] +mod compact_gate_tests { + use ipa_step::{CompactStep, StepNarrow}; + + use crate::{ + ff::{ + boolean_array::{BA20, BA5, BA8}, + U128Conversions, + }, + helpers::query::DpMechanism, + protocol::{ + ipa_prf::{oprf_ipa, oprf_padding::PaddingParameters}, + step::{ProtocolGate, ProtocolStep}, + }, + test_executor::run, + test_fixture::{ipa::TestRawDataRecord, Reconstruct, Runner, TestWorld, TestWorldConfig}, + }; + + #[test] + fn step_count_limit() { + // This is an arbitrary limit intended to catch changes that unintentionally + // blow up the step count. It can be increased, within reason. + const STEP_COUNT_LIMIT: u32 = 24_000; + assert!( + ProtocolStep::STEP_COUNT < STEP_COUNT_LIMIT, + "Step count of {actual} exceeds limit of {STEP_COUNT_LIMIT}.", + actual = ProtocolStep::STEP_COUNT, + ); + } + + #[test] + fn saturated_agg() { + const EXPECTED: &[u128] = &[0, 255, 255, 0, 0, 0, 0, 0]; + + run(|| async { + let world = TestWorld::new_with(TestWorldConfig { + initial_gate: Some(ProtocolGate::default().narrow(&ProtocolStep::IpaPrf)), + ..Default::default() + }); + + let records: Vec = vec![ + TestRawDataRecord { + timestamp: 0, + user_id: 12345, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 5, + user_id: 12345, + is_trigger_report: false, + breakdown_key: 2, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 10, + user_id: 12345, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 255, + }, + TestRawDataRecord { + timestamp: 20, + user_id: 12345, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 255, + }, + TestRawDataRecord { + timestamp: 30, + user_id: 12345, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 255, + }, + TestRawDataRecord { + timestamp: 0, + user_id: 68362, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 20, + user_id: 68362, + is_trigger_report: true, + breakdown_key: 1, + trigger_value: 255, + }, + ]; + let dp_params = DpMechanism::NoDp; + let padding_params = PaddingParameters::relaxed(); + + let mut result: Vec<_> = world + .semi_honest(records.into_iter(), |ctx, input_rows| async move { + oprf_ipa::<_, BA5, BA8, BA8, BA20, 5, 32>( ctx, input_rows, None, diff --git a/ipa-core/src/protocol/ipa_prf/oprf_padding/insecure.rs b/ipa-core/src/protocol/ipa_prf/oprf_padding/insecure.rs index b7268900d..d37edfecb 100644 --- a/ipa-core/src/protocol/ipa_prf/oprf_padding/insecure.rs +++ b/ipa-core/src/protocol/ipa_prf/oprf_padding/insecure.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - use std::f64::consts::E; use rand::distributions::{BernoulliError, Distribution}; @@ -77,6 +75,7 @@ impl Dp { }) } + #[cfg(all(test, unit_test))] fn apply(&self, mut input: I, rng: &mut R) where R: RngCore + CryptoRng, @@ -521,16 +520,18 @@ mod test { println!("A sample value equal to {sample} occurred {count} time(s)",); } } + + #[test] fn test_oprf_padding_dp_constructor() { let mut actual = OPRFPaddingDp::new(-1.0, 1e-6, 10); // (epsilon, delta, sensitivity) let mut expected = Err(Error::BadEpsilon(-1.0)); - assert_eq!(expected, Ok(actual)); + assert_eq!(expected, actual); actual = OPRFPaddingDp::new(1.0, -1e-6, 10); // (epsilon, delta, sensitivity) expected = Err(Error::BadDelta(-1e-6)); - assert_eq!(expected, Ok(actual)); - actual = OPRFPaddingDp::new(1.0, -1e-6, 1_000_001); // (epsilon, delta, sensitivity) + assert_eq!(expected, actual); + actual = OPRFPaddingDp::new(1.0, 1e-6, 1_000_001); // (epsilon, delta, sensitivity) expected = Err(Error::BadSensitivity(1_000_001)); - assert_eq!(expected, Ok(actual)); + assert_eq!(expected, actual); } #[test] diff --git a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs index ce2d1ceda..207dd2a43 100644 --- a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs @@ -274,10 +274,11 @@ where /// # Errors /// Will propagate errors from `apply_dp_padding_pass` +#[tracing::instrument(name = "apply_dp_padding", skip_all)] pub async fn apply_dp_padding( ctx: C, mut input: Vec, - padding_params: PaddingParameters, + padding_params: &PaddingParameters, ) -> Result, Error> where C: Context, @@ -290,7 +291,7 @@ where ctx.narrow(&PaddingDpStep::PaddingDpPass1), input, Role::H3, - &padding_params, + padding_params, ) .await?; @@ -299,7 +300,7 @@ where ctx.narrow(&PaddingDpStep::PaddingDpPass2), input, Role::H2, - &padding_params, + padding_params, ) .await?; @@ -308,7 +309,7 @@ where ctx.narrow(&PaddingDpStep::PaddingDpPass3), input, Role::H1, - &padding_params, + padding_params, ) .await?; diff --git a/ipa-core/src/protocol/ipa_prf/prf_eval.rs b/ipa-core/src/protocol/ipa_prf/prf_eval.rs index 1870c9ef7..60cc41bdc 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_eval.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_eval.rs @@ -78,7 +78,6 @@ where C: UpgradableContext, Fp25519: Vectorizable, { - let ctx = ctx.narrow(&Step::PRFKeyGen); let v: AdditiveShare = ctx.prss().generate(RecordId::FIRST); v.expand() diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs index 19d6caac0..708998ed0 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs @@ -1,4 +1,7 @@ -use std::{convert::Infallible, iter::zip}; +use std::{ + convert::Infallible, + iter::{repeat_n, zip}, +}; use futures::stream; use futures_util::{future::try_join, stream::unfold, Stream, StreamExt}; @@ -6,7 +9,7 @@ use futures_util::{future::try_join, stream::unfold, Stream, StreamExt}; use crate::{ error::{Error, LengthError, UnwrapInfallible}, ff::{boolean::Boolean, boolean_array::BooleanArray, Field, U128Conversions}, - helpers::{repeat_n, stream::TryFlattenItersExt, TotalRecords}, + helpers::{stream::TryFlattenItersExt, TotalRecords}, protocol::{ basics::{SecureMul, ShareKnownValue}, boolean::{and::bool_and_8_bit, or::or}, @@ -258,7 +261,7 @@ where seq_join(sh_ctx.active_work(), stream::iter(chunked_user_results)).try_flatten_iters(), ); let aggregated_result: BitDecomposed> = - aggregate_values::<_, HV, B>(binary_m_ctx, flattened_stream, num_outputs).await?; + aggregate_values::<_, HV, B>(binary_m_ctx, flattened_stream, num_outputs, None).await?; let transposed_aggregated_result: Vec> = Vec::transposed_from(&aggregated_result)?; diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 482123d73..03994ef6c 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -1,7 +1,6 @@ use std::{ convert::Infallible, - iter, - iter::zip, + iter::{self, repeat_n, zip}, num::NonZeroU32, ops::{Not, Range}, }; @@ -12,7 +11,7 @@ use futures::{ FutureExt, Stream, StreamExt, TryStreamExt, }; -use super::aggregation::{aggregate_contributions, breakdown_reveal::breakdown_reveal_aggregation}; +use super::aggregation::breakdown_reveal::breakdown_reveal_aggregation; use crate::{ error::{Error, LengthError}, ff::{ @@ -20,9 +19,9 @@ use crate::{ boolean_array::{BooleanArray, BA32, BA7}, ArrayAccess, Field, U128Conversions, }, - helpers::{repeat_n, stream::TryFlattenItersExt, TotalRecords}, + helpers::{stream::TryFlattenItersExt, TotalRecords}, protocol::{ - basics::{select, BooleanArrayMul, BooleanProtocols, SecureMul, ShareKnownValue}, + basics::{select, BooleanArrayMul, BooleanProtocols, Reveal, SecureMul, ShareKnownValue}, boolean::{ or::or, step::{EightBitStep, ThirtyTwoBitStep}, @@ -30,7 +29,7 @@ use crate::{ }, context::{ dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, - Context, DZKPContext, DZKPUpgraded, UpgradableContext, + Context, DZKPContext, DZKPUpgraded, MaliciousProtocolSteps, UpgradableContext, }, ipa_prf::{ boolean_ops::{ @@ -38,6 +37,7 @@ use crate::{ comparison_and_subtraction_sequential::{compare_gt, integer_sub}, expand_shared_array_in_place, }, + oprf_padding::PaddingParameters, prf_sharding::step::{ AttributionPerRowStep as PerRowStep, AttributionStep as Step, AttributionWindowStep as WindowStep, @@ -49,7 +49,7 @@ use crate::{ }, secret_sharing::{ replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, - BitDecomposed, FieldSimd, SharedValue, TransposeFrom, + BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, }, }; @@ -319,14 +319,14 @@ pub struct AttributionOutputs { pub type SecretSharedAttributionOutputs = AttributionOutputs, Replicated>; -#[cfg(all(test, any(unit_test, feature = "shuttle")))] +#[cfg(test)] #[derive(Debug, Clone, Ord, PartialEq, PartialOrd, Eq)] pub struct AttributionOutputsTestInput { pub bk: BK, pub tv: TV, } -#[cfg(all(test, any(unit_test, feature = "shuttle")))] +#[cfg(test)] impl crate::secret_sharing::IntoShares<(Replicated, Replicated)> for AttributionOutputsTestInput where @@ -385,18 +385,10 @@ where (histogram, ranges) } -fn set_up_contexts( - root_ctx: C, - chunk_size: usize, - histogram: &[usize], -) -> Result<(C::DZKPValidator, Vec>), Error> +fn set_up_contexts(ctx: &C, histogram: &[usize]) -> Result, Error> where - C: UpgradableContext, + C: Context, { - let mut dzkp_validator = root_ctx.dzkp_validator(chunk_size); - let ctx = dzkp_validator.context(); - dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); - let mut context_per_row_depth = Vec::with_capacity(histogram.len()); for (row_number, num_users_having_that_row_number) in histogram.iter().enumerate() { if row_number == 0 { @@ -409,7 +401,7 @@ where context_per_row_depth.push(ctx_for_row_number); } } - Ok((dzkp_validator, context_per_row_depth)) + Ok(context_per_row_depth) } /// @@ -476,6 +468,7 @@ pub async fn attribute_cap_aggregate< input_rows: Vec>, attribution_window_seconds: Option, histogram: &[usize], + padding_parameters: &PaddingParameters, ) -> Result>, Error> where C: UpgradableContext + 'ctx, @@ -487,9 +480,10 @@ where Replicated: BooleanProtocols>, Replicated: BooleanProtocols, B>, Replicated: BooleanProtocols, AGG_CHUNK>, - for<'a> Replicated: BooleanArrayMul>, - for<'a> Replicated: BooleanArrayMul>, - for<'a> Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul> + + Reveal, Output = >::Array>, + Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul>, BitDecomposed>: for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, BitDecomposed>: @@ -512,9 +506,17 @@ where * multiplications_per_record::(attribution_window_seconds)); // Tricky hacks to work around the limitations of our current infrastructure - let num_outputs = input_rows.len() - histogram[0]; - let (dzkp_validator, ctx_for_row_number) = - set_up_contexts(sh_ctx.narrow(&Step::Attribute), chunk_size, histogram)?; + let mut dzkp_validator = sh_ctx.clone().dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::Attribute, + validate: &Step::AttributeValidate, + }, + // TODO: this should not be necessary, but probably can't be removed + // until we align read_size with the batch size. + std::cmp::min(sh_ctx.active_work().get(), chunk_size.next_power_of_two()), + ); + dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); + let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?; // Chunk the incoming stream of records into stream of vectors of records with the same PRF let mut input_stream = stream::iter(input_rows); @@ -535,24 +537,13 @@ where attribution_window_seconds, ); - let aggregation_validator = sh_ctx.narrow(&Step::Aggregate).dzkp_validator(0); - let ctx = aggregation_validator.context(); - - // New aggregation is still experimental, we need proofs that it is private, - // hence it is only enabled behind a feature flag. - if cfg!(feature = "reveal-aggregation") { - // If there was any error in attribution we stop the execution with an error - tracing::warn!("Using the experimental aggregation based on revealing breakdown keys"); - let user_contributions = flattened_user_results.try_collect::>().await?; - breakdown_reveal_aggregation::<_, _, _, HV, B>(ctx, user_contributions).await - } else { - aggregate_contributions::<_, _, _, _, HV, B>( - sh_ctx.narrow(&Step::Aggregate), - flattened_user_results, - num_outputs, - ) - .await - } + let user_contributions = flattened_user_results.try_collect::>().await?; + breakdown_reveal_aggregation::<_, BK, TV, HV, B>( + sh_ctx.narrow(&Step::Aggregate), + user_contributions, + padding_parameters, + ) + .await } #[tracing::instrument(name = "attribute_cap", skip_all, fields(unique_match_keys = input.len()))] @@ -568,9 +559,9 @@ where BK: BreakdownKey, TV: BooleanArray + U128Conversions, TS: BooleanArray + U128Conversions, - for<'a> Replicated: BooleanArrayMul, - for<'a> Replicated: BooleanArrayMul, - for<'a> Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, { let chunked_user_results = input @@ -608,9 +599,9 @@ where BK: BooleanArray + U128Conversions, TV: BooleanArray + U128Conversions, TS: BooleanArray + U128Conversions, - for<'a> Replicated: BooleanArrayMul, - for<'a> Replicated: BooleanArrayMul, - for<'a> Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, { assert!(!rows_for_user.is_empty()); if rows_for_user.len() == 1 { @@ -885,7 +876,7 @@ where #[cfg(all(test, unit_test))] pub mod tests { - use std::num::NonZeroU32; + use std::{iter::repeat_n, num::NonZeroU32}; use super::{AttributionOutputs, PrfShardedIpaInputRow}; use crate::{ @@ -894,7 +885,9 @@ pub mod tests { boolean_array::{BooleanArray, BA16, BA20, BA3, BA5, BA8}, Field, U128Conversions, }, - protocol::ipa_prf::prf_sharding::attribute_cap_aggregate, + protocol::ipa_prf::{ + oprf_padding::PaddingParameters, prf_sharding::attribute_cap_aggregate, + }, rand::Rng, secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, @@ -904,6 +897,7 @@ pub mod tests { test_fixture::{Reconstruct, Runner, TestWorld}, }; + #[derive(Clone)] struct PreShardedAndSortedOPRFTestInput { prf_of_match_key: u64, is_trigger_bit: Boolean, @@ -1079,7 +1073,11 @@ pub mod tests { .malicious(records.into_iter(), |ctx, input_rows| async move { Vec::transposed_from( &attribute_cap_aggregate::<_, BA5, BA3, BA16, BA20, 5, 32>( - ctx, input_rows, None, &histogram, + ctx, + input_rows, + None, + &histogram, + &PaddingParameters::relaxed(), ) .await .unwrap(), @@ -1097,6 +1095,7 @@ pub mod tests { ); }); } + #[test] fn semi_honest_aggregation_capping_attribution_with_attribution_window() { const ATTRIBUTION_WINDOW_SECONDS: u32 = 200; @@ -1139,6 +1138,7 @@ pub mod tests { input_rows, NonZeroU32::new(ATTRIBUTION_WINDOW_SECONDS), &histogram, + &PaddingParameters::relaxed(), ) .await .unwrap(), @@ -1157,6 +1157,33 @@ pub mod tests { }); } + #[test] + #[should_panic(expected = "Step index 64 out of bounds for UserNthRowStep with count 64.")] + fn attribution_too_many_records_per_user() { + run(|| async move { + let world = TestWorld::default(); + + let records: Vec> = + repeat_n(oprf_test_input(123, false, 17, 0), 65).collect(); + + let histogram = repeat_n(1, 65).collect::>(); + let histogram_ref = histogram.as_slice(); + + world + .malicious(records.into_iter(), |ctx, input_rows| async move { + attribute_cap_aggregate::<_, BA5, BA3, BA16, BA20, 5, 32>( + ctx, + input_rows, + None, + histogram_ref, + &PaddingParameters::relaxed(), + ) + .await + .unwrap() + }) + .await; + }); + } #[test] fn capping_bugfix() { const HISTOGRAM: [usize; 10] = [5, 5, 5, 5, 5, 5, 5, 2, 1, 1]; @@ -1236,7 +1263,13 @@ pub mod tests { BA20, { SaturatingSumType::BITS as usize }, 256, - >(ctx, input_rows, None, &HISTOGRAM) + >( + ctx, + input_rows, + None, + &HISTOGRAM, + &PaddingParameters::relaxed(), + ) .await .unwrap(), ) diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs index d3f7c9edb..d3f123bf3 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs @@ -1,21 +1,15 @@ use ipa_step_derive::CompactStep; #[derive(CompactStep)] -pub enum UserNthRowStep { - #[step(count = 64, child = AttributionPerRowStep)] - Row(usize), -} - -impl From for UserNthRowStep { - fn from(v: usize) -> Self { - Self::Row(v) - } -} +#[step(count = 64, child = AttributionPerRowStep, name = "row")] +pub struct UserNthRowStep(usize); #[derive(CompactStep)] pub(crate) enum AttributionStep { #[step(child = UserNthRowStep)] Attribute, + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] + AttributeValidate, #[step(child = crate::protocol::ipa_prf::aggregation::step::AggregationStep)] Aggregate, } diff --git a/ipa-core/src/protocol/ipa_prf/quicksort.rs b/ipa-core/src/protocol/ipa_prf/quicksort.rs index c3794c222..943dfb1ec 100644 --- a/ipa-core/src/protocol/ipa_prf/quicksort.rs +++ b/ipa-core/src/protocol/ipa_prf/quicksort.rs @@ -14,7 +14,10 @@ use crate::{ protocol::{ basics::reveal, boolean::{step::ThirtyTwoBitStep, NBitStep}, - context::{dzkp_validator::DZKPValidator, Context, DZKPUpgraded, UpgradableContext}, + context::{ + dzkp_validator::{validated_seq_join, DZKPValidator, TARGET_PROOF_SIZE}, + Context, DZKPUpgraded, MaliciousProtocolSteps, UpgradableContext, + }, ipa_prf::{ boolean_ops::comparison_and_subtraction_sequential::compare_gt, step::{QuicksortPassStep, QuicksortStep as Step}, @@ -94,6 +97,10 @@ where } } +fn quicksort_proof_chunk(key_bits: usize) -> usize { + (TARGET_PROOF_SIZE / key_bits / SORT_CHUNK).next_power_of_two() +} + /// Insecure quicksort using MPC comparisons and a key extraction function `get_key`. /// /// `get_key` takes as input an element in the slice and outputs the key by which we sort by @@ -166,12 +173,13 @@ where let total_records_usize = div_round_up(num_comparisons_needed, Const::); let total_records = TotalRecords::specified(total_records_usize) .expect("num_comparisons_needed should not be zero"); - let v = ctx - .narrow(&Step::QuicksortPass(quicksort_pass)) - .set_total_records(total_records) - // TODO: use something like this when validating in chunks - //.dzkp_validator(TARGET_PROOF_SIZE / usize::try_from(K::BITS).unwrap() / SORT_CHUNK); - .dzkp_validator(total_records_usize); + let v = ctx.set_total_records(total_records).dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::quicksort_pass(quicksort_pass), + validate: &Step::quicksort_pass_validate(quicksort_pass), + }, + quicksort_proof_chunk(usize::try_from(K::BITS).unwrap()), + ); let c = v.context(); let cmp_ctx = c.narrow(&QuicksortPassStep::Compare); let rvl_ctx = c.narrow(&QuicksortPassStep::Reveal); @@ -180,7 +188,7 @@ where stream::iter(ranges_to_sort.clone().into_iter().filter(|r| r.len() > 1)) .flat_map(|range| { // set up iterator - let mut iterator = list[range.clone()].iter().map(get_key).cloned(); + let mut iterator = list[range].iter().map(get_key).cloned(); // first element is pivot, apply key extraction function f let pivot = iterator.next().unwrap(); repeat(pivot).zip(stream::iter(iterator)) @@ -191,8 +199,8 @@ where K::BITS <= ThirtyTwoBitStep::BITS, "ThirtyTwoBitStep is not large enough to accommodate this sort" ); - let compare_results = seq_join( - ctx.active_work(), + let compare_results = validated_seq_join( + v, process_stream_by_chunks::<_, _, _, _, _, _, SORT_CHUNK>( compare_index_pairs, (Vec::new(), Vec::new()), @@ -212,9 +220,6 @@ where .try_collect::>() .await?; - // TODO: validate in chunks rather than for the entire input - v.validate().await?; - let revealed: BitVec = seq_join( ctx.active_work(), stream::iter(compare_results).enumerate().map(|(i, chunk)| { @@ -269,7 +274,7 @@ where #[cfg(all(test, unit_test))] pub mod tests { use std::{ - cmp::Ordering, + cmp::{min, Ordering}, iter::{repeat, repeat_with}, }; @@ -386,6 +391,57 @@ pub mod tests { }); } + #[test] + fn test_quicksort_insecure_malicious_batching() { + run(|| async move { + const COUNT: usize = 600; + let world = TestWorld::default(); + let mut rng = thread_rng(); + + // generate vector of random values + let records: Vec = repeat_with(|| rng.gen()).take(COUNT).collect(); + + // Smaller ranges means fewer passes, makes the test faster. + // (With no impact on proof size, because there is a proof per pass.) + let ranges = (0..COUNT) + .step_by(8) + .map(|i| i..min(i + 8, COUNT)) + .collect::>(); + + // convert expected into more readable format + let mut expected: Vec = + records.clone().into_iter().map(|x| x.as_u128()).collect(); + // sort expected + for range in ranges.iter().cloned() { + expected[range].sort_unstable(); + } + + // compute mpc sort + let result: Vec<_> = world + .malicious(records.into_iter(), |ctx, mut r| { + let ranges_copy = ranges.clone(); + async move { + #[allow(clippy::single_range_in_vec_init)] + quicksort_ranges_by_key_insecure(ctx, &mut r, false, |x| x, ranges_copy) + .await + .unwrap(); + r + } + }) + .await + .reconstruct(); + + assert_eq!( + // convert into more readable format + result + .into_iter() + .map(|x| x.as_u128()) + .collect::>(), + expected + ); + }); + } + #[test] fn test_quicksort_insecure_semi_honest_trivial() { run(|| async move { diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs index a34477d69..83343b739 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs @@ -18,7 +18,22 @@ use crate::{ /// # Errors /// Will propagate errors from transport and a few typecasts -pub async fn shuffle( +pub async fn semi_honest_shuffle(ctx: C, shares: I) -> Result>, Error> +where + C: Context, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + S: SharedValue + Add, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, + Standard: Distribution, +{ + Ok(shuffle_protocol(ctx, shares).await?.0) +} + +/// # Errors +/// Will propagate errors from transport and a few typecasts +pub async fn shuffle_protocol( ctx: C, shares: I, ) -> Result<(Vec>, IntermediateShuffleMessages), Error> @@ -47,13 +62,12 @@ where let zs = generate_random_tables_with_peers(shares_len, &ctx_z); match ctx.role() { - Role::H1 => run_h1(&ctx, shares_len, shares, zs).await, - Role::H2 => run_h2(&ctx, shares_len, shares, zs).await, - Role::H3 => run_h3(&ctx, shares_len, zs).await, + Role::H1 => Box::pin(run_h1(&ctx, shares_len, shares, zs)).await, + Role::H2 => Box::pin(run_h2(&ctx, shares_len, shares, zs)).await, + Role::H3 => Box::pin(run_h3(&ctx, shares_len, zs)).await, } } -#[allow(dead_code)] /// This struct stores some intermediate messages during the shuffle. /// In a maliciously secure shuffle, /// these messages need to be checked for consistency across helpers. @@ -64,7 +78,6 @@ pub struct IntermediateShuffleMessages { x2_or_y2: Option>, } -#[allow(dead_code)] impl IntermediateShuffleMessages { /// When `IntermediateShuffleMessages` is initialized correctly, /// this function returns `x1` when `Role = H1` @@ -430,7 +443,7 @@ where pub mod tests { use rand::{thread_rng, Rng}; - use super::shuffle; + use super::shuffle_protocol; use crate::{ ff::{Gf40Bit, U128Conversions}, secret_sharing::replicated::ReplicatedSecretSharing, @@ -453,7 +466,7 @@ pub mod tests { // Stable seed is used to get predictable shuffle results. let mut actual = TestWorld::new_with(TestWorldConfig::default().with_seed(123)) .semi_honest(records.clone().into_iter(), |ctx, shares| async move { - shuffle(ctx, shares).await.unwrap().0 + shuffle_protocol(ctx, shares).await.unwrap().0 }) .await .reconstruct(); @@ -484,7 +497,7 @@ pub mod tests { let [h1, h2, h3] = world .semi_honest(records.clone().into_iter(), |ctx, records| async move { - shuffle(ctx, records).await + shuffle_protocol(ctx, records).await }) .await; diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index 68ba7120f..b03363288 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -1,4 +1,4 @@ -use std::iter; +use std::{iter, ops::Add}; use futures::stream::TryStreamExt; use futures_util::{ @@ -6,6 +6,7 @@ use futures_util::{ stream::iter, }; use generic_array::GenericArray; +use rand::distributions::{Distribution, Standard}; use crate::{ error::Error, @@ -17,7 +18,12 @@ use crate::{ protocol::{ basics::{malicious_reveal, mul::semi_honest_multiply}, context::Context, - ipa_prf::shuffle::{base::IntermediateShuffleMessages, step::OPRFShuffleStep}, + ipa_prf::shuffle::{ + base::IntermediateShuffleMessages, + shuffle_protocol, + step::{OPRFShuffleStep, VerifyShuffleStep}, + }, + prss::SharedRandomness, RecordId, }, secret_sharing::{ @@ -27,30 +33,133 @@ use crate::{ seq_join::seq_join, }; +/// This function executes the maliciously secure shuffle protocol on the input: `shares`. +/// +/// ## Errors +/// Propagates network, multiplication and conversion errors from sub functions. +/// +/// ## Panics +/// Panics when `S::Bits + 32 != B::Bits` or type conversions fail. +pub async fn malicious_shuffle( + ctx: C, + shares: I, +) -> Result>, Error> +where + C: Context, + S: BooleanArray, + B: BooleanArray, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + ::IntoIter: Send, + for<'a> &'a B: Add, + for<'a> &'a B: Add<&'a B, Output = B>, + Standard: Distribution, +{ + // assert lengths + assert_eq!(S::BITS + 32, B::BITS); + // compute amount of MAC keys + let amount_of_keys: usize = (usize::try_from(S::BITS).unwrap() + 31) / 32; + // // generate MAC keys + let keys = (0..amount_of_keys) + .map(|i| ctx.prss().generate(RecordId::from(i))) + .collect::>>(); + + // compute and append tags to rows + let shares_and_tags: Vec> = + compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?; + + // shuffle + let (shuffled_shares, messages) = shuffle_protocol(ctx.clone(), shares_and_tags).await?; + + // verify the shuffle + verify_shuffle::<_, S, B>( + ctx.narrow(&OPRFShuffleStep::VerifyShuffle), + &keys, + &shuffled_shares, + messages, + ) + .await?; + + // truncate tags from output_shares + // verify_shuffle ensures that truncate_tags yields the correct rows + Ok(truncate_tags(&shuffled_shares)) +} + +/// This function truncates the tags from the output shares of the shuffle protocol +/// +/// ## Panics +/// Panics when `S::Bits > B::Bits`. +fn truncate_tags(shares_and_tags: &[AdditiveShare]) -> Vec> +where + S: BooleanArray, + B: BooleanArray, +{ + shares_and_tags + .iter() + .map(|row_with_tag| { + AdditiveShare::new( + split_row_and_tag(row_with_tag.left()).0, + split_row_and_tag(row_with_tag.right()).0, + ) + }) + .collect() +} + +/// This function splits a row with tag into +/// a row without tag and a tag. +/// +/// When `row_with_tag` does not have the correct format, +/// i.e. deserialization returns an error, +/// the output row and tag will be the default values. +/// +/// ## Panics +/// Panics when the lengths are incorrect: +/// `S` in bytes needs to be equal to `tag_offset`. +/// `B` in bytes needs to be equal to `tag_offset + 4`. +fn split_row_and_tag(row_with_tag: B) -> (S, Gf32Bit) { + let tag_offset = usize::try_from((S::BITS + 7) / 8).unwrap(); + let mut buf = GenericArray::default(); + row_with_tag.serialize(&mut buf); + ( + S::deserialize(GenericArray::from_slice(&buf.as_slice()[0..tag_offset])) + .unwrap_or_default(), + Gf32Bit::deserialize(GenericArray::from_slice(&buf.as_slice()[tag_offset..])) + .unwrap_or_default(), + ) +} + /// This function verifies the `shuffled_shares` and the `IntermediateShuffleMessages`. /// /// ## Errors /// Propagates network errors. /// Further, returns an error when messages are inconsistent with the MAC tags. -async fn verify_shuffle( +async fn verify_shuffle( ctx: C, key_shares: &[AdditiveShare], - shuffled_shares: &[AdditiveShare], - messages: IntermediateShuffleMessages, + shuffled_shares: &[AdditiveShare], + messages: IntermediateShuffleMessages, ) -> Result<(), Error> { // reveal keys let k_ctx = ctx - .narrow(&OPRFShuffleStep::RevealMACKey) + .narrow(&VerifyShuffleStep::RevealMACKey) .set_total_records(TotalRecords::specified(key_shares.len())?); - let keys = reveal_keys(&k_ctx, key_shares).await?; + let keys = reveal_keys(&k_ctx, key_shares) + .await? + .iter() + .map(Gf32Bit::from_array) + .collect::>(); // verify messages and shares match ctx.role() { - Role::H1 => h1_verify(ctx, &keys, shuffled_shares, messages.get_x1_or_y1()).await, - Role::H2 => h2_verify(ctx, &keys, shuffled_shares, messages.get_x2_or_y2()).await, + Role::H1 => { + h1_verify::<_, S, B>(ctx, &keys, shuffled_shares, messages.get_x1_or_y1()).await + } + Role::H2 => { + h2_verify::<_, S, B>(ctx, &keys, shuffled_shares, messages.get_x2_or_y2()).await + } Role::H3 => { let (y1, y2) = messages.get_both_x_or_ys(); - h3_verify(ctx, &keys, shuffled_shares, y1, y2).await + h3_verify::<_, S, B>(ctx, &keys, shuffled_shares, y1, y2).await } } } @@ -64,17 +173,17 @@ async fn verify_shuffle( /// Propagates network errors. Further it returns an error when /// `hash_x1 != hash_y1` or `hash_c_h2 != hash_a_xor_b` /// or `hash_c_h3 != hash_a_xor_b`. -async fn h1_verify( +async fn h1_verify( ctx: C, - keys: &[StdArray], - share_a_and_b: &[AdditiveShare], - x1: Vec, + keys: &[Gf32Bit], + share_a_and_b: &[AdditiveShare], + x1: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for x1 - let hash_x1 = compute_row_hash(keys, x1); + let hash_x1 = compute_and_hash_tags::(keys, x1); // compute hash for A xor B - let hash_a_xor_b = compute_row_hash( + let hash_a_xor_b = compute_and_hash_tags::( keys, share_a_and_b .iter() @@ -83,10 +192,10 @@ async fn h1_verify( // setup channels let h3_ctx = ctx - .narrow(&OPRFShuffleStep::HashesH3toH1) + .narrow(&VerifyShuffleStep::HashesH3toH1) .set_total_records(TotalRecords::specified(2)?); let h2_ctx = ctx - .narrow(&OPRFShuffleStep::HashH2toH1) + .narrow(&VerifyShuffleStep::HashH2toH1) .set_total_records(TotalRecords::ONE); let channel_h3 = &h3_ctx.recv_channel::(ctx.role().peer(Direction::Left)); let channel_h2 = &h2_ctx.recv_channel::(ctx.role().peer(Direction::Right)); @@ -131,27 +240,27 @@ async fn h1_verify( /// ## Errors /// Propagates network errors. Further it returns an error when /// `hash_x2 != hash_y2`. -async fn h2_verify( +async fn h2_verify( ctx: C, - keys: &[StdArray], - share_b_and_c: &[AdditiveShare], - x2: Vec, + keys: &[Gf32Bit], + share_b_and_c: &[AdditiveShare], + x2: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for x2 - let hash_x2 = compute_row_hash(keys, x2); + let hash_x2 = compute_and_hash_tags::(keys, x2); // compute hash for C - let hash_c = compute_row_hash( + let hash_c = compute_and_hash_tags::( keys, share_b_and_c.iter().map(ReplicatedSecretSharing::right), ); // setup channels let h1_ctx = ctx - .narrow(&OPRFShuffleStep::HashH2toH1) + .narrow(&VerifyShuffleStep::HashH2toH1) .set_total_records(TotalRecords::specified(1)?); let h3_ctx = ctx - .narrow(&OPRFShuffleStep::HashH3toH2) + .narrow(&VerifyShuffleStep::HashH3toH2) .set_total_records(TotalRecords::specified(1)?); let channel_h1 = &h1_ctx.send_channel::(ctx.role().peer(Direction::Left)); let channel_h3 = &h3_ctx.recv_channel::(ctx.role().peer(Direction::Right)); @@ -179,30 +288,30 @@ async fn h2_verify( /// /// ## Errors /// Propagates network errors. -async fn h3_verify( +async fn h3_verify( ctx: C, - keys: &[StdArray], - share_c_and_a: &[AdditiveShare], - y1: Vec, - y2: Vec, + keys: &[Gf32Bit], + share_c_and_a: &[AdditiveShare], + y1: Vec, + y2: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for y1 - let hash_y1 = compute_row_hash(keys, y1); + let hash_y1 = compute_and_hash_tags::(keys, y1); // compute hash for y2 - let hash_y2 = compute_row_hash(keys, y2); + let hash_y2 = compute_and_hash_tags::(keys, y2); // compute hash for C - let hash_c = compute_row_hash( + let hash_c = compute_and_hash_tags::( keys, share_c_and_a.iter().map(ReplicatedSecretSharing::left), ); // setup channels let h1_ctx = ctx - .narrow(&OPRFShuffleStep::HashesH3toH1) + .narrow(&VerifyShuffleStep::HashesH3toH1) .set_total_records(TotalRecords::specified(2)?); let h2_ctx = ctx - .narrow(&OPRFShuffleStep::HashH3toH2) + .narrow(&VerifyShuffleStep::HashH3toH2) .set_total_records(TotalRecords::specified(1)?); let channel_h1 = &h1_ctx.send_channel::(ctx.role().peer(Direction::Right)); let channel_h2 = &h2_ctx.send_channel::(ctx.role().peer(Direction::Left)); @@ -223,19 +332,26 @@ async fn h3_verify( /// /// ## Panics /// Panics when conversion from `BooleanArray` to `Vec(keys: &[StdArray], row_iterator: I) -> Hash +fn compute_and_hash_tags(keys: &[Gf32Bit], row_iterator: I) -> Hash where S: BooleanArray, - I: IntoIterator, + B: BooleanArray, + I: IntoIterator, { - let iterator = row_iterator - .into_iter() - .map(|row| >>::try_into(row).unwrap()); - compute_hash(iterator.map(|row| { - row.into_iter() + let iterator = row_iterator.into_iter().map(|row_with_tag| { + // when split_row_and_tags returns the default value, the verification will fail + // except 2^-security_parameter, i.e. 2^-32 + let (row, tag) = split_row_and_tag(row_with_tag); + >>::try_into(row) + .unwrap() + .into_iter() + .chain(iter::once(tag)) + }); + compute_hash(iterator.map(|row_entry_iterator| { + row_entry_iterator .zip(keys) .fold(Gf32Bit::ZERO, |acc, (row_entry, key)| { - acc + row_entry * *key.first() + acc + row_entry * *key }) })) } @@ -255,6 +371,7 @@ async fn reveal_keys( // reveal MAC keys let keys = ctx .parallel_join(key_shares.iter().enumerate().map(|(i, key)| async move { + // uses malicious_reveal directly since we malicious_shuffle always needs the malicious_revel malicious_reveal(ctx.clone(), RecordId::from(i), None, key).await })) .await? @@ -282,19 +399,27 @@ async fn reveal_keys( /// ## Panics /// When conversion fails, when `S::Bits + 32 != B::Bits` /// or when `rows` is empty or elements in `rows` have length `0`. -async fn compute_and_add_tags( +async fn compute_and_add_tags( ctx: C, keys: &[AdditiveShare], - rows: &[AdditiveShare], -) -> Result>, Error> { - let length = rows.len(); + rows: I, +) -> Result>, Error> +where + C: Context, + S: BooleanArray, + B: BooleanArray, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator + Send, +{ + let row_iterator = rows.into_iter(); + let length = row_iterator.len(); let row_length = keys.len(); // make sure total records is not 0 debug_assert!(length * row_length != 0); let tag_ctx = ctx.set_total_records(TotalRecords::specified(length * row_length)?); let p_ctx = &tag_ctx; - let futures = rows.iter().enumerate().map(|(i, row)| async move { + let futures = row_iterator.enumerate().map(|(i, row)| async move { let row_entries_iterator = row.to_gf32bit()?; // compute tags via inner product between row and keys let row_tag = p_ctx @@ -313,7 +438,7 @@ async fn compute_and_add_tags( .iter() .fold(AdditiveShare::::ZERO, |acc, x| acc + x); // combine row and row_tag - Ok::, Error>(concatenate_row_and_tag::(row, &row_tag)) + Ok::, Error>(concatenate_row_and_tag::(&row, &row_tag)) }); seq_join(ctx.active_work(), iter(futures)) @@ -352,14 +477,94 @@ mod tests { use crate::{ ff::{ boolean_array::{BA112, BA144, BA20, BA32, BA64}, - Serializable, + Serializable, U128Conversions, }, - protocol::ipa_prf::shuffle::base::shuffle, + helpers::in_memory_config::{MaliciousHelper, MaliciousHelperContext}, + protocol::ipa_prf::shuffle::base::shuffle_protocol, secret_sharing::SharedValue, test_executor::run, - test_fixture::{Reconstruct, Runner, TestWorld}, + test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig}, }; + /// Test the hashing of `BA112` and tag equality. + #[test] + fn hash() { + run(|| async { + let world = TestWorld::default(); + + let mut rng = thread_rng(); + let record = rng.gen::(); + + let (keys, result) = world + .semi_honest(record, |ctx, record| async move { + // compute amount of MAC keys + let amount_of_keys: usize = (usize::try_from(BA112::BITS).unwrap() + 31) / 32; + // // generate MAC keys + let keys = (0..amount_of_keys) + .map(|i| ctx.prss().generate_fields(RecordId::from(i))) + .map(|(left, right)| AdditiveShare::new(left, right)) + .collect::>>(); + + // compute and append tags to rows + let shares_and_tags: Vec> = compute_and_add_tags( + ctx.narrow(&OPRFShuffleStep::GenerateTags), + &keys, + iter::once(record), + ) + .await + .unwrap(); + + (keys, shares_and_tags) + }) + .await + .reconstruct(); + + let result_ba = BA112::deserialize_from_slice(&result[0].as_raw_slice()[0..14]); + + assert_eq!(record, result_ba); + + let tag = Vec::::try_from(record) + .unwrap() + .iter() + .zip(keys) + .fold(Gf32Bit::ZERO, |acc, (entry, key)| acc + *entry * key); + + let tag_mpc = Vec::::try_from(BA32::deserialize_from_slice( + &result[0].as_raw_slice()[14..18], + )) + .unwrap(); + assert_eq!(tag, tag_mpc[0]); + }); + } + + /// This test checks the correctness of the malicious shuffle. + /// It does not check the security against malicious behavior. + #[test] + fn check_shuffle_correctness() { + const RECORD_AMOUNT: usize = 10; + run(|| async { + let world = TestWorld::default(); + let mut rng = thread_rng(); + let mut records = (0..RECORD_AMOUNT) + .map(|_| rng.gen()) + .collect::>(); + + let mut result = world + .semi_honest(records.clone().into_iter(), |ctx, records| async move { + malicious_shuffle::<_, BA112, BA144, _>(ctx, records) + .await + .unwrap() + }) + .await + .reconstruct(); + + records.sort_by_key(BA112::as_u128); + result.sort_by_key(BA112::as_u128); + + assert_eq!(records, result); + }); + } + /// This test checks the correctness of the malicious shuffle /// when all parties behave honestly /// and all the MAC keys are `Gf32Bit::ONE`. @@ -385,11 +590,17 @@ mod tests { // trivial shares of Gf32Bit::ONE let key_shares = vec![AdditiveShare::new(Gf32Bit::ONE, Gf32Bit::ONE); 1]; // run shuffle - let (shares, messages) = shuffle(ctx.narrow("shuffle"), rows).await.unwrap(); + let (shares, messages) = + shuffle_protocol(ctx.narrow("shuffle"), rows).await.unwrap(); // verify it - verify_shuffle(ctx.narrow("verify"), &key_shares, &shares, messages) - .await - .unwrap(); + verify_shuffle::<_, BA32, BA64>( + ctx.narrow("verify"), + &key_shares, + &shares, + messages, + ) + .await + .unwrap(); }) .await; }); @@ -489,9 +700,13 @@ mod tests { // convert key let mac_key: Vec> = key_shares.to_gf32bit().unwrap().collect::>(); - compute_and_add_tags(ctx, &mac_key, &row_shares) - .await - .unwrap() + compute_and_add_tags( + ctx.narrow(&OPRFShuffleStep::GenerateTags), + &mac_key, + row_shares, + ) + .await + .unwrap() }, ) .await @@ -531,4 +746,116 @@ mod tests { fn bad_initialization_too_small() { check_tags::(); } + + #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait + fn interceptor_h1_to_h2(ctx: &MaliciousHelperContext, data: &mut Vec) { + // H1 runs an additive attack against H2 by + // changing x2 + if ctx.gate.as_ref().contains("transfer_x2") && ctx.dest == Role::H2 { + data[0] ^= 1u8; + } + } + + #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait + fn interceptor_h2_to_h3(ctx: &MaliciousHelperContext, data: &mut Vec) { + // H2 runs an additive attack against H3 by + // changing y1 + if ctx.gate.as_ref().contains("transfer_y1") && ctx.dest == Role::H3 { + data[0] ^= 1u8; + } + } + + #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait + fn interceptor_h3_to_h2(ctx: &MaliciousHelperContext, data: &mut Vec) { + // H3 runs an additive attack against H2 by + // changing c_hat_2 + if ctx.gate.as_ref().contains("transfer_c_hat") && ctx.dest == Role::H2 { + data[0] ^= 1u8; + } + } + + /// This test checks that the malicious sort fails + /// under a simple bit flip attack by H1. + /// + /// `x2` will be inconsistent which is checked by `H2`. + #[test] + #[should_panic(expected = "X2 is inconsistent")] + fn fail_under_bit_flip_attack_on_x2() { + const RECORD_AMOUNT: usize = 10; + + run(move || async move { + let mut rng = thread_rng(); + let mut config = TestWorldConfig::default(); + config.stream_interceptor = + MaliciousHelper::new(Role::H1, config.role_assignment(), interceptor_h1_to_h2); + + let world = TestWorld::new_with(config); + let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); + let [_, h2, _] = world + .semi_honest(records.into_iter(), |ctx, shares| async move { + malicious_shuffle::<_, BA32, BA64, _>(ctx, shares).await + }) + .await; + + let _ = h2.unwrap(); + }); + } + + /// This test checks that the malicious sort fails + /// under a simple bit flip attack by H2. + /// + /// `y1` will be inconsistent which is checked by `H1`. + #[test] + #[should_panic(expected = "Y1 is inconsistent")] + fn fail_under_bit_flip_attack_on_y1() { + const RECORD_AMOUNT: usize = 10; + + run(move || async move { + let mut rng = thread_rng(); + let mut config = TestWorldConfig::default(); + config.stream_interceptor = + MaliciousHelper::new(Role::H2, config.role_assignment(), interceptor_h2_to_h3); + + let world = TestWorld::new_with(config); + let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); + let [h1, _, _] = world + .malicious(records.into_iter(), |ctx, shares| async move { + malicious_shuffle::<_, BA32, BA64, _>(ctx, shares).await + }) + .await; + let _ = h1.unwrap(); + }); + } + + /// This test checks that the malicious sort fails + /// under a simple bit flip attack by H3. + /// + /// `c` from `H2` will be inconsistent + /// which is checked by `H1`. + #[test] + #[should_panic(expected = "C from H2 is inconsistent")] + fn fail_under_bit_flip_attack_on_c() { + const RECORD_AMOUNT: usize = 10; + + run(move || async move { + let mut rng = thread_rng(); + let mut config = TestWorldConfig::default(); + config.stream_interceptor = + MaliciousHelper::new(Role::H3, config.role_assignment(), interceptor_h3_to_h2); + + let world = TestWorld::new_with(config); + let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); + let [h1, h2, _] = world + .semi_honest(records.into_iter(), |ctx, shares| async move { + malicious_shuffle::<_, BA32, BA64, _>(ctx, shares).await + }) + .await; + + // x2 should be consistent with y2 + let _ = h2.unwrap(); + + // but this should fail + let _ = h1.unwrap(); + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index 2908bf066..582445190 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -1,8 +1,8 @@ -use std::ops::Add; +use std::{future::Future, ops::Add}; -use rand::distributions::Standard; +use rand::distributions::{Distribution, Standard}; -use self::base::shuffle; +use self::base::shuffle_protocol; use super::{ boolean_ops::{expand_shared_array_in_place, extract_from_shared_array}, prf_sharding::SecretSharedAttributionOutputs, @@ -11,30 +11,99 @@ use crate::{ error::Error, ff::{ boolean::Boolean, - boolean_array::{BooleanArray, BA112, BA64}, + boolean_array::{BooleanArray, BA112, BA144, BA64, BA96}, ArrayAccess, }, - protocol::{context::Context, ipa_prf::OPRFIPAInputRow}, + protocol::{ + context::{Context, MaliciousContext, SemiHonestContext}, + ipa_prf::{ + shuffle::{base::semi_honest_shuffle, malicious::malicious_shuffle}, + OPRFIPAInputRow, + }, + }, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, SharedValue, }, + sharding::ShardBinding, }; pub mod base; -#[allow(dead_code)] pub mod malicious; #[cfg(descriptive_gate)] mod sharded; pub(crate) mod step; +pub trait Shuffle: Context { + fn shuffle( + self, + shares: I, + ) -> impl Future>, Error>> + Send + where + S: BooleanArray, + B: BooleanArray, + I: IntoIterator> + Send, + I::IntoIter: ExactSizeIterator, + ::IntoIter: Send, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, + for<'a> &'a B: Add, + for<'a> &'a B: Add<&'a B, Output = B>, + Standard: Distribution, + Standard: Distribution; +} + +impl<'b, T: ShardBinding> Shuffle for SemiHonestContext<'b, T> { + fn shuffle( + self, + shares: I, + ) -> impl Future>, Error>> + Send + where + S: BooleanArray, + B: BooleanArray, + I: IntoIterator> + Send, + I::IntoIter: ExactSizeIterator, + ::IntoIter: Send, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, + for<'a> &'a B: Add, + for<'a> &'a B: Add<&'a B, Output = B>, + Standard: Distribution, + Standard: Distribution, + { + semi_honest_shuffle::<_, I, S>(self, shares) + } +} + +impl<'b> Shuffle for MaliciousContext<'b> { + fn shuffle( + self, + shares: I, + ) -> impl Future>, Error>> + Send + where + S: BooleanArray, + B: BooleanArray, + I: IntoIterator> + Send, + I::IntoIter: ExactSizeIterator, + ::IntoIter: Send, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, + for<'a> &'a B: Add, + for<'a> &'a B: Add<&'a B, Output = B>, + Standard: Distribution, + Standard: Distribution, + { + malicious_shuffle::<_, S, B, I>(self, shares) + } +} + #[tracing::instrument(name = "shuffle_inputs", skip_all)] pub async fn shuffle_inputs( ctx: C, input: Vec>, ) -> Result>, Error> where - C: Context, + C: Context + Shuffle, BK: BooleanArray, TV: BooleanArray, TS: BooleanArray, @@ -44,7 +113,7 @@ where .map(|item| oprfreport_to_shuffle_input::(&item)) .collect::>(); - let (shuffled, _) = shuffle(ctx, shuffle_input).await?; + let shuffled = ctx.shuffle::(shuffle_input).await?; Ok(shuffled .into_iter() @@ -71,7 +140,7 @@ where .map(|item| attribution_outputs_to_shuffle_input::(&item)) .collect::>(); - let (shuffled, _) = shuffle(ctx, shuffle_input).await?; + let shuffled = malicious_shuffle::<_, R, BA96, _>(ctx, shuffle_input).await?; Ok(shuffled .into_iter() diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs index 0a7f94d76..48c02c103 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs @@ -18,7 +18,7 @@ use crate::{ ff::{boolean_array::BA64, U128Conversions}, helpers::{Direction, Error, Role, TotalRecords}, protocol::{ - context::{reshard, ShardedContext}, + context::{reshard_iter, ShardedContext}, prss::{FromRandom, FromRandomU128, SharedRandomness}, RecordId, }, @@ -88,7 +88,7 @@ trait ShuffleContext: ShardedContext { let data = data.into_iter(); async move { let masking_ctx = self.narrow(&ShuffleStep::Mask); - let mut resharded = assert_send(reshard( + let mut resharded = assert_send(reshard_iter( self.clone(), data.enumerate().map(|(i, item)| { // FIXME(1029): update PRSS trait to compute only left or right part @@ -495,7 +495,7 @@ mod tests { let inputs = [1_u32, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] .map(BA8::truncate_from) .to_vec(); - let mut result = sharded_shuffle::<3, D>(inputs.clone()).await; + let mut result = sharded_shuffle::(inputs.clone()).await; assert_ne!(inputs, result); result.sort_by_key(U128Conversions::as_u128); diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs index c9de371b3..126996574 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs @@ -9,6 +9,13 @@ pub(crate) enum OPRFShuffleStep { TransferCHat, TransferX2, TransferY1, + GenerateTags, + #[step(child = crate::protocol::ipa_prf::shuffle::step::VerifyShuffleStep)] + VerifyShuffle, +} + +#[derive(CompactStep)] +pub(crate) enum VerifyShuffleStep { RevealMACKey, HashesH3toH1, HashH2toH1, diff --git a/ipa-core/src/protocol/ipa_prf/step.rs b/ipa-core/src/protocol/ipa_prf/step.rs index 2f0ab5d92..5020a7185 100644 --- a/ipa-core/src/protocol/ipa_prf/step.rs +++ b/ipa-core/src/protocol/ipa_prf/step.rs @@ -6,10 +6,12 @@ pub(crate) enum IpaPrfStep { PaddingDp, #[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)] Shuffle, - // ConvertInputRowsToPrf, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::Fp25519ConversionStep)] ConvertFp25519, - #[step(child = PrfStep)] + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] + ConvertFp25519Validate, + PrfKeyGen, + #[step(child = crate::protocol::context::step::MaliciousProtocolStep)] EvalPrf, #[step(child = QuicksortStep)] SortByTimestamp, @@ -17,6 +19,8 @@ pub(crate) enum IpaPrfStep { Attribution, #[step(child = crate::protocol::dp::step::DPStep, name = "dp")] DifferentialPrivacy, + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] + DifferentialPrivacyValidate, } #[derive(CompactStep)] @@ -24,6 +28,8 @@ pub(crate) enum QuicksortStep { /// Sort up to 1B rows. We can't exceed that limit for other reasons as well `record_id`. #[step(count = 30, child = crate::protocol::ipa_prf::step::QuicksortPassStep)] QuicksortPass(usize), + #[step(count = 30, child = crate::protocol::context::step::DzkpValidationProtocolStep)] + QuicksortPassValidate(usize), } #[derive(CompactStep)] @@ -35,10 +41,12 @@ pub(crate) enum QuicksortPassStep { #[derive(CompactStep)] pub(crate) enum PrfStep { - PRFKeyGen, GenRandomMask, + #[step(child = crate::protocol::context::step::UpgradeStep)] UpgradeY, + #[step(child = crate::protocol::context::step::UpgradeStep)] UpgradeMask, + #[step(child = crate::protocol::basics::mul::step::MaliciousMultiplyStep)] MultMaskWithPRFInput, RevealR, Revealz, diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs index 5eccdc084..cb2754e5f 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs @@ -1,10 +1,16 @@ +use std::{array, iter::zip}; + +use typenum::{UInt, UTerm, Unsigned, B0, B1}; + use crate::{ + const_assert_eq, error::Error, - ff::Fp61BitPrime, - helpers::{Direction, TotalRecords}, + ff::{Fp61BitPrime, Serializable}, + helpers::{Direction, MpcMessage, TotalRecords}, protocol::{ context::{ dzkp_field::{UVTupleBlock, BLOCK_SIZE}, + dzkp_validator::MAX_PROOF_RECURSION, Context, }, ipa_prf::malicious_security::{ @@ -12,8 +18,9 @@ use crate::{ prover::{LargeProofGenerator, SmallProofGenerator}, }, prss::SharedRandomness, - RecordId, + RecordId, RecordIdRange, }, + secret_sharing::SharedValue, }; /// This a `ProofBatch` generated by a prover. @@ -47,11 +54,18 @@ impl ProofBatch { self.proofs.len() * SmallProofGenerator::PROOF_LENGTH + LargeProofGenerator::PROOF_LENGTH } - /// This function returns an iterator over the field elements of all proofs. - fn iter(&self) -> impl Iterator { - self.first_proof + #[allow(clippy::unnecessary_box_returns)] // clippy bug? `Array` exceeds unnecessary-box-size + fn to_array(&self) -> Box { + assert!(self.len() <= ARRAY_LEN); + let iter = self + .first_proof .iter() - .chain(self.proofs.iter().flat_map(|x| x.iter())) + .chain(self.proofs.iter().flat_map(|x| x.iter())); + let mut array = Box::new(array::from_fn(|_| Fp61BitPrime::ZERO)); + for (i, v) in iter.enumerate() { + array[i] = *v; + } + array } /// Each helper party generates a set of proofs, which are secret-shared. @@ -66,7 +80,11 @@ impl ProofBatch { /// ## Panics /// Panics when the function fails to set the masks without overwritting `u` and `v` values. /// This only happens when there is an issue in the recursion. - pub fn generate(ctx: &C, uv_tuple_inputs: I) -> (Self, Self, Fp61BitPrime, Fp61BitPrime) + pub fn generate( + ctx: &C, + mut prss_record_ids: RecordIdRange, + uv_tuple_inputs: I, + ) -> (Self, Self, Fp61BitPrime, Fp61BitPrime) where C: Context, I: Iterator> + Clone, @@ -77,9 +95,6 @@ impl ProofBatch { const SLL: usize = SmallProofGenerator::LAGRANGE_LENGTH; const SPL: usize = SmallProofGenerator::PROOF_LENGTH; - // set up record counter - let mut record_counter = RecordId::FIRST; - // precomputation for first proof let first_denominator = CanonicalLagrangeDenominator::::new(); let first_lagrange_table = LagrangeTable::::from(first_denominator); @@ -88,32 +103,40 @@ impl ProofBatch { let (mut uv_values, first_proof_from_left, my_first_proof_left_share) = LargeProofGenerator::gen_artefacts_from_recursive_step( ctx, - &mut record_counter, + &mut prss_record_ids, &first_lagrange_table, ProofBatch::polynomials_from_inputs(uv_tuple_inputs), ); - // approximate length of proof vector (rounded up) - let uv_len_bits: u32 = usize::BITS - uv_values.len().leading_zeros(); - let small_recursion_factor_bits: u32 = usize::BITS - SRF.leading_zeros(); - let expected_len = 1 << (uv_len_bits - small_recursion_factor_bits); + // `MAX_PROOF_RECURSION - 2` because: + // * The first level of recursion has already happened. + // * We need (SRF - 1) at the last level to have room for the masks. + let max_uv_values: usize = + (SRF - 1) * SRF.pow(u32::try_from(MAX_PROOF_RECURSION - 2).unwrap()); + assert!( + uv_values.len() <= max_uv_values, + "Proof batch is too large: have {} uv_values, max is {}", + uv_values.len(), + max_uv_values, + ); // storage for other proofs - let mut my_proofs_left_shares = Vec::<[Fp61BitPrime; SPL]>::with_capacity(expected_len); + let mut my_proofs_left_shares = + Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); let mut shares_of_proofs_from_prover_left = - Vec::<[Fp61BitPrime; SPL]>::with_capacity(expected_len); + Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); // generate masks // Prover `P_i` and verifier `P_{i-1}` both compute p(x) // therefore the "right" share computed by this verifier corresponds to that which // was used by the prover to the right. - let (my_p_mask, p_mask_from_right_prover) = ctx.prss().generate_fields(record_counter); - record_counter += 1; + let (my_p_mask, p_mask_from_right_prover) = + ctx.prss().generate_fields(prss_record_ids.expect_next()); // Prover `P_i` and verifier `P_{i+1}` both compute q(x) // therefore the "left" share computed by this verifier corresponds to that which // was used by the prover to the left. - let (q_mask_from_left_prover, my_q_mask) = ctx.prss().generate_fields(record_counter); - record_counter += 1; + let (q_mask_from_left_prover, my_q_mask) = + ctx.prss().generate_fields(prss_record_ids.expect_next()); let denominator = CanonicalLagrangeDenominator::::new(); let lagrange_table = LagrangeTable::::from(denominator); @@ -135,7 +158,7 @@ impl ProofBatch { let (uv_values_new, share_of_proof_from_prover_left, my_proof_left_share) = SmallProofGenerator::gen_artefacts_from_recursive_step( ctx, - &mut record_counter, + &mut prss_record_ids, &lagrange_table, uv_values.iter(), ); @@ -165,52 +188,40 @@ impl ProofBatch { /// /// ## Errors /// Propagates error from sending values over the network channel. - pub async fn send_to_left(&self, ctx: &C) -> Result<(), Error> + pub async fn send_to_left(&self, ctx: &C, record_id: RecordId) -> Result<(), Error> where C: Context, { - // set up context for the communication over the network - let communication_ctx = ctx.set_total_records(TotalRecords::specified(self.len())?); - - // set up channel - let send_channel_left = - &communication_ctx.send_channel::(ctx.role().peer(Direction::Left)); - - // send to left - // we send the proof batch via sending the individual field elements - communication_ctx - .parallel_join( - self.iter().enumerate().map(|(i, x)| async move { - send_channel_left.send(RecordId::from(i), x).await - }), - ) - .await?; - Ok(()) + Ok(ctx + .set_total_records(TotalRecords::Indeterminate) + .send_channel::>(ctx.role().peer(Direction::Left)) + .send(record_id, self.to_array()) + .await?) } /// This function receives a `Proof` from the party on the right. /// /// ## Errors /// Propagates errors from receiving values over the network channel. - pub async fn receive_from_right(ctx: &C, length: usize) -> Result + /// + /// ## Panics + /// If the recursion depth implied by `length` exceeds `MAX_PROOF_RECURSION`. + pub async fn receive_from_right( + ctx: &C, + record_id: RecordId, + length: usize, + ) -> Result where C: Context, { - // set up context - let communication_ctx = ctx.set_total_records(TotalRecords::specified(length)?); - - // set up channel - let receive_channel_right = - &communication_ctx.recv_channel::(ctx.role().peer(Direction::Right)); - - // receive from the right + assert!(length <= ARRAY_LEN); Ok(ctx - .parallel_join( - (0..length) - .map(|i| async move { receive_channel_right.receive(RecordId::from(i)).await }), - ) + .set_total_records(TotalRecords::Indeterminate) + .recv_channel::>(ctx.role().peer(Direction::Right)) + .receive(record_id) .await? .into_iter() + .take(length) .collect()) } @@ -251,6 +262,47 @@ impl ProofBatch { } } +const_assert_eq!( + MAX_PROOF_RECURSION, + 9, + "following impl valid only for MAX_PROOF_RECURSION = 9" +); + +#[rustfmt::skip] +type U1464 = UInt, B0>, B1>, B1>, B0>, B1>, B1>, B1>, B0>, B0>, B0>; + +const ARRAY_LEN: usize = 183; +type Array = [Fp61BitPrime; ARRAY_LEN]; + +impl Serializable for Box { + type Size = U1464; + + type DeserializationError = ::DeserializationError; + + fn serialize(&self, buf: &mut generic_array::GenericArray) { + for (hash, buf) in zip( + **self, + buf.chunks_mut(<::Size as Unsigned>::to_usize()), + ) { + hash.serialize(buf.try_into().unwrap()); + } + } + + fn deserialize( + buf: &generic_array::GenericArray, + ) -> Result { + Ok(buf + .chunks(<::Size as Unsigned>::to_usize()) + .map(|buf| Fp61BitPrime::deserialize(buf.try_into().unwrap())) + .collect::, _>>()? + .try_into() + .unwrap()) + } +} + +impl MpcMessage for Box {} + #[cfg(all(test, unit_test))] mod test { use rand::{thread_rng, Rng}; @@ -263,6 +315,7 @@ mod test { proof_generation::ProofBatch, validation::{test::simple_proof_check, BatchToVerify}, }, + RecordId, RecordIdRange, }, secret_sharing::replicated::ReplicatedSecretSharing, test_executor::run, @@ -312,11 +365,13 @@ mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, uv_tuple_vec.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs index 456c7cdf0..f0430e996 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs @@ -1,16 +1,23 @@ -use std::iter::{once, repeat}; +use std::{ + array, + iter::{once, repeat, zip}, +}; use futures_util::future::{try_join, try_join4}; +use typenum::{Unsigned, U288, U80}; use crate::{ - error::Error, - ff::Fp61BitPrime, + const_assert_eq, + error::{Error, UnwrapInfallible}, + ff::{Fp61BitPrime, Serializable}, helpers::{ hashing::{compute_hash, hash_to_field, Hash}, - Direction, TotalRecords, + Direction, MpcMessage, TotalRecords, }, protocol::{ - context::{step::ZeroKnowledgeProofValidateStep as Step, Context}, + context::{ + dzkp_validator::MAX_PROOF_RECURSION, step::DzkpProofVerifyStep as Step, Context, + }, ipa_prf::{ malicious_security::{ prover::{LargeProofGenerator, SmallProofGenerator}, @@ -55,6 +62,7 @@ impl BatchToVerify { /// Panics when send and receive over the network channels fail. pub async fn generate_batch_to_verify( ctx: C, + record_id: RecordId, my_batch_left_shares: ProofBatch, shares_of_batch_from_left_prover: ProofBatch, p_mask_from_right_prover: Fp61BitPrime, @@ -66,8 +74,8 @@ impl BatchToVerify { // send one batch left and receive one batch from the right let length = my_batch_left_shares.len(); let ((), shares_of_batch_from_right_prover) = try_join( - my_batch_left_shares.send_to_left(&ctx), - ProofBatch::receive_from_right(&ctx, length), + my_batch_left_shares.send_to_left(&ctx, record_id), + ProofBatch::receive_from_right(&ctx, record_id, length), ) .await .unwrap(); @@ -88,7 +96,11 @@ impl BatchToVerify { /// ## Panics /// Panics when recursion factor constant cannot be converted to `u128` /// or when sending and receiving hashes over the network fails. - pub async fn generate_challenges(&self, ctx: C) -> (Vec, Vec) + pub async fn generate_challenges( + &self, + ctx: C, + record_id: RecordId, + ) -> (Vec, Vec) where C: Context, { @@ -101,15 +113,25 @@ impl BatchToVerify { let exclude_small = u128::try_from(SRF).unwrap(); // generate hashes - let my_hashes_prover_left = ProofHashes::generate_hashes(self, Side::Left); - let my_hashes_prover_right = ProofHashes::generate_hashes(self, Side::Right); + let my_hashes_prover_left = ProofHashes::generate_hashes(self, Direction::Left); + let my_hashes_prover_right = ProofHashes::generate_hashes(self, Direction::Right); // receive hashes from the other verifier let ((), (), other_hashes_prover_left, other_hashes_prover_right) = try_join4( - my_hashes_prover_left.send_hashes(&ctx, Side::Left), - my_hashes_prover_right.send_hashes(&ctx, Side::Right), - ProofHashes::receive_hashes(&ctx, my_hashes_prover_left.hashes.len(), Side::Left), - ProofHashes::receive_hashes(&ctx, my_hashes_prover_right.hashes.len(), Side::Right), + my_hashes_prover_left.send_hashes(&ctx, record_id, Direction::Left), + my_hashes_prover_right.send_hashes(&ctx, record_id, Direction::Right), + ProofHashes::receive_hashes( + &ctx, + record_id, + my_hashes_prover_left.hashes.len(), + Direction::Left, + ), + ProofHashes::receive_hashes( + &ctx, + record_id, + my_hashes_prover_right.hashes.len(), + Direction::Right, + ), ) .await .unwrap(); @@ -174,6 +196,7 @@ impl BatchToVerify { /// This function computes and outputs the final `p_r_right_prover * q_r_right_prover` value. async fn compute_p_times_q( ctx: C, + record_id: RecordId, p_r_right_prover: Fp61BitPrime, q_r_left_prover: Fp61BitPrime, ) -> Result @@ -181,7 +204,7 @@ impl BatchToVerify { C: Context, { // send to the left - let communication_ctx = ctx.set_total_records(TotalRecords::specified(1usize)?); + let communication_ctx = ctx.set_total_records(TotalRecords::Indeterminate); let send_right = communication_ctx.send_channel::(ctx.role().peer(Direction::Right)); @@ -189,8 +212,8 @@ impl BatchToVerify { communication_ctx.recv_channel::(ctx.role().peer(Direction::Left)); let ((), q_r_right_prover) = try_join( - send_right.send(RecordId::FIRST, q_r_left_prover), - receive_left.receive(RecordId::FIRST), + send_right.send(record_id, q_r_left_prover), + receive_left.receive(record_id), ) .await?; @@ -201,9 +224,14 @@ impl BatchToVerify { /// /// ## Errors /// Propagates network errors or when the proof fails to verify. + /// + /// ## Panics + /// If the proof exceeds `MAX_PROOF_RECURSION`. + #[allow(clippy::too_many_arguments)] pub async fn verify( &self, ctx: C, + record_id: RecordId, sum_of_uv_right: Fp61BitPrime, p_r_right_prover: Fp61BitPrime, q_r_left_prover: Fp61BitPrime, @@ -221,6 +249,7 @@ impl BatchToVerify { let p_times_q_right = Self::compute_p_times_q( ctx.narrow(&Step::PTimesQ), + record_id, p_r_right_prover, q_r_left_prover, ) @@ -243,31 +272,26 @@ impl BatchToVerify { p_times_q_right, ); - // send dif_left to the right + // send diff_left to the right let length = diff_left.len(); - let communication_ctx = ctx.set_total_records(TotalRecords::specified(length)?); - - let send_channel = - communication_ctx.send_channel::(ctx.role().peer(Direction::Right)); - let receive_channel = - communication_ctx.recv_channel::(ctx.role().peer(Direction::Left)); - - let send_channel_ref = &send_channel; - let receive_channel_ref = &receive_channel; + assert!(length <= MAX_PROOF_RECURSION + 1); - let send_future = communication_ctx.parallel_join( - diff_left - .iter() - .enumerate() - .map(|(i, f)| async move { send_channel_ref.send(RecordId::from(i), f).await }), - ); + let communication_ctx = ctx + .narrow(&Step::Diff) + .set_total_records(TotalRecords::Indeterminate); - let receive_future = communication_ctx.parallel_join( - (0..length) - .map(|i| async move { receive_channel_ref.receive(RecordId::from(i)).await }), - ); + let send_data = array::from_fn(|i| *diff_left.get(i).unwrap_or(&Fp61BitPrime::ZERO)); - let (_, diff_right_from_other_verifier) = try_join(send_future, receive_future).await?; + let ((), receive_data) = try_join( + communication_ctx + .send_channel::(ctx.role().peer(Direction::Right)) + .send(record_id, send_data), + communication_ctx + .recv_channel::(ctx.role().peer(Direction::Left)) + .receive(record_id), + ) + .await?; + let diff_right_from_other_verifier = receive_data[0..length].to_vec(); // compare recombined dif to zero for i in 0..length { @@ -284,21 +308,15 @@ struct ProofHashes { hashes: Vec, } -#[derive(Clone, Copy, Debug)] -enum Side { - Left, - Right, -} - impl ProofHashes { - // Generates hashes for proofs received from prover indicated by `side` - fn generate_hashes(batch_to_verify: &BatchToVerify, side: Side) -> Self { - let (first_proof, other_proofs) = match side { - Side::Left => ( + // Generates hashes for proofs received from prover indicated by `direction` + fn generate_hashes(batch_to_verify: &BatchToVerify, direction: Direction) -> Self { + let (first_proof, other_proofs) = match direction { + Direction::Left => ( &batch_to_verify.first_proof_from_left_prover, &batch_to_verify.proofs_from_left_prover, ), - Side::Right => ( + Direction::Right => ( &batch_to_verify.first_proof_from_right_prover, &batch_to_verify.proofs_from_right_prover, ), @@ -312,54 +330,116 @@ impl ProofHashes { } /// Sends the one verifier's hashes to the other verifier - /// `side` indicates the direction of the prover. - async fn send_hashes(&self, ctx: &C, side: Side) -> Result<(), Error> { - let communication_ctx = ctx.set_total_records(TotalRecords::specified(self.hashes.len())?); - - let send_channel = match side { - // send left hashes to the right - Side::Left => communication_ctx.send_channel::(ctx.role().peer(Direction::Right)), - // send right hashes to the left - Side::Right => communication_ctx.send_channel::(ctx.role().peer(Direction::Left)), - }; - let send_channel_ref = &send_channel; - - communication_ctx - .parallel_join(self.hashes.iter().enumerate().map(|(i, hash)| async move { - send_channel_ref.send(RecordId::from(i), hash).await - })) + /// `direction` indicates the direction of the prover. + async fn send_hashes( + &self, + ctx: &C, + record_id: RecordId, + direction: Direction, + ) -> Result<(), Error> { + assert!(self.hashes.len() <= MAX_PROOF_RECURSION); + let hashes_send = + array::from_fn(|i| self.hashes.get(i).unwrap_or(&Hash::default()).clone()); + let verifier_direction = !direction; + ctx.set_total_records(TotalRecords::Indeterminate) + .send_channel::<[Hash; MAX_PROOF_RECURSION]>(ctx.role().peer(verifier_direction)) + .send(record_id, hashes_send) .await?; Ok(()) } /// This function receives hashes from the other verifier - /// `side` indicates the direction of the prover. - async fn receive_hashes(ctx: &C, length: usize, side: Side) -> Result { - // set up context for the communication over the network - let communication_ctx = ctx.set_total_records(TotalRecords::specified(length)?); - - let recv_channel = match side { - // receive left hashes from the right helper - Side::Left => communication_ctx.recv_channel::(ctx.role().peer(Direction::Right)), - // reeive right hashes from the left helper - Side::Right => communication_ctx.recv_channel::(ctx.role().peer(Direction::Left)), - }; - let recv_channel_ref = &recv_channel; - - let hashes_received = communication_ctx - .parallel_join( - (0..length) - .map(|i| async move { recv_channel_ref.receive(RecordId::from(i)).await }), - ) + /// `direction` indicates the direction of the prover. + async fn receive_hashes( + ctx: &C, + record_id: RecordId, + length: usize, + direction: Direction, + ) -> Result { + assert!(length <= MAX_PROOF_RECURSION); + let verifier_direction = !direction; + let hashes_received = ctx + .set_total_records(TotalRecords::Indeterminate) + .recv_channel::<[Hash; MAX_PROOF_RECURSION]>(ctx.role().peer(verifier_direction)) + .receive(record_id) .await?; - Ok(Self { - hashes: hashes_received, + hashes: hashes_received[0..length].to_vec(), }) } } +const_assert_eq!( + MAX_PROOF_RECURSION, + 9, + "following impl valid only for MAX_PROOF_RECURSION = 9" +); + +impl Serializable for [Hash; MAX_PROOF_RECURSION] { + type Size = U288; + + type DeserializationError = ::DeserializationError; + + fn serialize(&self, buf: &mut generic_array::GenericArray) { + for (hash, buf) in zip( + self, + buf.chunks_mut(<::Size as Unsigned>::to_usize()), + ) { + hash.serialize(buf.try_into().unwrap()); + } + } + + fn deserialize( + buf: &generic_array::GenericArray, + ) -> Result { + Ok(buf + .chunks(<::Size as Unsigned>::to_usize()) + .map(|buf| Hash::deserialize(buf.try_into().unwrap()).unwrap_infallible()) + .collect::>() + .try_into() + .unwrap()) + } +} + +impl MpcMessage for [Hash; MAX_PROOF_RECURSION] {} + +const_assert_eq!( + MAX_PROOF_RECURSION, + 9, + "following impl valid only for MAX_PROOF_RECURSION = 9" +); + +type ProofDiff = [Fp61BitPrime; MAX_PROOF_RECURSION + 1]; + +impl Serializable for ProofDiff { + type Size = U80; + + type DeserializationError = ::DeserializationError; + + fn serialize(&self, buf: &mut generic_array::GenericArray) { + for (hash, buf) in zip( + self, + buf.chunks_mut(<::Size as Unsigned>::to_usize()), + ) { + hash.serialize(buf.try_into().unwrap()); + } + } + + fn deserialize( + buf: &generic_array::GenericArray, + ) -> Result { + Ok(buf + .chunks(<::Size as Unsigned>::to_usize()) + .map(|buf| Fp61BitPrime::deserialize(buf.try_into().unwrap())) + .collect::, _>>()? + .try_into() + .unwrap()) + } +} + +impl MpcMessage for ProofDiff {} + #[cfg(all(test, unit_test))] pub mod test { use futures_util::future::try_join; @@ -382,7 +462,7 @@ pub mod test { validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, }, prss::SharedRandomness, - RecordId, + RecordId, RecordIdRange, }, secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue}, test_executor::run, @@ -526,11 +606,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, uv_tuple_vec.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -539,7 +621,9 @@ pub mod test { .await; // generate and output challenges - batch_to_verify.generate_challenges(ctx).await + batch_to_verify + .generate_challenges(ctx, RecordId::FIRST) + .await }) .await; @@ -637,11 +721,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, vec_my_u_and_v.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -652,7 +738,7 @@ pub mod test { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = batch_to_verify - .generate_challenges(ctx.narrow("generate_hash")) + .generate_challenges(ctx.narrow("generate_hash"), RecordId::FIRST) .await; assert_eq!( @@ -741,11 +827,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, vec_my_u_and_v.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -756,7 +844,7 @@ pub mod test { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = batch_to_verify - .generate_challenges(ctx.narrow("generate_hash")) + .generate_challenges(ctx.narrow("generate_hash"), RecordId::FIRST) .await; assert_eq!( @@ -771,7 +859,10 @@ pub mod test { vec_v_from_left_prover.into_iter(), ); - let p_times_q = BatchToVerify::compute_p_times_q(ctx, p, q).await.unwrap(); + let p_times_q = + BatchToVerify::compute_p_times_q(ctx, RecordId::FIRST, p, q) + .await + .unwrap(); let denominator = CanonicalLagrangeDenominator::< Fp61BitPrime, @@ -826,11 +917,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, vec_my_u_and_v.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -859,7 +952,7 @@ pub mod test { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = batch_to_verify - .generate_challenges(ctx.narrow("generate_hash")) + .generate_challenges(ctx.narrow("generate_hash"), RecordId::FIRST) .await; let (p, q) = batch_to_verify.compute_p_and_q_r( @@ -872,6 +965,7 @@ pub mod test { batch_to_verify .verify( v_ctx, + RecordId::FIRST, sum_of_uv_right, p, q, diff --git a/ipa-core/src/protocol/mod.rs b/ipa-core/src/protocol/mod.rs index 9401cec8d..18dfc6221 100644 --- a/ipa-core/src/protocol/mod.rs +++ b/ipa-core/src/protocol/mod.rs @@ -2,6 +2,7 @@ pub mod basics; pub mod boolean; pub mod context; pub mod dp; +pub mod hybrid; pub mod ipa_prf; pub mod prss; pub mod step; @@ -9,7 +10,7 @@ pub mod step; use std::{ fmt::{Debug, Display, Formatter}, hash::Hash, - ops::{Add, AddAssign}, + ops::{Add, AddAssign, Range}, }; pub use basics::{BasicProtocols, BooleanProtocols}; @@ -106,6 +107,7 @@ impl From for RecordId { impl RecordId { pub(crate) const FIRST: Self = Self(0); + pub(crate) const LAST: Self = Self(u32::MAX); } impl From for u128 { @@ -146,6 +148,30 @@ impl AddAssign for RecordId { } } +pub struct RecordIdRange(Range); + +impl RecordIdRange { + pub const ALL: RecordIdRange = RecordIdRange(RecordId::FIRST..RecordId::LAST); + + #[cfg(all(test, unit_test))] + fn peek_first(&self) -> RecordId { + self.0.start + } + + fn expect_next(&mut self) -> RecordId { + assert!(self.0.start < self.0.end, "RecordIdRange exhausted"); + let val = self.0.start; + self.0.start += 1; + val + } +} + +impl From> for RecordIdRange { + fn from(value: Range) -> Self { + Self(value) + } +} + /// Helper used when an operation may or may not be associated with a specific record. This is /// also used to prevent some kinds of invalid uses of record ID iteration. For example, trying to /// use the record ID to iterate over both the inner and outer vectors in a `Vec>` is an diff --git a/ipa-core/src/protocol/step.rs b/ipa-core/src/protocol/step.rs index a38d50593..cf3658018 100644 --- a/ipa-core/src/protocol/step.rs +++ b/ipa-core/src/protocol/step.rs @@ -7,11 +7,15 @@ pub enum ProtocolStep { Prss, #[step(child = crate::protocol::ipa_prf::step::IpaPrfStep)] IpaPrf, + #[step(child = crate::protocol::hybrid::step::HybridStep)] + Hybrid, Multiply, PrimeFieldAddition, - #[cfg(any(test, feature = "test-fixture"))] - #[step(count = 10, child = crate::test_fixture::step::TestExecutionStep)] - Test(usize), + /// Steps used in unit tests are grouped under this one. Ideally it should be + /// gated behind test configuration, but it does not work with build.rs that + /// does not enable any features when creating protocol gate file + #[step(child = TestExecutionStep)] + Test, /// This step includes all the steps that are currently not linked into a top-level protocol. /// @@ -29,22 +33,17 @@ impl<'de> serde::Deserialize<'de> for ProtocolGate { #[derive(CompactStep)] pub enum DeadCodeStep { - #[step(child = crate::protocol::basics::step::CheckZeroStep)] - CheckZero, - #[step(child = crate::protocol::basics::mul::step::MaliciousMultiplyStep)] - MaliciousMultiply, - #[step(child = crate::protocol::context::step::UpgradeStep)] - UpgradeShare, - #[step(child = crate::protocol::context::step::MaliciousProtocolStep)] - MaliciousProtocol, - #[step(child = crate::protocol::context::step::ValidateStep)] - MaliciousValidation, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::SaturatedSubtractionStep)] SaturatedSubtraction, #[step(child = crate::protocol::ipa_prf::prf_sharding::step::FeatureLabelDotProductStep)] FeatureLabelDotProduct, - #[step(child = crate::protocol::context::step::ZeroKnowledgeProofValidateStep)] - ZeroKnowledgeProofValidate, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::MultiplicationStep)] Multiplication, } + +/// Provides a unique per-iteration context in tests. +#[derive(CompactStep)] +pub enum TestExecutionStep { + #[step(count = 999)] + Iter(usize), +} diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index 923f3fafe..edd1662e4 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -15,29 +15,39 @@ use generic_array::GenericArray; use ipa_step::StepNarrow; use rand::rngs::StdRng; use rand_core::SeedableRng; -#[cfg(all(feature = "shuttle", test))] -use shuttle::future as tokio; use typenum::Unsigned; -#[cfg(any(test, feature = "cli", feature = "test-fixture"))] -use crate::{ - ff::Fp32BitPrime, query::runner::execute_test_multiply, query::runner::test_add_in_prime_field, -}; +#[cfg(any( + test, + feature = "cli", + feature = "test-fixture", + feature = "weak-field" +))] +use crate::ff::FieldType; use crate::{ - ff::{boolean_array::BA32, FieldType, Serializable}, + executor::IpaRuntime, + ff::{boolean_array::BA32, Serializable}, helpers::{ negotiate_prss, query::{QueryConfig, QueryType}, BodyStream, Gateway, }, hpke::PrivateKeyRegistry, - protocol::{context::SemiHonestContext, prss::Endpoint as PrssEndpoint, Gate}, + protocol::{ + context::{MaliciousContext, SemiHonestContext}, + prss::Endpoint as PrssEndpoint, + Gate, + }, query::{ runner::{OprfIpaQuery, QueryResult}, state::RunningQuery, }, sync::Arc, }; +#[cfg(any(test, feature = "cli", feature = "test-fixture"))] +use crate::{ + ff::Fp32BitPrime, query::runner::execute_test_multiply, query::runner::test_add_in_prime_field, +}; pub trait Result: Send + Debug { fn to_bytes(&self) -> Vec; @@ -63,6 +73,7 @@ where /// Needless pass by value because IPA v3 does not make use of key registry yet. #[allow(clippy::too_many_lines, clippy::needless_pass_by_value)] pub fn execute( + runtime: &IpaRuntime, config: QueryConfig, key_registry: Arc, gateway: Gateway, @@ -70,73 +81,95 @@ pub fn execute( ) -> RunningQuery { match (config.query_type, config.field_type) { #[cfg(any(test, feature = "weak-field"))] - (QueryType::TestMultiply, FieldType::Fp31) => { - do_query(config, gateway, input, |prss, gateway, _config, input| { + (QueryType::TestMultiply, FieldType::Fp31) => do_query( + runtime, + config, + gateway, + input, + |prss, gateway, _config, input| { Box::pin(execute_test_multiply::( prss, gateway, input, )) - }) - } + }, + ), #[cfg(any(test, feature = "cli", feature = "test-fixture"))] - (QueryType::TestMultiply, FieldType::Fp32BitPrime) => { - do_query(config, gateway, input, |prss, gateway, _config, input| { + (QueryType::TestMultiply, FieldType::Fp32BitPrime) => do_query( + runtime, + config, + gateway, + input, + |prss, gateway, _config, input| { Box::pin(execute_test_multiply::(prss, gateway, input)) - }) - } + }, + ), + #[cfg(any(test, feature = "cli", feature = "test-fixture"))] + (QueryType::TestShardedShuffle, _) => do_query( + runtime, + config, + gateway, + input, + |_prss, _gateway, _config, _input| unimplemented!(), + ), #[cfg(any(test, feature = "weak-field"))] - (QueryType::TestAddInPrimeField, FieldType::Fp31) => { - do_query(config, gateway, input, |prss, gateway, _config, input| { + (QueryType::TestAddInPrimeField, FieldType::Fp31) => do_query( + runtime, + config, + gateway, + input, + |prss, gateway, _config, input| { Box::pin(test_add_in_prime_field::( prss, gateway, input, )) - }) - } + }, + ), #[cfg(any(test, feature = "cli", feature = "test-fixture"))] - (QueryType::TestAddInPrimeField, FieldType::Fp32BitPrime) => { - do_query(config, gateway, input, |prss, gateway, _config, input| { + (QueryType::TestAddInPrimeField, FieldType::Fp32BitPrime) => do_query( + runtime, + config, + gateway, + input, + |prss, gateway, _config, input| { Box::pin(test_add_in_prime_field::( prss, gateway, input, )) - }) - } + }, + ), // TODO(953): This is really using BA32, not Fp32bitPrime. The `FieldType` mechanism needs // to be reworked. - (QueryType::OprfIpa(ipa_config), FieldType::Fp32BitPrime) => do_query( + (QueryType::SemiHonestOprfIpa(ipa_config), _) => do_query( + runtime, config, gateway, input, move |prss, gateway, config, input| { let ctx = SemiHonestContext::new(prss, gateway); Box::pin( - OprfIpaQuery::::new(ipa_config, key_registry) + OprfIpaQuery::<_, BA32, R>::new(ipa_config, key_registry) .execute(ctx, config.size, input) .then(|res| ready(res.map(|out| Box::new(out) as Box))), ) }, ), - // TODO(953): This is not doing anything differently than the Fp32BitPrime case, except - // using 16 bits for histogram values - #[cfg(any(test, feature = "weak-field"))] - (QueryType::OprfIpa(ipa_config), FieldType::Fp31) => do_query( + (QueryType::MaliciousOprfIpa(ipa_config), _) => do_query( + runtime, config, gateway, input, move |prss, gateway, config, input| { - let ctx = SemiHonestContext::new(prss, gateway); + let ctx = MaliciousContext::new(prss, gateway); Box::pin( - OprfIpaQuery::::new( - ipa_config, - key_registry, - ) - .execute(ctx, config.size, input) - .then(|res| ready(res.map(|out| Box::new(out) as Box))), + OprfIpaQuery::<_, BA32, R>::new(ipa_config, key_registry) + .execute(ctx, config.size, input) + .then(|res| ready(res.map(|out| Box::new(out) as Box))), ) }, ), + (QueryType::SemiHonestHybrid(_), _) => todo!(), } } pub fn do_query( + executor_handle: &IpaRuntime, config: QueryConfig, gateway: B, input_stream: BodyStream, @@ -155,7 +188,7 @@ where { let (tx, rx) = oneshot::channel(); - let join_handle = tokio::spawn(async move { + let join_handle = executor_handle.spawn(async move { let gateway = gateway.borrow(); // TODO: make it a generic argument for this function let mut rng = StdRng::from_entropy(); @@ -207,6 +240,7 @@ mod tests { use tokio::sync::Barrier; use crate::{ + executor::IpaRuntime, ff::{FieldType, Fp31, U128Conversions}, helpers::{ query::{QueryConfig, QueryType}, @@ -327,6 +361,7 @@ mod tests { Fut: Future + Send, { do_query( + &IpaRuntime::current(), QueryConfig { size: 1.try_into().unwrap(), field_type: FieldType::Fp31, diff --git a/ipa-core/src/query/mod.rs b/ipa-core/src/query/mod.rs index aaa437b7a..6e6650862 100644 --- a/ipa-core/src/query/mod.rs +++ b/ipa-core/src/query/mod.rs @@ -8,7 +8,7 @@ use completion::Handle as CompletionHandle; pub use executor::Result as ProtocolResult; pub use processor::{ NewQueryError, PrepareQueryError, Processor as QueryProcessor, QueryCompletionError, - QueryInputError, QueryStatusError, + QueryInputError, QueryKillStatus, QueryKilled, QueryStatusError, }; pub use runner::OprfIpaQuery; pub use state::QueryStatus; diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index a779b5fa6..2d6619f68 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -1,13 +1,14 @@ use std::{ collections::hash_map::Entry, fmt::{Debug, Formatter}, - num::NonZeroUsize, }; use futures::{future::try_join, stream}; +use serde::Serialize; use crate::{ error::Error as ProtocolError, + executor::IpaRuntime, helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, RoleAssignment, @@ -21,6 +22,7 @@ use crate::{ CompletionHandle, ProtocolResult, }, sync::Arc, + utils::NonZeroU32PowerOfTwo, }; /// `Processor` accepts and tracks requests to initiate new queries on this helper party @@ -43,7 +45,8 @@ use crate::{ pub struct Processor { queries: RunningQueries, key_registry: Arc>, - active_work: Option, + active_work: Option, + runtime: IpaRuntime, } impl Default for Processor { @@ -52,6 +55,7 @@ impl Default for Processor { queries: RunningQueries::default(), key_registry: Arc::new(KeyRegistry::::empty()), active_work: None, + runtime: IpaRuntime::current(), } } } @@ -117,12 +121,14 @@ impl Processor { #[must_use] pub fn new( key_registry: KeyRegistry, - active_work: Option, + active_work: Option, + runtime: IpaRuntime, ) -> Self { Self { queries: RunningQueries::default(), key_registry: Arc::new(key_registry), active_work, + runtime, } } @@ -248,6 +254,7 @@ impl Processor { queries.insert( input.query_id, QueryState::Running(executor::execute( + &self.runtime, config, Arc::clone(&self.key_registry), gateway, @@ -328,6 +335,36 @@ impl Processor { Ok(handle.await?) } + + /// Terminates a query with the given id. If query is running, then it + /// is unregistered and its task is terminated. + /// + /// ## Errors + /// if query is not registered on this helper. + /// + /// ## Panics + /// If failed to obtain exclusive access to the query collection. + pub fn kill(&self, query_id: QueryId) -> Result { + let mut queries = self.queries.inner.lock().unwrap(); + let Some(state) = queries.remove(&query_id) else { + return Err(QueryKillStatus::NoSuchQuery(query_id)); + }; + + if let QueryState::Running(handle) = state { + handle.join_handle.abort(); + } + + Ok(QueryKilled(query_id)) + } +} + +#[derive(Clone, Serialize)] +pub struct QueryKilled(pub QueryId); + +#[derive(thiserror::Error, Debug)] +pub enum QueryKillStatus { + #[error("failed to kill a query: {0} does not exist.")] + NoSuchQuery(QueryId), } #[cfg(all(test, unit_test))] @@ -549,6 +586,105 @@ mod tests { } } + mod kill { + use std::sync::Arc; + + use crate::{ + executor::IpaRuntime, + ff::FieldType, + helpers::{ + query::{ + QueryConfig, + QueryType::{TestAddInPrimeField, TestMultiply}, + }, + HandlerBox, HelperIdentity, InMemoryMpcNetwork, Transport, + }, + protocol::QueryId, + query::{ + processor::{tests::respond_ok, Processor}, + state::{QueryState, RunningQuery}, + QueryKillStatus, + }, + test_executor::run, + }; + + #[test] + fn non_existent_query() { + run(|| async { + let processor = Processor::default(); + assert!(matches!( + processor.kill(QueryId), + Err(QueryKillStatus::NoSuchQuery(QueryId)) + )); + }); + } + + #[test] + fn existing_query() { + run(|| async move { + let h2 = respond_ok(); + let h3 = respond_ok(); + let network = InMemoryMpcNetwork::new([ + None, + Some(HandlerBox::owning_ref(&h2)), + Some(HandlerBox::owning_ref(&h3)), + ]); + let identities = HelperIdentity::make_three(); + let processor = Processor::default(); + let transport = network.transport(identities[0]); + processor + .new_query( + Transport::clone_ref(&transport), + QueryConfig::new(TestMultiply, FieldType::Fp31, 1).unwrap(), + ) + .await + .unwrap(); + + processor.kill(QueryId).unwrap(); + + // start query again - it should work because the query was killed + processor + .new_query( + transport, + QueryConfig::new(TestAddInPrimeField, FieldType::Fp32BitPrime, 1).unwrap(), + ) + .await + .unwrap(); + }); + } + + #[test] + fn aborts_protocol_task() { + run(|| async move { + let processor = Processor::default(); + let (_tx, rx) = tokio::sync::oneshot::channel(); + let counter = Arc::new(1); + let task = IpaRuntime::current().spawn({ + let counter = Arc::clone(&counter); + async move { + loop { + tokio::task::yield_now().await; + let _ = *counter.as_ref(); + } + } + }); + processor.queries.inner.lock().unwrap().insert( + QueryId, + QueryState::Running(RunningQuery { + result: rx, + join_handle: task, + }), + ); + + assert_eq!(2, Arc::strong_count(&counter)); + processor.kill(QueryId).unwrap(); + while Arc::strong_count(&counter) > 1 { + tokio::task::yield_now().await; + } + }); + } + } + mod e2e { use std::time::Duration; @@ -677,7 +813,7 @@ mod tests { QueryConfig { size: record_count.try_into().unwrap(), field_type: FieldType::Fp31, - query_type: QueryType::OprfIpa(IpaQueryConfig { + query_type: QueryType::SemiHonestOprfIpa(IpaQueryConfig { per_user_credit_cap: 8, max_breakdown_key: 3, attribution_window_seconds: None, diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs new file mode 100644 index 000000000..06cc2da4a --- /dev/null +++ b/ipa-core/src/query/runner/hybrid.rs @@ -0,0 +1,442 @@ +use std::{convert::Into, marker::PhantomData, sync::Arc}; + +use futures::{stream::iter, StreamExt, TryStreamExt}; + +use crate::{ + error::Error, + ff::{ + boolean_array::{BooleanArray, BA20, BA3, BA8}, + U128Conversions, + }, + helpers::{ + query::{DpMechanism, HybridQueryParams, QuerySize}, + BodyStream, LengthDelimitedStream, + }, + hpke::PrivateKeyRegistry, + protocol::{ + context::{ShardedContext, UpgradableContext}, + hybrid::{hybrid_protocol, step::HybridStep}, + ipa_prf::{oprf_padding::PaddingParameters, shuffle::Shuffle}, + step::ProtocolStep::Hybrid, + }, + query::runner::reshard_tag::reshard_aad, + report::hybrid::{ + EncryptedHybridReport, IndistinguishableHybridReport, UniqueTag, UniqueTagValidator, + }, + secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, +}; + +#[allow(dead_code)] +pub struct Query { + config: HybridQueryParams, + key_registry: Arc, + phantom_data: PhantomData<(C, HV)>, +} + +#[allow(dead_code)] +impl Query { + pub fn new(query_params: HybridQueryParams, key_registry: Arc) -> Self { + Self { + config: query_params, + key_registry, + phantom_data: PhantomData, + } + } +} + +impl Query +where + C: UpgradableContext + Shuffle + ShardedContext, + HV: BooleanArray + U128Conversions, + R: PrivateKeyRegistry, +{ + #[tracing::instrument("hybrid_query", skip_all, fields(sz=%query_size))] + pub async fn execute( + self, + ctx: C, + query_size: QuerySize, + input_stream: BodyStream, + ) -> Result>, Error> { + let Self { + config, + key_registry, + phantom_data: _, + } = self; + + tracing::info!("New hybrid query: {config:?}"); + let ctx = ctx.narrow(&Hybrid); + let sz = usize::from(query_size); + + if config.plaintext_match_keys { + return Err(Error::Unsupported( + "Hybrid queries do not currently support plaintext match keys".to_string(), + )); + } + + let stream = LengthDelimitedStream::::new(input_stream) + .map_err(Into::::into) + .map_ok(|enc_reports| { + iter(enc_reports.into_iter().map({ + |enc_report| { + let dec_report = enc_report + .decrypt::(key_registry.as_ref()) + .map_err(Into::::into); + let unique_tag = UniqueTag::from_unique_bytes(&enc_report); + dec_report.map(|dec_report1| (dec_report1, unique_tag)) + } + })) + }) + .try_flatten() + .take(sz); + let (decrypted_reports, resharded_tags) = reshard_aad( + ctx.narrow(&HybridStep::ReshardByTag), + stream, + |ctx, _, tag| tag.shard_picker(ctx.shard_count()), + ) + .await?; + + // this should use ? but until this returns a result, + //we want to capture the panic for the test + let mut unique_encrypted_hybrid_reports = UniqueTagValidator::new(resharded_tags.len()); + unique_encrypted_hybrid_reports + .check_duplicates(&resharded_tags) + .unwrap(); + + let indistinguishable_reports: Vec> = + decrypted_reports.into_iter().map(Into::into).collect(); + + let dp_params: DpMechanism = match config.with_dp { + 0 => DpMechanism::NoDp, + _ => DpMechanism::DiscreteLaplace { + epsilon: config.epsilon, + }, + }; + + #[cfg(feature = "relaxed-dp")] + let padding_params = PaddingParameters::relaxed(); + #[cfg(not(feature = "relaxed-dp"))] + let padding_params = PaddingParameters::default(); + + match config.per_user_credit_cap { + 1 => hybrid_protocol::<_, BA8, BA3, HV, 1, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 2 | 4 => hybrid_protocol::<_, BA8, BA3, HV, 2, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 8 => hybrid_protocol::<_, BA8, BA3, HV, 3, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 16 => hybrid_protocol::<_, BA8, BA3, HV, 4, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 32 => hybrid_protocol::<_, BA8, BA3, HV, 5, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 64 => hybrid_protocol::<_, BA8, BA3, HV, 6, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 128 => hybrid_protocol::<_, BA8, BA3, HV, 7, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + _ => panic!( + "Invalid value specified for per-user cap: {:?}. Must be one of 1, 2, 4, 8, 16, 32, 64, or 128.", + config.per_user_credit_cap + ), + } + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::{iter::zip, sync::Arc}; + + use rand::rngs::StdRng; + use rand_core::SeedableRng; + + use crate::{ + ff::{ + boolean_array::{BA16, BA20, BA3, BA8}, + U128Conversions, + }, + helpers::{ + query::{HybridQueryParams, QuerySize}, + BodyStream, + }, + hpke::{KeyPair, KeyRegistry}, + query::runner::hybrid::Query as HybridQuery, + report::{OprfReport, DEFAULT_KEY_ID}, + secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, + test_fixture::{ + flatten3v, ipa::TestRawDataRecord, Reconstruct, RoundRobinInputDistribution, TestWorld, + TestWorldConfig, WithShards, + }, + }; + + const EXPECTED: &[u128] = &[0, 8, 5]; + + fn build_records() -> Vec { + // TODO: When Encryption/Decryption exists for HybridReports + // update these to use that, rather than generating OprfReports + vec![ + TestRawDataRecord { + timestamp: 0, + user_id: 12345, + is_trigger_report: false, + breakdown_key: 2, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 4, + user_id: 68362, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 10, + user_id: 12345, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 5, + }, + TestRawDataRecord { + timestamp: 12, + user_id: 68362, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 2, + }, + TestRawDataRecord { + timestamp: 20, + user_id: 68362, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 30, + user_id: 68362, + is_trigger_report: true, + breakdown_key: 1, + trigger_value: 7, + }, + ] + } + + struct BufferAndKeyRegistry { + buffers: [Vec>; 3], + key_registry: Arc>, + query_sizes: Vec, + } + + fn build_buffers_from_records(records: &[TestRawDataRecord], s: usize) -> BufferAndKeyRegistry { + let mut rng = StdRng::seed_from_u64(42); + let key_id = DEFAULT_KEY_ID; + let key_registry = Arc::new(KeyRegistry::::random(1, &mut rng)); + + let mut buffers: [_; 3] = std::array::from_fn(|_| vec![Vec::new(); s]); + let shares: [Vec>; 3] = records.iter().cloned().share(); + for (buf, shares) in zip(&mut buffers, shares) { + for (i, share) in shares.into_iter().enumerate() { + share + .delimited_encrypt_to(key_id, key_registry.as_ref(), &mut rng, &mut buf[i % s]) + .unwrap(); + } + } + + let total_query_size = records.len(); + let base_size = total_query_size / s; + let remainder = total_query_size % s; + let query_sizes: Vec<_> = (0..s) + .map(|i| { + if i < remainder { + base_size + 1 + } else { + base_size + } + }) + .map(|size| QuerySize::try_from(size).unwrap()) + .collect(); + + BufferAndKeyRegistry { + buffers, + key_registry, + query_sizes, + } + } + + #[tokio::test] + // placeholder until the protocol is complete. can be updated to make sure we + // get to the unimplemented() call + #[should_panic( + expected = "not implemented: protocol::hybrid::hybrid_protocol is not fully implemented" + )] + async fn encrypted_hybrid_reports() { + // While this test currently checks for an unimplemented panic it is + // designed to test for a correct result for a complete implementation. + + const SHARDS: usize = 2; + let records = build_records(); + + let BufferAndKeyRegistry { + buffers, + key_registry, + query_sizes, + } = build_buffers_from_records(&records, SHARDS); + + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let contexts = world.contexts(); + + #[allow(clippy::large_futures)] + let results = flatten3v(buffers.into_iter().zip(contexts).map( + |(helper_buffers, helper_ctxs)| { + helper_buffers + .into_iter() + .zip(helper_ctxs) + .zip(query_sizes.clone()) + .map(|((buffer, ctx), query_size)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: false, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + }) + }, + )) + .await; + + let results: Vec<[Vec>; 3]> = results + .chunks(3) + .map(|chunk| { + [ + chunk[0].as_ref().unwrap().clone(), + chunk[1].as_ref().unwrap().clone(), + chunk[2].as_ref().unwrap().clone(), + ] + }) + .collect(); + + assert_eq!( + results.into_iter().next().unwrap().reconstruct()[0..3] + .iter() + .map(U128Conversions::as_u128) + .collect::>(), + EXPECTED + ); + } + + // cannot test for Err directly because join3v calls unwrap. This should be sufficient. + #[tokio::test] + #[should_panic(expected = "DuplicateBytes")] + async fn duplicate_encrypted_hybrid_reports() { + const SHARDS: usize = 2; + let records = build_records(); + + let BufferAndKeyRegistry { + mut buffers, + key_registry, + query_sizes, + } = build_buffers_from_records(&records, SHARDS); + + // this is double, since we duplicate the data below + let query_sizes = query_sizes + .into_iter() + .map(|query_size| QuerySize::try_from(usize::from(query_size) * 2).unwrap()) + .collect::>(); + + // duplicate all the data across shards + + for helper_buffers in &mut buffers { + // Get the last shard buffer to use for the first shard buffer extension + let last_shard_buffer = helper_buffers.last().unwrap().clone(); + let len = helper_buffers.len(); + for i in 0..len { + if i > 0 { + let previous = &helper_buffers[i - 1].clone(); + helper_buffers[i].extend_from_slice(previous); + } else { + helper_buffers[i].extend_from_slice(&last_shard_buffer); + } + } + } + + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let contexts = world.contexts(); + + #[allow(clippy::large_futures)] + let results = flatten3v(buffers.into_iter().zip(contexts).map( + |(helper_buffers, helper_ctxs)| { + helper_buffers + .into_iter() + .zip(helper_ctxs) + .zip(query_sizes.clone()) + .map(|((buffer, ctx), query_size)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: false, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + }) + }, + )) + .await; + + results.into_iter().map(|r| r.unwrap()).for_each(drop); + } + + // cannot test for Err directly because join3v calls unwrap. This should be sufficient. + #[tokio::test] + #[should_panic( + expected = "Unsupported(\"Hybrid queries do not currently support plaintext match keys\")" + )] + async fn unsupported_plaintext_match_keys_hybrid_query() { + const SHARDS: usize = 2; + let records = build_records(); + + let BufferAndKeyRegistry { + buffers, + key_registry, + query_sizes, + } = build_buffers_from_records(&records, SHARDS); + + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let contexts = world.contexts(); + + #[allow(clippy::large_futures)] + let results = flatten3v(buffers.into_iter().zip(contexts).map( + |(helper_buffers, helper_ctxs)| { + helper_buffers + .into_iter() + .zip(helper_ctxs) + .zip(query_sizes.clone()) + .map(|((buffer, ctx), query_size)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: true, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + }) + }, + )) + .await; + + results.into_iter().map(|r| r.unwrap()).for_each(drop); + } +} diff --git a/ipa-core/src/query/runner/mod.rs b/ipa-core/src/query/runner/mod.rs index 4c7240cbb..3f1b59f55 100644 --- a/ipa-core/src/query/runner/mod.rs +++ b/ipa-core/src/query/runner/mod.rs @@ -1,6 +1,8 @@ #[cfg(any(test, feature = "cli", feature = "test-fixture"))] mod add_in_prime_field; +mod hybrid; mod oprf_ipa; +mod reshard_tag; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] mod test_multiply; diff --git a/ipa-core/src/query/runner/oprf_ipa.rs b/ipa-core/src/query/runner/oprf_ipa.rs index 320d2246b..11846c86c 100644 --- a/ipa-core/src/query/runner/oprf_ipa.rs +++ b/ipa-core/src/query/runner/oprf_ipa.rs @@ -8,6 +8,8 @@ use crate::{ ff::{ boolean::Boolean, boolean_array::{BooleanArray, BA20, BA3, BA8}, + curve_points::RP25519, + ec_prime_field::Fp25519, Field, Serializable, U128Conversions, }, helpers::{ @@ -16,26 +18,31 @@ use crate::{ }, hpke::PrivateKeyRegistry, protocol::{ - basics::ShareKnownValue, - context::{Context, SemiHonestContext}, - ipa_prf::{oprf_ipa, oprf_padding::PaddingParameters, OPRFIPAInputRow}, + basics::{BooleanArrayMul, Reveal, ShareKnownValue}, + context::{DZKPUpgraded, MacUpgraded, UpgradableContext}, + ipa_prf::{ + oprf_ipa, oprf_padding::PaddingParameters, prf_eval::PrfSharing, shuffle::Shuffle, + OPRFIPAInputRow, AGG_CHUNK, CONV_CHUNK, PRF_CHUNK, SORT_CHUNK, + }, + prss::FromPrss, step::ProtocolStep::IpaPrf, + BooleanProtocols, }, report::{EncryptedOprfReport, EventType}, secret_sharing::{ replicated::semi_honest::{AdditiveShare as Replicated, AdditiveShare}, - BitDecomposed, SharedValue, TransposeFrom, + BitDecomposed, SharedValue, TransposeFrom, Vectorizable, }, sync::Arc, }; -pub struct OprfIpaQuery<'a, HV, R: PrivateKeyRegistry> { +pub struct OprfIpaQuery { config: IpaQueryConfig, key_registry: Arc, - phantom_data: PhantomData<&'a HV>, + phantom_data: PhantomData<(C, HV)>, } -impl<'a, HV, R: PrivateKeyRegistry> OprfIpaQuery<'a, HV, R> { +impl OprfIpaQuery { pub fn new(config: IpaQueryConfig, key_registry: Arc) -> Self { Self { config, @@ -46,11 +53,25 @@ impl<'a, HV, R: PrivateKeyRegistry> OprfIpaQuery<'a, HV, R> { } #[allow(clippy::too_many_lines)] -impl<'ctx, HV, R> OprfIpaQuery<'ctx, HV, R> +impl OprfIpaQuery where + C: UpgradableContext + Shuffle, HV: BooleanArray + U128Conversions, R: PrivateKeyRegistry, - Replicated: Serializable + ShareKnownValue, Boolean>, + Replicated: Serializable + ShareKnownValue, + Replicated: BooleanProtocols>, + Replicated: BooleanProtocols, 256>, + Replicated: BooleanProtocols, AGG_CHUNK>, + Replicated: BooleanProtocols, CONV_CHUNK>, + Replicated: BooleanProtocols, SORT_CHUNK>, + Replicated: + PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, + Replicated: + Reveal, Output = >::Array>, + Replicated: BooleanArrayMul> + + Reveal, Output = >::Array>, + Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul>, Vec>: for<'a> TransposeFrom<&'a BitDecomposed>, Error = LengthError>, BitDecomposed>: @@ -59,7 +80,7 @@ where #[tracing::instrument("oprf_ipa_query", skip_all, fields(sz=%query_size))] pub async fn execute( self, - ctx: SemiHonestContext<'ctx>, + ctx: C, query_size: QuerySize, input_stream: BodyStream, ) -> Result>, Error> { @@ -122,18 +143,20 @@ where }, }; - #[cfg(any(test, feature = "cli", feature = "test-fixture"))] + #[cfg(feature = "relaxed-dp")] let padding_params = PaddingParameters::relaxed(); - #[cfg(not(any(test, feature = "cli", feature = "test-fixture")))] + #[cfg(not(feature = "relaxed-dp"))] let padding_params = PaddingParameters::default(); match config.per_user_credit_cap { - 8 => oprf_ipa::(ctx, input, aws, dp_params, padding_params).await, - 16 => oprf_ipa::(ctx, input, aws, dp_params, padding_params).await, - 32 => oprf_ipa::(ctx, input, aws, dp_params, padding_params).await, - 64 => oprf_ipa::(ctx, input, aws, dp_params, padding_params).await, - 128 => oprf_ipa::(ctx, input, aws, dp_params, padding_params).await, + 1 => oprf_ipa::<_, BA8, BA3, HV, BA20, 1, 256>(ctx, input, aws, dp_params, padding_params).await, + 2 | 4 => oprf_ipa::<_, BA8, BA3, HV, BA20, 2, 256>(ctx, input, aws, dp_params, padding_params).await, + 8 => oprf_ipa::<_, BA8, BA3, HV, BA20, 3, 256>(ctx, input, aws, dp_params, padding_params).await, + 16 => oprf_ipa::<_, BA8, BA3, HV, BA20, 4, 256>(ctx, input, aws, dp_params, padding_params).await, + 32 => oprf_ipa::<_, BA8, BA3, HV, BA20, 5, 256>(ctx, input, aws, dp_params, padding_params).await, + 64 => oprf_ipa::<_, BA8, BA3, HV, BA20, 6, 256>(ctx, input, aws, dp_params, padding_params).await, + 128 => oprf_ipa::<_, BA8, BA3, HV, BA20, 7, 256>(ctx, input, aws, dp_params, padding_params).await, _ => panic!( - "Invalid value specified for per-user cap: {:?}. Must be one of 8, 16, 32, 64, or 128.", + "Invalid value specified for per-user cap: {:?}. Must be one of 1, 2, 4, 8, 16, 32, 64, or 128.", config.per_user_credit_cap ), } @@ -243,8 +266,11 @@ mod tests { }; let input = BodyStream::from(buffer); - OprfIpaQuery::>::new(query_config, Arc::clone(&key_registry)) - .execute(ctx, query_size, input) + OprfIpaQuery::<_, BA16, KeyRegistry>::new( + query_config, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) })) .await; diff --git a/ipa-core/src/query/runner/reshard_tag.rs b/ipa-core/src/query/runner/reshard_tag.rs new file mode 100644 index 000000000..5d1c3b8f5 --- /dev/null +++ b/ipa-core/src/query/runner/reshard_tag.rs @@ -0,0 +1,149 @@ +use std::{ + pin::{pin, Pin}, + task::{Context, Poll}, +}; + +use futures::{ready, Stream}; +use pin_project::pin_project; + +use crate::{ + error::Error, + helpers::Message, + protocol::{ + context::{reshard_try_stream, ShardedContext}, + RecordId, + }, + sharding::ShardIndex, +}; + +type DataWithTag = Result<(D, A), Error>; + +/// Helper function to work with inputs to hybrid queries. Each encryption needs +/// to be checked for uniqueness and we use AAD tag for that. While reports are +/// being collected, AAD tags need to be resharded. This function does both at the same +/// time which should reduce the perceived latency of queries. +/// +/// The output contains two separate collections: one for data and another one +/// for AAD tags that are "owned" by this shard. The tags can later be checked for +/// uniqueness. +/// +/// ## Errors +/// This will return an error, if input stream contains at least one `Err` element. +#[allow(dead_code)] +pub async fn reshard_aad( + ctx: C, + input: L, + shard_picker: S, +) -> Result<(Vec, Vec), crate::error::Error> +where + L: Stream>, + S: Fn(C, RecordId, &A) -> ShardIndex + Send, + A: Message + Clone, + C: ShardedContext, +{ + let mut k_buf = Vec::with_capacity(input.size_hint().1.unwrap_or(0)); + let splitter = StreamSplitter { + inner: input, + buf: &mut k_buf, + }; + let a_buf = reshard_try_stream(ctx, splitter, shard_picker).await?; + + Ok((k_buf, a_buf)) +} + +/// Takes a fallible input stream that yields a tuple `(K, A)` and produces a new stream +/// over `A` while collecting `K` elements into the provided buffer. +/// Any error encountered from the input stream is propagated. +#[pin_project] +struct StreamSplitter<'a, S: Stream>, K, A> { + #[pin] + inner: S, + buf: &'a mut Vec, +} + +impl>, K, A> Stream for StreamSplitter<'_, S, K, A> { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + match ready!(this.inner.poll_next(cx)) { + Some(Ok((k, a))) => { + this.buf.push(k); + Poll::Ready(Some(Ok(a))) + } + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use futures::{stream, StreamExt}; + + use crate::{ + error::Error, + ff::{boolean_array::BA8, U128Conversions}, + query::runner::reshard_tag::reshard_aad, + secret_sharing::SharedValue, + sharding::{ShardConfiguration, ShardIndex}, + test_executor::run, + test_fixture::{Runner, TestWorld, TestWorldConfig, WithShards}, + }; + + #[test] + fn reshard_basic() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest( + vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(), + |ctx, input| async move { + let shard_id = ctx.shard_id(); + let sz = input.len(); + let (values, tags) = reshard_aad( + ctx, + stream::iter(input).map(|v| Ok((v, BA8::ZERO))), + |_, _, _| ShardIndex::FIRST, + ) + .await + .unwrap(); + assert_eq!(sz, values.len()); + match shard_id { + ShardIndex::FIRST => assert_eq!(2, tags.len()), + _ => assert_eq!(0, tags.len()), + } + }, + ) + .await; + }); + } + + #[test] + #[should_panic(expected = "InconsistentShares")] + fn reshard_err() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest( + vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(), + |ctx, input| async move { + reshard_aad( + ctx, + stream::iter(input) + .map(|_| Err::<(BA8, BA8), _>(Error::InconsistentShares)), + |_, _, _| ShardIndex::FIRST, + ) + .await + .unwrap(); + }, + ) + .await; + }); + } +} diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index 3c4359ca9..460296022 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -10,16 +10,15 @@ use futures::{ready, FutureExt}; use serde::{Deserialize, Serialize}; use crate::{ + executor::IpaJoinHandle, helpers::{query::QueryConfig, RoleAssignment}, protocol::QueryId, query::runner::QueryResult, sync::Mutex, - task::JoinHandle, }; /// The status of query processing #[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] -#[allow(dead_code)] pub enum QueryStatus { /// Only query running on the coordinator helper can be in this state. Means that coordinator /// sent out requests to other helpers and asked them to assume a given role for this query. @@ -87,7 +86,7 @@ pub struct RunningQuery { /// /// We could return the result via the `JoinHandle`, except that we want to check the status /// of the task, and shuttle doesn't implement `JoinHandle::is_finished`. - pub join_handle: JoinHandle<()>, + pub join_handle: IpaJoinHandle<()>, } impl RunningQuery { diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs new file mode 100644 index 000000000..62d66a797 --- /dev/null +++ b/ipa-core/src/report/hybrid.rs @@ -0,0 +1,855 @@ +//! Provides report types which are aggregated by the Hybrid protocol +//! +//! The `IndistinguishableHybridReport` is the primary data type which each helpers uses +//! to aggreate in the Hybrid protocol. +//! +//! From each Helper's POV, the Report Collector POSTs a length delimited byte +//! stream, which is then processed as follows: +//! +//! `BodyStream` → `EncryptedHybridReport` → `HybridReport` → `IndistinguishableHybridReport` +//! +//! The difference between a `HybridReport` and a `IndistinguishableHybridReport` is that a +//! a `HybridReport` is an `enum` with two possible options: `Impression` and `Conversion`. +//! These two options are implemented as `HybridImpressionReport` and `HybridConversionReport`. +//! A `IndistinguishableHybridReport` contains the union of the fields across +//! `HybridImpressionReport` and `HybridConversionReport`. Those fields are secret sharings, +//! which allows for building a collection of `IndistinguishableHybridReport` which carry +//! the information of the underlying `HybridImpressionReport` and `HybridConversionReport` +//! (and secret sharings of zero in the fields unique to each report type) without the +//! ability to infer if a given report is a `HybridImpressionReport` +//! or a `HybridConversionReport`. + +//! Note: immediately following convertion of a `HybridReport` into a +//! `IndistinguishableHybridReport`, each helper will know which type it was built from, +//! both from the position in the collection as well as the fact that both replicated +//! secret shares for one or more fields are zero. A shuffle is required to delink +//! a `IndistinguishableHybridReport`'s position in a collection, which also rerandomizes +//! all secret sharings (including the sharings of zero), making the collection of reports +//! cryptographically indistinguishable. + +use std::{ + collections::HashSet, + convert::Infallible, + marker::PhantomData, + ops::{Add, Deref}, +}; + +use bytes::{BufMut, Bytes}; +use generic_array::{ArrayLength, GenericArray}; +use hpke::Serializable as _; +use rand_core::{CryptoRng, RngCore}; +use typenum::{Sum, Unsigned, U16}; + +use crate::{ + const_assert_eq, + error::{BoxError, Error}, + ff::{boolean_array::BA64, Serializable}, + hpke::{ + open_in_place, seal_in_place, CryptError, EncapsulationSize, PrivateKeyRegistry, + PublicKeyRegistry, TagSize, + }, + report::{ + hybrid_info::HybridImpressionInfo, EncryptedOprfReport, EventType, InvalidReportError, + KeyIdentifier, + }, + secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, SharedValue}, + sharding::ShardIndex, +}; + +// TODO(679): This needs to come from configuration. +static HELPER_ORIGIN: &str = "github.com/private-attribution"; + +#[derive(Debug, thiserror::Error)] +#[error("string contains non-ascii symbols: {0}")] +pub struct NonAsciiStringError(String); + +impl From<&'_ str> for NonAsciiStringError { + fn from(input: &str) -> Self { + Self(input.to_string()) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum InvalidHybridReportError { + #[error("bad site_domain: {0}")] + NonAsciiString(#[from] NonAsciiStringError), + #[error("en/decryption failure: {0}")] + Crypt(#[from] CryptError), + #[error("failed to deserialize field {0}: {1}")] + DeserializationError(&'static str, #[source] BoxError), + #[error("report is too short: {0}, expected length at least: {1}")] + Length(usize, usize), +} + +/// Reports for impression events are represented here. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct HybridImpressionReport +where + BK: SharedValue, +{ + match_key: Replicated, + breakdown_key: Replicated, +} + +impl Serializable for HybridImpressionReport +where + BK: SharedValue, + Replicated: Serializable, + as Serializable>::Size: Add, + < as Serializable>::Size as Add< as Serializable>::Size>>:: Output: ArrayLength, +{ + type Size = < as Serializable>::Size as Add< as Serializable>::Size>>:: Output; + type DeserializationError = InvalidHybridReportError; + + fn serialize(&self, buf: &mut GenericArray) { + let mk_sz = as Serializable>::Size::USIZE; + let bk_sz = as Serializable>::Size::USIZE; + + self.match_key + .serialize(GenericArray::from_mut_slice(&mut buf[..mk_sz])); + + self.breakdown_key + .serialize(GenericArray::from_mut_slice(&mut buf[mk_sz..mk_sz + bk_sz])); + } + fn deserialize(buf: &GenericArray) -> Result { + let mk_sz = as Serializable>::Size::USIZE; + let bk_sz = as Serializable>::Size::USIZE; + let match_key = + Replicated::::deserialize(GenericArray::from_slice(&buf[..mk_sz])) + .map_err(|e| InvalidHybridReportError::DeserializationError("match_key", e.into()))?; + let breakdown_key = + Replicated::::deserialize(GenericArray::from_slice(&buf[mk_sz..mk_sz + bk_sz])) + .map_err(|e| InvalidHybridReportError::DeserializationError("breakdown_key", e.into()))?; + Ok(Self { match_key, breakdown_key }) + } +} + +impl HybridImpressionReport +where + BK: SharedValue, + Replicated: Serializable, + as Serializable>::Size: Add, + < as Serializable>::Size as Add< as Serializable>::Size>>:: Output: ArrayLength, +{ + const BTT_END: usize = as Serializable>::Size::USIZE; + + /// # Panics + /// If report length does not fit in `u16`. + pub fn encrypted_len(&self) -> u16 { + let len = EncryptedHybridImpressionReport::::SITE_DOMAIN_OFFSET; + len.try_into().unwrap() + } + /// # Errors + /// If there is a problem encrypting the report. + pub fn encrypt( + &self, + key_id: KeyIdentifier, + key_registry: &impl PublicKeyRegistry, + rng: &mut R, + ) -> Result, InvalidHybridReportError> { + let mut out = Vec::with_capacity(usize::from(self.encrypted_len())); + self.encrypt_to(key_id, key_registry, rng, &mut out)?; + debug_assert_eq!(out.len(), usize::from(self.encrypted_len())); + Ok(out) + } + + /// # Errors + /// If there is a problem encrypting the report. + pub fn encrypt_to( + &self, + key_id: KeyIdentifier, + key_registry: &impl PublicKeyRegistry, + rng: &mut R, + out: &mut B, + ) -> Result<(), InvalidHybridReportError> { + let info = HybridImpressionInfo::new(key_id, HELPER_ORIGIN)?; + + let mut plaintext_mk = GenericArray::default(); + self.match_key.serialize(&mut plaintext_mk); + + let mut plaintext_btt = vec![0u8; Self::BTT_END]; + self.breakdown_key + .serialize(GenericArray::from_mut_slice(&mut plaintext_btt[..])); + + let pk = key_registry.public_key(key_id).ok_or(CryptError::NoSuchKey(key_id))?; + + let (encap_key_mk, ciphertext_mk, tag_mk) = seal_in_place( + pk, + plaintext_mk.as_mut(), + &info.to_bytes(), + rng, + )?; + + let (encap_key_btt, ciphertext_btt, tag_btt) = seal_in_place( + pk, + plaintext_btt.as_mut(), + &info.to_bytes(), + rng, + )?; + + out.put_slice(&encap_key_mk.to_bytes()); + out.put_slice(ciphertext_mk); + out.put_slice(&tag_mk.to_bytes()); + out.put_slice(&encap_key_btt.to_bytes()); + out.put_slice(ciphertext_btt); + out.put_slice(&tag_btt.to_bytes()); + out.put_slice(&[key_id]); + + Ok(()) + } +} + +/// Reports for conversion events are represented here. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct HybridConversionReport +where + V: SharedValue, +{ + match_key: Replicated, + value: Replicated, +} + +/// This enum contains both report types, impression and conversion. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum HybridReport +where + BK: SharedValue, + V: SharedValue, +{ + Impression(HybridImpressionReport), + Conversion(HybridConversionReport), +} + +impl HybridReport +where + BK: SharedValue, + V: SharedValue, +{ + /// # Errors + /// If there is a problem encrypting the report. + pub fn encrypt( + &self, + _key_id: KeyIdentifier, + _key_registry: &impl PublicKeyRegistry, + _rng: &mut R, + ) -> Result, InvalidReportError> { + unimplemented!() + } +} + +/// `HybridImpressionReport`s are encrypted when they arrive to the helpers, +/// which is represented here. A `EncryptedHybridImpressionReport` decrypts +/// into a `HybridImpressionReport`. +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct EncryptedHybridImpressionReport +where + B: Deref, + BK: SharedValue, +{ + data: B, + phantom_data: PhantomData, +} + +impl EncryptedHybridImpressionReport +where + B: Deref, + BK: SharedValue, + Replicated: Serializable, + as Serializable>::Size: Add, + < as Serializable>::Size as Add>::Output: ArrayLength, +{ + const ENCAP_KEY_MK_OFFSET: usize = 0; + const CIPHERTEXT_MK_OFFSET: usize = Self::ENCAP_KEY_MK_OFFSET + EncapsulationSize::USIZE; + const ENCAP_KEY_BTT_OFFSET: usize = (Self::CIPHERTEXT_MK_OFFSET + + TagSize::USIZE + + as Serializable>::Size::USIZE); + const CIPHERTEXT_BTT_OFFSET: usize = Self::ENCAP_KEY_BTT_OFFSET + EncapsulationSize::USIZE; + + const KEY_IDENTIFIER_OFFSET: usize = (Self::CIPHERTEXT_BTT_OFFSET + + TagSize::USIZE + + as Serializable>::Size::USIZE); + const SITE_DOMAIN_OFFSET: usize = Self::KEY_IDENTIFIER_OFFSET + 1; + + pub fn encap_key_mk(&self) -> &[u8] { + &self.data[Self::ENCAP_KEY_MK_OFFSET..Self::CIPHERTEXT_MK_OFFSET] + } + + pub fn mk_ciphertext(&self) -> &[u8] { + &self.data[Self::CIPHERTEXT_MK_OFFSET..Self::ENCAP_KEY_BTT_OFFSET] + } + + pub fn encap_key_btt(&self) -> &[u8] { + &self.data[Self::ENCAP_KEY_BTT_OFFSET..Self::CIPHERTEXT_BTT_OFFSET] + } + + pub fn btt_ciphertext(&self) -> &[u8] { + &self.data[Self::CIPHERTEXT_BTT_OFFSET..Self::KEY_IDENTIFIER_OFFSET] + } + + pub fn key_id(&self) -> KeyIdentifier { + self.data[Self::KEY_IDENTIFIER_OFFSET] + } + + /// ## Errors + /// If the report contents are invalid. + pub fn from_bytes(bytes: B) -> Result { + if bytes.len() < Self::SITE_DOMAIN_OFFSET { + return Err(InvalidHybridReportError::Length( + bytes.len(), + Self::SITE_DOMAIN_OFFSET, + )); + } + Ok(Self { + data: bytes, + phantom_data: PhantomData, + }) + } + + /// ## Errors + /// If the match key shares in the report cannot be decrypted (e.g. due to a + /// failure of the authenticated encryption). + /// ## Panics + /// Should not panic. Only panics if a `Report` constructor failed to validate the + /// contents properly, which would be a bug. + pub fn decrypt( + &self, + key_registry: &P, + ) -> Result, InvalidHybridReportError> { + type CTMKLength = Sum< as Serializable>::Size, TagSize>; + type CTBTTLength = < as Serializable>::Size as Add>::Output; + + let info = HybridImpressionInfo::new(self.key_id(), HELPER_ORIGIN).unwrap(); // validated on construction + + let mut ct_mk: GenericArray = + *GenericArray::from_slice(self.mk_ciphertext()); + let sk = key_registry + .private_key(self.key_id()) + .ok_or(CryptError::NoSuchKey(self.key_id()))?; + let plaintext_mk = open_in_place(sk, self.encap_key_mk(), &mut ct_mk, &info.to_bytes())?; + let mut ct_btt: GenericArray> = + GenericArray::from_slice(self.btt_ciphertext()).clone(); + + let plaintext_btt = open_in_place(sk, self.encap_key_btt(), &mut ct_btt, &info.to_bytes())?; + + Ok(HybridImpressionReport:: { + match_key: Replicated::::deserialize_infallible(GenericArray::from_slice( + plaintext_mk, + )), + breakdown_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_btt)) + .map_err(|e| { + InvalidHybridReportError::DeserializationError("is_trigger", e.into()) + })?, + }) + } +} + +/// This struct is designed to fit both `HybridConversionReport`s +/// and `HybridImpressionReport`s so that they can be made indistingushable. +/// Note: these need to be shuffled (and secret shares need to be rerandomized) +/// to provide any formal indistinguishability. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct IndistinguishableHybridReport +where + BK: SharedValue, + V: SharedValue, +{ + match_key: Replicated, + value: Replicated, + breakdown_key: Replicated, +} + +impl From> for IndistinguishableHybridReport +where + BK: SharedValue, + V: SharedValue, +{ + fn from(report: HybridReport) -> Self { + match report { + HybridReport::Impression(r) => r.into(), + HybridReport::Conversion(r) => r.into(), + } + } +} + +impl From> for IndistinguishableHybridReport +where + BK: SharedValue, + V: SharedValue, +{ + fn from(impression_report: HybridImpressionReport) -> Self { + Self { + match_key: impression_report.match_key, + value: Replicated::ZERO, + breakdown_key: impression_report.breakdown_key, + } + } +} + +impl From> for IndistinguishableHybridReport +where + BK: SharedValue, + V: SharedValue, +{ + fn from(conversion_report: HybridConversionReport) -> Self { + Self { + match_key: conversion_report.match_key, + value: conversion_report.value, + breakdown_key: Replicated::ZERO, + } + } +} + +#[derive(Clone)] +pub struct EncryptedHybridReport { + bytes: Bytes, +} + +impl EncryptedHybridReport { + /// ## Errors + /// If the report fails to decrypt + pub fn decrypt( + &self, + key_registry: &P, + ) -> Result, InvalidReportError> + where + P: PrivateKeyRegistry, + BK: SharedValue, + V: SharedValue, + TS: SharedValue, + Replicated: Serializable, + Replicated: Serializable, + Replicated: Serializable, + as Serializable>::Size: Add< as Serializable>::Size>, + Sum< as Serializable>::Size, as Serializable>::Size>: + Add< as Serializable>::Size>, + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >: Add, + Sum< + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >, + U16, + >: ArrayLength, + { + let encrypted_oprf_report = + EncryptedOprfReport::::try_from(self.bytes.clone())?; + let oprf_report = encrypted_oprf_report.decrypt(key_registry)?; + match oprf_report.event_type { + EventType::Source => Ok(HybridReport::Impression(HybridImpressionReport { + match_key: oprf_report.match_key, + breakdown_key: oprf_report.breakdown_key, + })), + EventType::Trigger => Ok(HybridReport::Conversion(HybridConversionReport { + match_key: oprf_report.match_key, + value: oprf_report.trigger_value, + })), + } + } + + /// TODO: update these when we produce a proper encapsulation of + /// `EncryptedHybridReport`, rather than pigggybacking on `EncryptedOprfReport` + pub fn mk_ciphertext(&self) -> &[u8] { + let encap_key_mk_offset: usize = 0; + let ciphertext_mk_offset: usize = encap_key_mk_offset + EncapsulationSize::USIZE; + let encap_key_btt_offset: usize = + ciphertext_mk_offset + TagSize::USIZE + as Serializable>::Size::USIZE; + + &self.bytes[ciphertext_mk_offset..encap_key_btt_offset] + } +} + +impl TryFrom for EncryptedHybridReport { + type Error = InvalidReportError; + + fn try_from(bytes: Bytes) -> Result { + Ok(EncryptedHybridReport { bytes }) + } +} + +const TAG_SIZE: usize = TagSize::USIZE; + +#[derive(Clone, Debug)] +pub struct UniqueTag { + bytes: [u8; TAG_SIZE], +} + +pub trait UniqueBytes { + fn unique_bytes(&self) -> [u8; TAG_SIZE]; +} + +impl UniqueBytes for UniqueTag { + fn unique_bytes(&self) -> [u8; TAG_SIZE] { + self.bytes + } +} + +impl UniqueBytes for EncryptedHybridReport { + /// We use the `TagSize` (the first 16 bytes of the ciphertext) for collision-detection + /// See [analysis here for uniqueness](https://eprint.iacr.org/2019/624) + fn unique_bytes(&self) -> [u8; TAG_SIZE] { + let slice = &self.mk_ciphertext()[0..TAG_SIZE]; + let mut array = [0u8; TAG_SIZE]; + array.copy_from_slice(slice); + array + } +} + +impl UniqueTag { + // Function to attempt to create a UniqueTag from a UniqueBytes implementor + pub fn from_unique_bytes(item: &T) -> Self { + const_assert_eq!(16, TAG_SIZE); + UniqueTag { + bytes: item.unique_bytes(), + } + } + + /// Maps the tag into a consistent shard. + /// + /// ## Panics + /// if the `TAG_SIZE != 16` + /// note: ~10 below this, we have a compile time check that `TAG_SIZE = 16` + #[must_use] + pub fn shard_picker(&self, shard_count: ShardIndex) -> ShardIndex { + let num = u128::from_le_bytes(self.bytes); + let shard_count = u128::from(shard_count); + ShardIndex::try_from(num % shard_count).expect("Modulo a u32 will fit in u32") + } +} + +impl Serializable for UniqueTag { + type Size = U16; // This must match TAG_SIZE + type DeserializationError = Infallible; + + fn serialize(&self, buf: &mut GenericArray) { + buf.copy_from_slice(&self.bytes); + } + fn deserialize(buf: &GenericArray) -> Result { + let mut bytes = [0u8; TAG_SIZE]; + bytes.copy_from_slice(buf.as_slice()); + Ok(UniqueTag { bytes }) + } +} + +#[derive(Debug)] +pub struct UniqueTagValidator { + hash_set: HashSet<[u8; TAG_SIZE]>, + check_counter: usize, +} + +impl UniqueTagValidator { + #[must_use] + pub fn new(size: usize) -> Self { + UniqueTagValidator { + hash_set: HashSet::with_capacity(size), + check_counter: 0, + } + } + fn insert(&mut self, value: [u8; TAG_SIZE]) -> bool { + self.hash_set.insert(value) + } + /// Checks that item is unique among all checked thus far + /// + /// ## Errors + /// if the item inserted is not unique among all checked thus far + pub fn check_duplicate(&mut self, item: &U) -> Result<(), Error> { + self.check_counter += 1; + if self.insert(item.unique_bytes()) { + Ok(()) + } else { + Err(Error::DuplicateBytes(self.check_counter)) + } + } + /// Checks that an iter of items is unique among the iter and any other items checked thus far + /// + /// ## Errors + /// if the and item inserted is not unique among all in this batch and checked previously + pub fn check_duplicates(&mut self, items: &[U]) -> Result<(), Error> { + items + .iter() + .try_for_each(|item| self.check_duplicate(item))?; + Ok(()) + } +} + +#[cfg(test)] +mod test { + + use rand::{distributions::Alphanumeric, rngs::ThreadRng, thread_rng, Rng}; + use typenum::Unsigned; + + use super::{ + EncryptedHybridImpressionReport, EncryptedHybridReport, GenericArray, + HybridConversionReport, HybridImpressionReport, HybridReport, + IndistinguishableHybridReport, UniqueTag, UniqueTagValidator, + }; + use crate::{ + error::Error, + ff::{ + boolean_array::{BA20, BA3, BA8}, + Serializable, + }, + hpke::{KeyPair, KeyRegistry}, + report::{ + hybrid::{NonAsciiStringError, BA64}, + hybrid_info::HybridImpressionInfo, + EventType, OprfReport, + }, + secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, + }; + + fn build_oprf_report(event_type: EventType, rng: &mut ThreadRng) -> OprfReport { + OprfReport:: { + match_key: AdditiveShare::new(rng.gen(), rng.gen()), + timestamp: AdditiveShare::new(rng.gen(), rng.gen()), + breakdown_key: AdditiveShare::new(rng.gen(), rng.gen()), + trigger_value: AdditiveShare::new(rng.gen(), rng.gen()), + event_type, + epoch: rng.gen(), + site_domain: (rng) + .sample_iter(Alphanumeric) + .map(char::from) + .take(10) + .collect(), + } + } + + fn generate_random_tag() -> UniqueTag { + let mut rng = thread_rng(); + let mut bytes = [0u8; 16]; + rng.fill(&mut bytes[..]); + UniqueTag { bytes } + } + + #[test] + fn convert_to_hybrid_impression_report() { + let mut rng = thread_rng(); + + let b = EventType::Source; + + let oprf_report = build_oprf_report(b, &mut rng); + let hybrid_report = HybridReport::Impression::(HybridImpressionReport:: { + match_key: oprf_report.match_key.clone(), + breakdown_key: oprf_report.breakdown_key.clone(), + }); + + let key_registry = KeyRegistry::::random(1, &mut rng); + let key_id = 0; + + let enc_report_bytes = oprf_report + .encrypt(key_id, &key_registry, &mut rng) + .unwrap(); + let enc_report = EncryptedHybridReport { + bytes: enc_report_bytes.into(), + }; + + let hybrid_report2 = enc_report + .decrypt::<_, BA8, BA3, BA20>(&key_registry) + .unwrap(); + + assert_eq!(hybrid_report, hybrid_report2); + } + + #[test] + fn convert_to_hybrid_conversion_report() { + let mut rng = thread_rng(); + + let b = EventType::Trigger; + + let oprf_report = build_oprf_report(b, &mut rng); + let hybrid_report = HybridReport::Conversion::(HybridConversionReport:: { + match_key: oprf_report.match_key.clone(), + value: oprf_report.trigger_value.clone(), + }); + + let key_registry = KeyRegistry::::random(1, &mut rng); + let key_id = 0; + + let enc_report_bytes = oprf_report + .encrypt(key_id, &key_registry, &mut rng) + .unwrap(); + let enc_report = EncryptedHybridReport { + bytes: enc_report_bytes.into(), + }; + let hybrid_report2 = enc_report + .decrypt::<_, BA8, BA3, BA20>(&key_registry) + .unwrap(); + + assert_eq!(hybrid_report, hybrid_report2); + } + + /// We create a random `HybridConversionReport`, convert it into an + ///`IndistinguishableHybridReport`, and check that the field values are the same + /// (or zero, for the breakdown key, which doesn't exist on the conversion report.) + /// We then build a generic `HybridReport` from the conversion report, convert it + /// into an `IndistingushableHybridReport`, and validate that it has the same value + /// as the previous `IndistingushableHybridReport`. + #[test] + fn convert_hybrid_conversion_report_to_indistinguishable_report() { + let mut rng = thread_rng(); + + let conversion_report = HybridConversionReport:: { + match_key: AdditiveShare::new(rng.gen(), rng.gen()), + value: AdditiveShare::new(rng.gen(), rng.gen()), + }; + let indistinguishable_report: IndistinguishableHybridReport = + conversion_report.clone().into(); + assert_eq!( + conversion_report.match_key, + indistinguishable_report.match_key + ); + assert_eq!(conversion_report.value, indistinguishable_report.value); + assert_eq!(AdditiveShare::ZERO, indistinguishable_report.breakdown_key); + + let hybrid_report = HybridReport::Conversion::(conversion_report.clone()); + let indistinguishable_report2: IndistinguishableHybridReport = + hybrid_report.clone().into(); + assert_eq!(indistinguishable_report, indistinguishable_report2); + } + + /// We create a random `HybridImpressionReport`, convert it into an + ///`IndistinguishableHybridReport`, and check that the field values are the same + /// (or zero, for the value, which doesn't exist on the impression report.) + /// We then build a generic `HybridReport` from the impression report, convert it + /// into an `IndistingushableHybridReport`, and validate that it has the same value + /// as the previous `IndistingushableHybridReport`. + #[test] + fn convert_hybrid_impression_report_to_indistinguishable_report() { + let mut rng = thread_rng(); + + let impression_report = HybridImpressionReport:: { + match_key: AdditiveShare::new(rng.gen(), rng.gen()), + breakdown_key: AdditiveShare::new(rng.gen(), rng.gen()), + }; + let indistinguishable_report: IndistinguishableHybridReport = + impression_report.clone().into(); + assert_eq!( + impression_report.match_key, + indistinguishable_report.match_key + ); + assert_eq!(AdditiveShare::ZERO, indistinguishable_report.value); + assert_eq!( + impression_report.breakdown_key, + indistinguishable_report.breakdown_key + ); + + let hybrid_report = HybridReport::Impression::(impression_report.clone()); + let indistinguishable_report2: IndistinguishableHybridReport = + hybrid_report.clone().into(); + assert_eq!(indistinguishable_report, indistinguishable_report2); + } + + #[test] + fn unique_encrypted_hybrid_reports() { + let tag1 = generate_random_tag(); + let tag2 = generate_random_tag(); + let tag3 = generate_random_tag(); + let tag4 = generate_random_tag(); + + let mut unique_bytes = UniqueTagValidator::new(4); + + unique_bytes.check_duplicate(&tag1).unwrap(); + + unique_bytes + .check_duplicates(&[tag2.clone(), tag3.clone()]) + .unwrap(); + let expected_err = unique_bytes.check_duplicate(&tag2); + assert!(matches!(expected_err, Err(Error::DuplicateBytes(4)))); + + let expected_err = unique_bytes.check_duplicates(&[tag4, tag3]); + assert!(matches!(expected_err, Err(Error::DuplicateBytes(6)))); + } + + #[test] + fn serialization_hybrid_impression() { + let mut rng = thread_rng(); + let b = EventType::Source; + let oprf_report = build_oprf_report(b, &mut rng); + + let hybrid_impression_report = HybridImpressionReport:: { + match_key: oprf_report.match_key.clone(), + breakdown_key: oprf_report.breakdown_key.clone(), + }; + let mut hybrid_impression_report_bytes = + [0u8; as Serializable>::Size::USIZE]; + hybrid_impression_report.serialize(GenericArray::from_mut_slice( + &mut hybrid_impression_report_bytes[..], + )); + let hybrid_impression_report2 = HybridImpressionReport::::deserialize( + GenericArray::from_mut_slice(&mut hybrid_impression_report_bytes[..]), + ) + .unwrap(); + assert_eq!(hybrid_impression_report, hybrid_impression_report2); + } + + #[test] + fn deserialzation_from_constant() { + let hybrid_report = HybridImpressionReport::::deserialize(GenericArray::from_slice( + &hex::decode("4123a6e38ef1d6d9785c948797cb744d38f4").unwrap(), + )) + .unwrap(); + + let match_key = AdditiveShare::::deserialize(GenericArray::from_slice( + &hex::decode("4123a6e38ef1d6d9785c948797cb744d").unwrap(), + )) + .unwrap(); + let breakdown_key = AdditiveShare::::deserialize(GenericArray::from_slice( + &hex::decode("38f4").unwrap(), + )) + .unwrap(); + + assert_eq!( + hybrid_report, + HybridImpressionReport:: { + match_key, + breakdown_key + } + ); + + let mut hybrid_impression_report_bytes = + [0u8; as Serializable>::Size::USIZE]; + hybrid_report.serialize(GenericArray::from_mut_slice( + &mut hybrid_impression_report_bytes[..], + )); + + assert_eq!( + hybrid_impression_report_bytes.to_vec(), + hex::decode("4123a6e38ef1d6d9785c948797cb744d38f4").unwrap() + ); + } + + #[test] + fn enc_dec_roundtrip_hybrid_impression() { + let mut rng = thread_rng(); + let b = EventType::Source; + let oprf_report = build_oprf_report(b, &mut rng); + + let hybrid_impression_report = HybridImpressionReport:: { + match_key: oprf_report.match_key.clone(), + breakdown_key: oprf_report.breakdown_key.clone(), + }; + + let key_registry = KeyRegistry::::random(1, &mut rng); + let key_id = 0; + + let enc_report_bytes = hybrid_impression_report + .encrypt(key_id, &key_registry, &mut rng) + .unwrap(); + + let enc_report = + EncryptedHybridImpressionReport::::from_bytes(enc_report_bytes.as_slice()) + .unwrap(); + let dec_report: HybridImpressionReport = enc_report.decrypt(&key_registry).unwrap(); + + assert_eq!(dec_report, hybrid_impression_report); + } + + #[test] + fn non_ascii_string() { + let non_ascii_string = "☃️☃️☃️"; + let err = HybridImpressionInfo::new(0, non_ascii_string).unwrap_err(); + assert!(matches!(err, NonAsciiStringError(_))); + } +} diff --git a/ipa-core/src/report/hybrid_info.rs b/ipa-core/src/report/hybrid_info.rs new file mode 100644 index 000000000..c41849121 --- /dev/null +++ b/ipa-core/src/report/hybrid_info.rs @@ -0,0 +1,65 @@ +use crate::report::{hybrid::NonAsciiStringError, KeyIdentifier}; + +const DOMAIN: &str = "private-attribution"; + +#[derive(Debug)] +pub struct HybridImpressionInfo<'a> { + pub key_id: KeyIdentifier, + pub helper_origin: &'a str, +} + +#[allow(dead_code)] +pub struct HybridConversionInfo<'a> { + pub key_id: KeyIdentifier, + pub helper_origin: &'a str, + pub converion_site_domain: &'a str, + pub timestamp: u64, + pub epsilon: f64, + pub sensitivity: f64, +} + +#[allow(dead_code)] +pub enum HybridInfo<'a> { + Impression(HybridImpressionInfo<'a>), + Conversion(HybridConversionInfo<'a>), +} + +impl<'a> HybridImpressionInfo<'a> { + /// Creates a new instance. + /// + /// ## Errors + /// if helper or site origin is not a valid ASCII string. + pub fn new(key_id: KeyIdentifier, helper_origin: &'a str) -> Result { + // If the types of errors returned from this function change, then the validation in + // `EncryptedReport::from_bytes` may need to change as well. + if !helper_origin.is_ascii() { + return Err(helper_origin.into()); + } + + Ok(Self { + key_id, + helper_origin, + }) + } + + // Converts this instance into an owned byte slice that can further be used to create HPKE + // sender or receiver context. + pub(super) fn to_bytes(&self) -> Box<[u8]> { + let info_len = DOMAIN.len() + + self.helper_origin.len() + + 2 // delimiters(?) + + std::mem::size_of_val(&self.key_id); + let mut r = Vec::with_capacity(info_len); + + r.extend_from_slice(DOMAIN.as_bytes()); + r.push(0); + r.extend_from_slice(self.helper_origin.as_bytes()); + r.push(0); + + r.push(self.key_id); + + debug_assert_eq!(r.len(), info_len, "HPKE Info length estimation is incorrect and leads to extra allocation or wasted memory"); + + r.into_boxed_slice() + } +} diff --git a/ipa-core/src/report.rs b/ipa-core/src/report/ipa.rs similarity index 87% rename from ipa-core/src/report.rs rename to ipa-core/src/report/ipa.rs index c2411ce28..cfaf4349f 100644 --- a/ipa-core/src/report.rs +++ b/ipa-core/src/report/ipa.rs @@ -1,7 +1,34 @@ +//! Provides report types which are aggregated by the IPA protocol +//! +//! The `OprfReport` is the primary data type which each helpers use to aggreate in the IPA +//! protocol. +//! From each Helper's POV, the Report Collector POSTs a length delimited byte +//! stream, which is then processed as follows: +//! +//! `BodyStream` → `EncryptedOprfReport` → `OprfReport` +//! +//! From the Report Collectors's POV, there are two potential paths: +//! 1. In production, encrypted events are recieved from clients and accumulated out of band +//! as 3 files of newline delimited hex encoded enrypted events. +//! 2. For testing, simluated plaintext events are provided as a CSV. +//! +//! Path 1 is proccssed as follows: +//! +//! `files: [PathBuf; 3]` → `EncryptedOprfReportsFiles` → `helpers::BodyStream` +//! +//! Path 2 is processed as follows: +//! +//! `cli::playbook::InputSource` (`PathBuf` or `stdin()`) → +//! `test_fixture::ipa::TestRawDataRecord` → `OprfReport` → encrypted bytes +//! (via `Oprf.delmited_encrypt_to`) → `helpers::BodyStream` + use std::{ fmt::{Display, Formatter}, + fs::File, + io::{BufRead, BufReader}, marker::PhantomData, ops::{Add, Deref}, + path::PathBuf, }; use bytes::{BufMut, Bytes}; @@ -13,6 +40,7 @@ use typenum::{Sum, Unsigned, U1, U16}; use crate::{ error::BoxError, ff::{boolean_array::BA64, Serializable}, + helpers::BodyStream, hpke::{ open_in_place, seal_in_place, CryptError, EncapsulationSize, Info, PrivateKeyRegistry, PublicKeyRegistry, TagSize, @@ -159,6 +187,53 @@ pub enum InvalidReportError { Length(usize, usize), } +/// A struct intended for the Report Collector to hold the streams of underlying +/// `EncryptedOprfReports` represented as length delmited bytes. Helpers receive an +/// individual stream, which are unpacked into `EncryptedOprfReports` and decrypted +/// into `OprfReports`. +pub struct EncryptedOprfReportStreams { + pub streams: [BodyStream; 3], + pub query_size: usize, +} + +/// A trait to build an `EncryptedOprfReportStreams` struct from 3 files of +/// `EncryptedOprfReports` formated at newline delimited hex. +impl From<[&PathBuf; 3]> for EncryptedOprfReportStreams { + fn from(files: [&PathBuf; 3]) -> Self { + let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new()); + let mut query_sizes: [usize; 3] = [0, 0, 0]; + for (i, path) in files.iter().enumerate() { + let file = + File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}")); + let reader = BufReader::new(file); + for line in reader.lines() { + let encrypted_report_bytes = hex::decode( + line.expect("Unable to read line. {file:?} is likely corrupt") + .trim(), + ) + .expect("Unable to read line. {file:?} is likely corrupt"); + buffers[i].put_u16_le( + encrypted_report_bytes + .len() + .try_into() + .expect("Unable to read line. {file:?} is likely corrupt"), + ); + buffers[i].put_slice(encrypted_report_bytes.as_slice()); + query_sizes[i] += 1; + } + } + // Panic if input sizes are not the same + // Panic instead of returning an Error as this is non-recoverable + assert_eq!(query_sizes[0], query_sizes[1]); + assert_eq!(query_sizes[1], query_sizes[2]); + + Self { + streams: buffers.map(BodyStream::from), + // without loss of generality, set query length to length of first input size + query_size: query_sizes[0], + } + } +} // TODO: If we are parsing reports from CSV files, we may also want an owned version of EncryptedReport. /// A binary report as submitted by a report collector, containing encrypted `OprfReport` @@ -332,11 +407,14 @@ where let mut ct_mk: GenericArray = *GenericArray::from_slice(self.mk_ciphertext()); - let plaintext_mk = open_in_place(key_registry, self.encap_key_mk(), &mut ct_mk, &info)?; + let sk = key_registry + .private_key(self.key_id()) + .ok_or(CryptError::NoSuchKey(self.key_id()))?; + let plaintext_mk = open_in_place(sk, self.encap_key_mk(), &mut ct_mk, &info.to_bytes())?; let mut ct_btt: GenericArray> = GenericArray::from_slice(self.btt_ciphertext()).clone(); - let plaintext_btt = open_in_place(key_registry, self.encap_key_btt(), &mut ct_btt, &info)?; + let plaintext_btt = open_in_place(sk, self.encap_key_btt(), &mut ct_btt, &info.to_bytes())?; Ok(OprfReport:: { timestamp: Replicated::::deserialize(GenericArray::from_slice( @@ -502,11 +580,15 @@ where ..(Self::TV_OFFSET + as Serializable>::Size::USIZE)], )); + let pk = key_registry + .public_key(key_id) + .ok_or(CryptError::NoSuchKey(key_id))?; + let (encap_key_mk, ciphertext_mk, tag_mk) = - seal_in_place(key_registry, plaintext_mk.as_mut(), &info, rng)?; + seal_in_place(pk, plaintext_mk.as_mut(), &info.to_bytes(), rng)?; let (encap_key_btt, ciphertext_btt, tag_btt) = - seal_in_place(key_registry, plaintext_btt.as_mut(), &info, rng)?; + seal_in_place(pk, plaintext_btt.as_mut(), &info.to_bytes(), rng)?; out.put_slice(&encap_key_mk.to_bytes()); out.put_slice(ciphertext_mk); @@ -752,9 +834,9 @@ mod test { fn check_compatibility_impressionmk_with_ios_encryption() { let enc_report_bytes1 = hex::decode( "12854879d86ef277cd70806a7f6bad269877adc95ee107380381caf15b841a7e995e41\ - 4c63a9d82f834796cdd6c40529189fca82720714d24200d8a916a1e090b123f27eaf24\ - f047f3930a77e5bcd33eeb823b73b0e9546c59d3d6e69383c74ae72b79645698fe1422\ - f83886bd3cbca9fbb63f7019e2139191dd000000007777772e6d6574612e636f6d", + 4c63a9d82f834796cdd6c40529189fca82720714d24200d8a916a1e090b123f27eaf24\ + f047f3930a77e5bcd33eeb823b73b0e9546c59d3d6e69383c74ae72b79645698fe1422\ + f83886bd3cbca9fbb63f7019e2139191dd000000007777772e6d6574612e636f6d", ) .unwrap(); let enc_report_bytes2 = hex::decode( diff --git a/ipa-core/src/report/mod.rs b/ipa-core/src/report/mod.rs new file mode 100644 index 000000000..192c87aca --- /dev/null +++ b/ipa-core/src/report/mod.rs @@ -0,0 +1,4 @@ +pub mod ipa; +pub use self::ipa::*; +pub mod hybrid; +pub mod hybrid_info; diff --git a/ipa-core/src/secret_sharing/into_shares.rs b/ipa-core/src/secret_sharing/into_shares.rs index ddd83ec3c..a7bde0764 100644 --- a/ipa-core/src/secret_sharing/into_shares.rs +++ b/ipa-core/src/secret_sharing/into_shares.rs @@ -41,7 +41,7 @@ where } } -#[cfg(all(test, unit_test))] +#[cfg(test)] impl IntoShares> for Result where U: IntoShares, diff --git a/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs b/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs index faca5d570..78d393a07 100644 --- a/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs +++ b/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs @@ -76,6 +76,13 @@ impl, const N: usize> AdditiveShare { >::Array::ZERO_ARRAY, >::Array::ZERO_ARRAY, ); + + /// Returns the size this instance would occupy on the wire or disk. + /// In other words, it does not include padding/alignment. + #[must_use] + pub const fn size() -> usize { + 2 * <>::Array as Serializable>::Size::USIZE + } } impl AdditiveShare { @@ -636,6 +643,14 @@ mod tests { mult_by_constant_test_case((0, 0, 0), 2, 0); } + #[test] + fn test_size() { + const FP31_SZ: usize = AdditiveShare::::size(); + const VEC_FP32: usize = AdditiveShare::::size(); + assert_eq!(2, FP31_SZ); + assert_eq!(256, VEC_FP32); + } + impl Arbitrary for AdditiveShare where V: Vectorizable>, diff --git a/ipa-core/src/secret_sharing/vector/impls.rs b/ipa-core/src/secret_sharing/vector/impls.rs index 536840fa3..b5b043b4d 100644 --- a/ipa-core/src/secret_sharing/vector/impls.rs +++ b/ipa-core/src/secret_sharing/vector/impls.rs @@ -55,104 +55,6 @@ macro_rules! boolean_vector { AdditiveShare::new(*value.left_arr(), *value.right_arr()) } } - - #[cfg(all(test, unit_test))] - mod tests { - use std::iter::zip; - - use super::*; - use crate::{ - error::Error, - protocol::{ - basics::select, - context::{dzkp_validator::DZKPValidator, Context, UpgradableContext}, - RecordId, - }, - rand::{thread_rng, Rng}, - secret_sharing::into_shares::IntoShares, - test_fixture::{join3v, Reconstruct, TestWorld}, - }; - - #[tokio::test] - async fn simplest_circuit_malicious() { - let world = TestWorld::default(); - let context = world.malicious_contexts(); - let mut rng = thread_rng(); - - let bit = rng.gen::(); - let a = rng.gen::<$vec>(); - let b = rng.gen::<$vec>(); - - let bit_shares = bit.share_with(&mut rng); - let a_shares = a.share_with(&mut rng); - let b_shares = b.share_with(&mut rng); - - let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))) - .map(|(ctx, (bit_share, (a_share, b_share)))| async move { - let v = ctx.clone().dzkp_validator(1); - let m_ctx = v.context(); - - let result = select( - m_ctx.set_total_records(1), - RecordId::from(0), - &bit_share, - &a_share, - &b_share, - ) - .await?; - - v.validate().await?; - - Ok::<_, Error>(result) - }); - - let [ab0, ab1, ab2] = join3v(futures).await; - - let ab = [ab0, ab1, ab2].reconstruct(); - - assert_eq!(ab, if bit.into() { a } else { b }); - } - - #[tokio::test] - async fn simplest_circuit_semi_honest() { - let world = TestWorld::default(); - let context = world.contexts(); - let mut rng = thread_rng(); - - let bit = rng.gen::(); - let a = rng.gen::<$vec>(); - let b = rng.gen::<$vec>(); - - let bit_shares = bit.share_with(&mut rng); - let a_shares = a.share_with(&mut rng); - let b_shares = b.share_with(&mut rng); - - let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))) - .map(|(ctx, (bit_share, (a_share, b_share)))| async move { - let v = ctx.clone().dzkp_validator(1); - let sh_ctx = v.context(); - - let result = select( - sh_ctx.set_total_records(1), - RecordId::from(0), - &bit_share, - &a_share, - &b_share, - ) - .await?; - - v.validate().await?; - - Ok::<_, Error>(result) - }); - - let [ab0, ab1, ab2] = join3v(futures).await; - - let ab = [ab0, ab1, ab2].reconstruct(); - - assert_eq!(ab, if bit.into() { a } else { b }); - } - } } }; } diff --git a/ipa-core/src/secret_sharing/vector/traits.rs b/ipa-core/src/secret_sharing/vector/traits.rs index b44316b70..8194fb103 100644 --- a/ipa-core/src/secret_sharing/vector/traits.rs +++ b/ipa-core/src/secret_sharing/vector/traits.rs @@ -108,7 +108,7 @@ pub trait SharedValueArray: pub trait FieldArray: SharedValueArray + FromRandom - + for<'a> Mul + + Mul + for<'a> Mul<&'a F, Output = Self> + for<'a> Mul<&'a Self, Output = Self> { diff --git a/ipa-core/src/sharding.rs b/ipa-core/src/sharding.rs index 625f724e6..afd051988 100644 --- a/ipa-core/src/sharding.rs +++ b/ipa-core/src/sharding.rs @@ -1,11 +1,88 @@ use std::{ fmt::{Debug, Display, Formatter}, num::TryFromIntError, + ops::{Index, IndexMut}, }; /// A unique zero-based index of the helper shard. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ShardIndex(u32); +pub struct ShardIndex(pub u32); + +impl ShardIndex { + pub const FIRST: Self = Self(0); + + /// Returns an iterator over all shard indices that precede this one, excluding this one. + pub fn iter(self) -> impl Iterator { + (0..self.0).map(Self) + } +} + +impl From for ShardIndex { + fn from(value: u32) -> Self { + Self(value) + } +} + +impl From for u64 { + fn from(value: ShardIndex) -> Self { + u64::from(value.0) + } +} + +impl From for u128 { + fn from(value: ShardIndex) -> Self { + Self::from(value.0) + } +} + +#[cfg(target_pointer_width = "64")] +impl From for usize { + fn from(value: ShardIndex) -> Self { + usize::try_from(value.0).unwrap() + } +} + +impl From for u32 { + fn from(value: ShardIndex) -> Self { + value.0 + } +} + +impl TryFrom for ShardIndex { + type Error = TryFromIntError; + + fn try_from(value: usize) -> Result { + u32::try_from(value).map(Self) + } +} + +impl TryFrom for ShardIndex { + type Error = TryFromIntError; + + fn try_from(value: u128) -> Result { + u32::try_from(value).map(Self) + } +} + +impl Display for ShardIndex { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl Index for Vec { + type Output = T; + + fn index(&self, index: ShardIndex) -> &Self::Output { + self.as_slice().index(usize::from(index)) + } +} + +impl IndexMut for Vec { + fn index_mut(&mut self, index: ShardIndex) -> &mut Self::Output { + self.as_mut_slice().index_mut(usize::from(index)) + } +} #[derive(Debug, Copy, Clone)] pub struct Sharded { @@ -23,12 +100,6 @@ impl ShardConfiguration for Sharded { } } -impl Display for ShardIndex { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - Display::fmt(&self.0, f) - } -} - /// Shard-specific configuration required by sharding API. Each shard must know its own index and /// the total number of shards in the system. pub trait ShardConfiguration { @@ -56,7 +127,7 @@ pub trait ShardConfiguration { } } -pub trait ShardBinding: Debug + Send + Sync + Clone {} +pub trait ShardBinding: Debug + Send + Sync + Clone + 'static {} #[derive(Debug, Copy, Clone)] pub struct NotSharded; @@ -64,48 +135,6 @@ pub struct NotSharded; impl ShardBinding for NotSharded {} impl ShardBinding for Sharded {} -impl ShardIndex { - pub const FIRST: Self = Self(0); - - /// Returns an iterator over all shard indices that precede this one, excluding this one. - pub fn iter(self) -> impl Iterator { - (0..self.0).map(Self) - } -} - -impl From for ShardIndex { - fn from(value: u32) -> Self { - Self(value) - } -} - -impl From for u64 { - fn from(value: ShardIndex) -> Self { - u64::from(value.0) - } -} - -impl From for u128 { - fn from(value: ShardIndex) -> Self { - Self::from(value.0) - } -} - -#[cfg(target_pointer_width = "64")] -impl From for usize { - fn from(value: ShardIndex) -> Self { - usize::try_from(value.0).unwrap() - } -} - -impl TryFrom for ShardIndex { - type Error = TryFromIntError; - - fn try_from(value: usize) -> Result { - u32::try_from(value).map(Self) - } -} - #[cfg(all(test, unit_test))] mod tests { use std::iter::empty; diff --git a/ipa-core/src/test_fixture/circuit.rs b/ipa-core/src/test_fixture/circuit.rs index 2e353d3c3..17920591f 100644 --- a/ipa-core/src/test_fixture/circuit.rs +++ b/ipa-core/src/test_fixture/circuit.rs @@ -1,4 +1,4 @@ -use std::{array, num::NonZeroUsize}; +use std::array; use futures::{future::join3, stream, StreamExt}; use ipa_step::StepNarrow; @@ -10,14 +10,14 @@ use crate::{ protocol::{ basics::SecureMul, context::{Context, SemiHonestContext}, - step::ProtocolStep, + step::{ProtocolStep, TestExecutionStep as Step}, Gate, RecordId, }, rand::thread_rng, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, FieldSimd, IntoShares}, seq_join::seq_join, - test_fixture::{step::TestExecutionStep as Step, ReconstructArr, TestWorld, TestWorldConfig}, - utils::array::zip3, + test_fixture::{ReconstructArr, TestWorld, TestWorldConfig}, + utils::{array::zip3, NonZeroU32PowerOfTwo}, }; pub struct Inputs, const N: usize> { @@ -76,16 +76,16 @@ pub async fn arithmetic( [F; N]: IntoShares>, Standard: Distribution, { - let active = NonZeroUsize::new(active_work).unwrap(); + let active = NonZeroU32PowerOfTwo::try_from(active_work.next_power_of_two()).unwrap(); let config = TestWorldConfig { gateway_config: GatewayConfig { active, ..Default::default() }, - initial_gate: Some(Gate::default().narrow(&ProtocolStep::Test(0))), + initial_gate: Some(Gate::default().narrow(&ProtocolStep::Test)), ..Default::default() }; - let world = TestWorld::new_with(config); + let world = TestWorld::new_with(&config); // Re-use contexts for the entire execution because record identifiers are contiguous. let contexts = world.contexts(); @@ -96,7 +96,7 @@ pub async fn arithmetic( // accumulated. This gives the best performance for vectorized operation. let ctx = ctx.set_total_records(TotalRecords::Indeterminate); seq_join( - active, + config.gateway_config.active_work(), stream::iter((0..(width / u32::try_from(N).unwrap())).zip(col_data)).map( move |(record, Inputs { a, b })| { circuit(ctx.clone(), RecordId::from(record), depth, a, b) diff --git a/ipa-core/src/test_fixture/hybrid.rs b/ipa-core/src/test_fixture/hybrid.rs new file mode 100644 index 000000000..63ecf73e5 --- /dev/null +++ b/ipa-core/src/test_fixture/hybrid.rs @@ -0,0 +1,150 @@ +use std::collections::{HashMap, HashSet}; + +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq)] +pub enum TestHybridRecord { + TestImpression { match_key: u64, breakdown_key: u32 }, + TestConversion { match_key: u64, value: u32 }, +} + +struct HashmapEntry { + breakdown_key: u32, + total_value: u32, +} + +impl HashmapEntry { + pub fn new(breakdown_key: u32, value: u32) -> Self { + Self { + breakdown_key, + total_value: value, + } + } +} + +/// # Panics +/// It won't, so long as you can convert a u32 to a usize +#[must_use] +pub fn hybrid_in_the_clear(input_rows: &[TestHybridRecord], max_breakdown: usize) -> Vec { + let mut conversion_match_keys = HashSet::new(); + let mut impression_match_keys = HashSet::new(); + + for input in input_rows { + match input { + TestHybridRecord::TestImpression { match_key, .. } => { + impression_match_keys.insert(*match_key); + } + TestHybridRecord::TestConversion { match_key, .. } => { + conversion_match_keys.insert(*match_key); + } + } + } + + // The key is the "match key" and the value stores both the breakdown and total attributed value + let mut attributed_conversions = HashMap::new(); + + for input in input_rows { + match input { + TestHybridRecord::TestImpression { + match_key, + breakdown_key, + } => { + if conversion_match_keys.contains(match_key) { + let v = attributed_conversions + .entry(*match_key) + .or_insert(HashmapEntry::new(*breakdown_key, 0)); + v.breakdown_key = *breakdown_key; + } + } + TestHybridRecord::TestConversion { match_key, value } => { + if impression_match_keys.contains(match_key) { + attributed_conversions + .entry(*match_key) + .and_modify(|e| e.total_value += value) + .or_insert(HashmapEntry::new(0, *value)); + } + } + } + } + + let mut output = vec![0; max_breakdown]; + for (_, entry) in attributed_conversions { + output[usize::try_from(entry.breakdown_key).unwrap()] += entry.total_value; + } + + output +} + +#[cfg(all(test, unit_test))] +mod tests { + use rand::{seq::SliceRandom, thread_rng}; + + use super::TestHybridRecord; + use crate::test_fixture::hybrid::hybrid_in_the_clear; + + #[test] + fn basic() { + let mut test_data = vec![ + TestHybridRecord::TestImpression { + match_key: 12345, + breakdown_key: 2, + }, + TestHybridRecord::TestImpression { + match_key: 23456, + breakdown_key: 4, + }, + TestHybridRecord::TestConversion { + match_key: 23456, + value: 25, + }, // attributed + TestHybridRecord::TestImpression { + match_key: 34567, + breakdown_key: 1, + }, + TestHybridRecord::TestImpression { + match_key: 45678, + breakdown_key: 3, + }, + TestHybridRecord::TestConversion { + match_key: 45678, + value: 13, + }, // attributed + TestHybridRecord::TestImpression { + match_key: 56789, + breakdown_key: 5, + }, + TestHybridRecord::TestConversion { + match_key: 67890, + value: 14, + }, // NOT attributed + TestHybridRecord::TestImpression { + match_key: 78901, + breakdown_key: 2, + }, + TestHybridRecord::TestConversion { + match_key: 78901, + value: 12, + }, // attributed + TestHybridRecord::TestConversion { + match_key: 78901, + value: 31, + }, // attributed + TestHybridRecord::TestImpression { + match_key: 89012, + breakdown_key: 4, + }, + TestHybridRecord::TestConversion { + match_key: 89012, + value: 8, + }, // attributed + ]; + + let mut rng = thread_rng(); + test_data.shuffle(&mut rng); + let expected = vec![ + 0, 0, 43, // 12 + 31 + 13, 33, // 25 + 8 + 0, + ]; + let result = hybrid_in_the_clear(&test_data, 6); + assert_eq!(result, expected); + } +} diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs new file mode 100644 index 000000000..792b1cc37 --- /dev/null +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -0,0 +1,362 @@ +use std::num::NonZeroU32; + +use rand::Rng; + +use super::hybrid::TestHybridRecord; + +#[derive(Debug, Copy, Clone)] +#[cfg_attr(feature = "clap", derive(clap::ValueEnum))] +pub enum ConversionDistribution { + Default, + LotsOfConversionsPerImpression, + OnlyImpressions, + OnlyConversions, +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "clap", derive(clap::Args))] +pub struct Config { + #[cfg_attr(feature = "clap", arg(long, default_value = "5"))] + pub max_conversion_value: NonZeroU32, + #[cfg_attr(feature = "clap", arg(long, default_value = "20"))] + pub max_breakdown_key: NonZeroU32, + #[cfg_attr(feature = "clap", arg(long, default_value = "10"))] + pub max_convs_per_imp: NonZeroU32, + /// Indicates the distribution of impression to conversion reports. + #[cfg_attr(feature = "clap", arg(value_enum, long, default_value_t = ConversionDistribution::Default))] + pub conversion_distribution: ConversionDistribution, +} + +impl Default for Config { + fn default() -> Self { + Self::new(5, 20, 10, ConversionDistribution::Default) + } +} + +impl Config { + /// Creates a new instance of [`Self`] + /// + /// ## Panics + /// If any argument is 0. + #[must_use] + pub fn new( + max_conversion_value: u32, + max_breakdown_key: u32, + max_convs_per_imp: u32, + conversion_distribution: ConversionDistribution, + ) -> Self { + Self { + max_conversion_value: NonZeroU32::try_from(max_conversion_value).unwrap(), + max_breakdown_key: NonZeroU32::try_from(max_breakdown_key).unwrap(), + max_convs_per_imp: NonZeroU32::try_from(max_convs_per_imp).unwrap(), + conversion_distribution, + } + } +} + +pub struct EventGenerator { + config: Config, + rng: R, + in_flight: Vec, +} + +impl EventGenerator { + #[allow(dead_code)] + pub fn with_default_config(rng: R) -> Self { + Self::with_config(rng, Config::default()) + } + + /// # Panics + /// If the configuration is not valid. + #[allow(dead_code)] + pub fn with_config(rng: R, config: Config) -> Self { + let max_capacity = usize::try_from(config.max_convs_per_imp.get() + 1).unwrap(); + Self { + config, + rng, + in_flight: Vec::with_capacity(max_capacity), + } + } + + fn gen_batch(&mut self) { + match self.config.conversion_distribution { + ConversionDistribution::OnlyImpressions => { + self.gen_batch_with_params(0.0, 1.0, 0.0); + } + ConversionDistribution::OnlyConversions => { + self.gen_batch_with_params(1.0, 0.0, 0.0); + } + ConversionDistribution::Default => { + self.gen_batch_with_params(0.1, 0.7, 0.15); + } + ConversionDistribution::LotsOfConversionsPerImpression => { + self.gen_batch_with_params(0.3, 0.4, 0.8); + } + } + } + + fn gen_batch_with_params( + &mut self, + unmatched_conversions: f32, + unmatched_impressions: f32, + subsequent_conversion_prob: f32, + ) { + assert!(unmatched_conversions + unmatched_impressions <= 1.0); + let match_key = self.rng.gen::(); + let rand = self.rng.gen_range(0.0..1.0); + if rand < unmatched_conversions { + let conv = self.gen_conversion(match_key); + self.in_flight.push(conv); + } else if rand < unmatched_conversions + unmatched_impressions { + let imp = self.gen_impression(match_key); + self.in_flight.push(imp); + } else { + let imp = self.gen_impression(match_key); + let conv = self.gen_conversion(match_key); + self.in_flight.push(imp); + self.in_flight.push(conv); + let mut conv_count = 1; + // long-tailed distribution of # of conversions per impression + // will not exceed the configured maximum number of conversions per impression + while conv_count < self.config.max_convs_per_imp.get() + && self.rng.gen_range(0.0..1.0) < subsequent_conversion_prob + { + let conv = self.gen_conversion(match_key); + self.in_flight.push(conv); + conv_count += 1; + } + } + } + + fn gen_conversion(&mut self, match_key: u64) -> TestHybridRecord { + TestHybridRecord::TestConversion { + match_key, + value: self + .rng + .gen_range(1..self.config.max_conversion_value.get()), + } + } + + fn gen_impression(&mut self, match_key: u64) -> TestHybridRecord { + TestHybridRecord::TestImpression { + match_key, + breakdown_key: self.rng.gen_range(0..self.config.max_breakdown_key.get()), + } + } +} + +impl Iterator for EventGenerator { + type Item = TestHybridRecord; + + fn next(&mut self) -> Option { + if self.in_flight.is_empty() { + self.gen_batch(); + } + Some(self.in_flight.pop().unwrap()) + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::{ + collections::{HashMap, HashSet}, + iter::zip, + }; + + use rand::thread_rng; + + use super::*; + + #[test] + fn iter() { + let gen = EventGenerator::with_default_config(thread_rng()); + assert_eq!(10, gen.take(10).collect::>().len()); + + let gen = EventGenerator::with_default_config(thread_rng()); + assert_eq!(1000, gen.take(1000).collect::>().len()); + } + + #[test] + fn default_config() { + // Since there is randomness, the actual number will be a bit different + // from the expected value. + // The "tolerance" is used to compute the allowable range of values. + // It is multiplied by the expected value. So a tolerance of 0.05 means + // we will accept a value within 5% of the expected value + const EXPECTED_HISTOGRAM_WITH_TOLERANCE: [(i32, f64); 12] = [ + (0, 0.0), + (647_634, 0.01), + (137_626, 0.02), + (20_652, 0.03), + (3_085, 0.05), + (463, 0.12), + (70, 0.5), + (10, 1.0), + (2, 1.0), + (0, 1.0), + (0, 1.0), + (0, 1.0), + ]; + const TEST_COUNT: usize = 1_000_000; + let gen = EventGenerator::with_default_config(thread_rng()); + let max_convs_per_imp = gen.config.max_convs_per_imp.get(); + let mut match_key_to_event_count = HashMap::new(); + for event in gen.take(TEST_COUNT) { + match event { + TestHybridRecord::TestImpression { match_key, .. } => { + match_key_to_event_count + .entry(match_key) + .and_modify(|count| *count += 1) + .or_insert(1); + } + TestHybridRecord::TestConversion { match_key, .. } => { + match_key_to_event_count + .entry(match_key) + .and_modify(|count| *count += 1) + .or_insert(1); + } + } + } + let histogram_size = usize::try_from(max_convs_per_imp + 2).unwrap(); + let mut histogram: Vec = vec![0; histogram_size]; + for (_, count) in match_key_to_event_count { + histogram[count] += 1; + } + + for (actual, (expected, tolerance)) in + zip(histogram, EXPECTED_HISTOGRAM_WITH_TOLERANCE.iter()) + { + // Adding a constant value of 10 is a way of dealing with the high variability small values + // which will vary a lot more (as a percent). Because 10 is an increasingly large percentage of + // A smaller and smaller expected value + let max_tolerance = f64::from(*expected) * tolerance + 10.0; + assert!( + f64::from((expected - actual).abs()) <= max_tolerance, + "{:?} is outside of the expected range: ({:?}..{:?})", + actual, + f64::from(*expected) - max_tolerance, + f64::from(*expected) + max_tolerance, + ); + } + } + + #[test] + fn lots_of_repeat_conversions() { + const EXPECTED_HISTOGRAM: [i32; 12] = [ + 0, 299_296, 25_640, 20_542, 16_421, 13_133, 10_503, 8_417, 6_730, 5_391, 4_289, 17_206, + ]; + const TEST_COUNT: usize = 1_000_000; + const MAX_CONVS_PER_IMP: u32 = 10; + const MAX_BREAKDOWN_KEY: u32 = 20; + const MAX_VALUE: u32 = 3; + let gen = EventGenerator::with_config( + thread_rng(), + Config::new( + MAX_VALUE, + MAX_BREAKDOWN_KEY, + MAX_CONVS_PER_IMP, + ConversionDistribution::LotsOfConversionsPerImpression, + ), + ); + let max_convs_per_imp = gen.config.max_convs_per_imp.get(); + let mut match_key_to_event_count = HashMap::new(); + for event in gen.take(TEST_COUNT) { + match event { + TestHybridRecord::TestImpression { + match_key, + breakdown_key, + } => { + assert!(breakdown_key <= MAX_BREAKDOWN_KEY); + match_key_to_event_count + .entry(match_key) + .and_modify(|count| *count += 1) + .or_insert(1); + } + TestHybridRecord::TestConversion { match_key, value } => { + assert!(value <= MAX_VALUE); + match_key_to_event_count + .entry(match_key) + .and_modify(|count| *count += 1) + .or_insert(1); + } + } + } + let histogram_size = usize::try_from(max_convs_per_imp + 2).unwrap(); + let mut histogram: Vec = vec![0; histogram_size]; + for (_, count) in match_key_to_event_count { + histogram[count] += 1; + } + + for (expected, actual) in zip(EXPECTED_HISTOGRAM.iter(), histogram) { + let max_tolerance = f64::from(*expected) * 0.05 + 10.0; + assert!( + f64::from((expected - actual).abs()) <= max_tolerance, + "{:?} is outside of the expected range: ({:?}..{:?})", + actual, + f64::from(*expected) - max_tolerance, + f64::from(*expected) + max_tolerance, + ); + } + } + + #[test] + fn only_impressions_config() { + const NUM_EVENTS: usize = 100; + const MAX_CONVS_PER_IMP: u32 = 1; + const MAX_BREAKDOWN_KEY: u32 = 10; + let gen = EventGenerator::with_config( + thread_rng(), + Config::new( + 10, + MAX_BREAKDOWN_KEY, + MAX_CONVS_PER_IMP, + ConversionDistribution::OnlyImpressions, + ), + ); + let mut match_keys = HashSet::new(); + for event in gen.take(NUM_EVENTS) { + match event { + TestHybridRecord::TestImpression { + match_key, + breakdown_key, + } => { + assert!(breakdown_key <= MAX_BREAKDOWN_KEY); + match_keys.insert(match_key); + } + TestHybridRecord::TestConversion { .. } => { + panic!("No conversions should be generated"); + } + } + } + assert_eq!(match_keys.len(), NUM_EVENTS); + } + + #[test] + fn only_conversions_config() { + const NUM_EVENTS: usize = 100; + const MAX_CONVS_PER_IMP: u32 = 1; + const MAX_VALUE: u32 = 10; + let gen = EventGenerator::with_config( + thread_rng(), + Config::new( + MAX_VALUE, + 10, + MAX_CONVS_PER_IMP, + ConversionDistribution::OnlyConversions, + ), + ); + let mut match_keys = HashSet::new(); + for event in gen.take(NUM_EVENTS) { + match event { + TestHybridRecord::TestConversion { match_key, value } => { + assert!(value <= MAX_VALUE); + match_keys.insert(match_key); + } + TestHybridRecord::TestImpression { .. } => { + panic!("No impressions should be generated"); + } + } + } + assert_eq!(match_keys.len(), NUM_EVENTS); + } +} diff --git a/ipa-core/src/test_fixture/ipa.rs b/ipa-core/src/test_fixture/ipa.rs index b14324c8e..c38a4bc9f 100644 --- a/ipa-core/src/test_fixture/ipa.rs +++ b/ipa-core/src/test_fixture/ipa.rs @@ -27,10 +27,6 @@ pub enum IpaSecurityModel { Malicious, } -pub enum IpaQueryStyle { - Oprf, -} - #[derive(Debug, Clone, Ord, PartialEq, PartialOrd, Eq)] pub struct TestRawDataRecord { pub timestamp: u64, @@ -219,7 +215,7 @@ pub async fn test_oprf_ipa( world.semi_honest( records.into_iter(), |ctx, input_rows: Vec>| async move { - oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + oprf_ipa::<_, BA5, BA8, BA32, BA20, 8, 32>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap() }, @@ -231,19 +227,19 @@ pub async fn test_oprf_ipa( |ctx, input_rows: Vec>| async move { match config.per_user_credit_cap { - 8 => oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + 8 => oprf_ipa::<_, BA8, BA3, BA32, BA20, 3, 256>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap(), - 16 => oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + 16 => oprf_ipa::<_, BA8, BA3, BA32, BA20, 4, 256>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap(), - 32 => oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + 32 => oprf_ipa::<_, BA8, BA3, BA32, BA20, 5, 256>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap(), - 64 => oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + 64 => oprf_ipa::<_, BA8, BA3, BA32, BA20, 6, 256>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap(), - 128 => oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + 128 => oprf_ipa::<_, BA8, BA3, BA32, BA20, 7, 256>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap(), _ => diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index 999f327d6..38d12eb21 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -11,19 +11,23 @@ mod app; #[cfg(feature = "in-memory-infra")] pub mod circuit; mod event_gen; +pub mod hybrid; +pub mod hybrid_event_gen; pub mod ipa; pub mod logging; pub mod metrics; -pub(crate) mod step; #[cfg(feature = "in-memory-infra")] mod test_gate; -use std::fmt::Debug; +use std::{fmt::Debug, future::Future}; #[cfg(feature = "in-memory-infra")] pub use app::TestApp; pub use event_gen::{Config as EventGeneratorConfig, EventGenerator}; -use futures::TryFuture; +use futures::{FutureExt, TryFuture}; +pub use hybrid_event_gen::{ + Config as HybridGeneratorConfig, EventGenerator as HybridEventGenerator, +}; use rand::{distributions::Standard, prelude::Distribution, rngs::mock::StepRng}; use rand_core::{CryptoRng, RngCore}; pub use sharing::{get_bits, into_bits, Reconstruct, ReconstructArr}; @@ -102,30 +106,32 @@ pub fn permutation_valid(permutation: &[u32]) -> bool { /// Wrapper for joining three things into an array. /// # Errors /// If one of the futures returned an error. -pub async fn try_join3_array([f0, f1, f2]: [T; 3]) -> Result<[T::Ok; 3], T::Error> { - futures::future::try_join3(f0, f1, f2) - .await - .map(|(a, b, c)| [a, b, c]) +pub fn try_join3_array( + [f0, f1, f2]: [T; 3], +) -> impl Future> { + futures::future::try_join3(f0, f1, f2).map(|res| res.map(|(a, b, c)| [a, b, c])) } /// Wrapper for joining three things into an array. /// # Panics /// If the tasks return `Err`. -pub async fn join3(a: T, b: T, c: T) -> [T::Ok; 3] +pub fn join3(a: T, b: T, c: T) -> impl Future where T: TryFuture, T::Output: Debug, T::Ok: Debug, T::Error: Debug, { - let (a, b, c) = futures::future::try_join3(a, b, c).await.unwrap(); - [a, b, c] + futures::future::try_join3(a, b, c).map(|res| { + let (a, b, c) = res.unwrap(); + [a, b, c] + }) } /// Wrapper for joining three things from an iterator into an array. /// # Panics /// If the tasks return `Err` or if `a` is the wrong length. -pub async fn join3v(a: V) -> [T::Ok; 3] +pub fn join3v(a: V) -> impl Future where V: IntoIterator, T: TryFuture, @@ -134,9 +140,42 @@ where T::Error: Debug, { let mut it = a.into_iter(); - let res = join3(it.next().unwrap(), it.next().unwrap(), it.next().unwrap()).await; + let fut0 = it.next().unwrap(); + let fut1 = it.next().unwrap(); + let fut2 = it.next().unwrap(); + assert!(it.next().is_none()); + join3(fut0, fut1, fut2) +} + +/// Wrapper for flattening 3 vecs of vecs into a single future +/// # Panics +/// If the tasks return `Err` or if `a` is the wrong length. +pub fn flatten3v(a: V) -> impl Future::Output>> +where + V: IntoIterator, + I: IntoIterator, + T: TryFuture, + T::Output: Debug, + T::Ok: Debug, + T::Error: Debug, +{ + let mut it = a.into_iter(); + + let outer0 = it.next().unwrap().into_iter(); + let outer1 = it.next().unwrap().into_iter(); + let outer2 = it.next().unwrap().into_iter(); + assert!(it.next().is_none()); - res + + // only used for tests + #[allow(clippy::disallowed_methods)] + futures::future::join_all( + outer0 + .zip(outer1) + .zip(outer2) + .flat_map(|((fut0, fut1), fut2)| vec![fut0, fut1, fut2]) + .collect::>(), + ) } /// Take a slice of bits in `{0,1} ⊆ F_p`, and reconstruct the integer in `Z` diff --git a/ipa-core/src/test_fixture/step.rs b/ipa-core/src/test_fixture/step.rs deleted file mode 100644 index a0881c2a6..000000000 --- a/ipa-core/src/test_fixture/step.rs +++ /dev/null @@ -1,8 +0,0 @@ -use ipa_step_derive::CompactStep; - -/// Provides a unique per-iteration context in tests. -#[derive(CompactStep)] -pub(crate) enum TestExecutionStep { - #[step(count = 999)] - Iter(usize), -} diff --git a/ipa-core/src/test_fixture/test_gate.rs b/ipa-core/src/test_fixture/test_gate.rs index a79a7dfca..a59802ff0 100644 --- a/ipa-core/src/test_fixture/test_gate.rs +++ b/ipa-core/src/test_fixture/test_gate.rs @@ -2,7 +2,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use ipa_step::StepNarrow; -use crate::{protocol::Gate, test_fixture::step::TestExecutionStep}; +use crate::protocol::{step::TestExecutionStep, Gate}; /// This manages the gate information for test runs. Most unit tests want to have multiple runs /// using the same instance of [`TestWorld`], but they don't care about the name of that particular diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 6fba504cf..5c23cb2d8 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -23,8 +23,8 @@ use crate::{ context::{ dzkp_validator::DZKPValidator, upgrade::Upgradable, Context, DZKPUpgradedMaliciousContext, MaliciousContext, SemiHonestContext, - ShardedSemiHonestContext, UpgradableContext, UpgradedContext, UpgradedMaliciousContext, - UpgradedSemiHonestContext, Validator, + ShardedMaliciousContext, ShardedSemiHonestContext, UpgradableContext, UpgradedContext, + UpgradedMaliciousContext, UpgradedSemiHonestContext, Validator, TEST_DZKP_STEPS, }, prss::Endpoint as PrssEndpoint, Gate, QueryId, RecordId, @@ -207,6 +207,42 @@ impl TestWorld> { .ok() .unwrap() } + + /// Creates protocol contexts for 3 helpers across all shards + /// + /// # Panics + /// Panics if world has more or less than 3 gateways/participants + #[must_use] + pub fn contexts(&self) -> [Vec>; 3] { + let gate = &self.next_gate(); + self.shards().iter().map(|shard| shard.contexts(gate)).fold( + [Vec::new(), Vec::new(), Vec::new()], + |mut acc, contexts| { + // Distribute contexts into the respective vectors. + for (vec, context) in acc.iter_mut().zip(contexts.iter()) { + vec.push(context.clone()); + } + acc + }, + ) + } + /// Creates malicious protocol contexts for 3 helpers across all shards + /// + /// # Panics + /// Panics if world has more or less than 3 gateways/participants + #[must_use] + pub fn malicious_contexts(&self) -> [Vec>; 3] { + self.shards() + .iter() + .map(|shard| shard.malicious_contexts(&self.next_gate())) + .fold([Vec::new(), Vec::new(), Vec::new()], |mut acc, contexts| { + // Distribute contexts into the respective vectors. + for (vec, context) in acc.iter_mut().zip(contexts.iter()) { + vec.push(context.clone()); + } + acc + }) + } } /// Backward-compatible API for tests that don't use sharding. @@ -369,6 +405,10 @@ where pub trait Runner { /// This could be also derived from [`S`], but maybe that's too much for that trait. type SemiHonestContext<'ctx>: Context; + /// The type of context used to run protocols that are secure against + /// active adversaries. It varies depending on whether sharding is used or not. + type MaliciousContext<'ctx>: Context; + /// Run with a context that can be upgraded, but is only good for semi-honest. async fn semi_honest<'a, I, A, O, H, R>( &'a self, @@ -396,12 +436,12 @@ pub trait Runner { R: Future + Send; /// Run with a context that can be upgraded to malicious. - async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3] + async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> S::Container<[O; 3]> where - I: IntoShares + Send + 'static, + I: RunnerInput, A: Send, O: Send + Debug, - H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(Self::MaliciousContext<'a>, S::Container) -> R + Send + Sync, R: Future + Send; /// Run with a context that has already been upgraded to malicious. @@ -428,7 +468,7 @@ pub trait Runner { I: IntoShares + Send + 'static, A: Send + 'static, O: Send + Debug, - H: Fn(DZKPUpgradedMaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(DZKPUpgradedMaliciousContext<'a, NotSharded>, A) -> R + Send + Sync, R: Future + Send; } @@ -444,6 +484,7 @@ impl Runner> for TestWorld> { type SemiHonestContext<'ctx> = ShardedSemiHonestContext<'ctx>; + type MaliciousContext<'ctx> = ShardedMaliciousContext<'ctx>; async fn semi_honest<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> Vec<[O; 3]> where I: RunnerInput, A>, @@ -494,15 +535,39 @@ impl Runner> unimplemented!() } - async fn malicious<'a, I, A, O, H, R>(&'a self, _input: I, _helper_fn: H) -> [O; 3] + async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> Vec<[O; 3]> where - I: IntoShares + Send + 'static, + I: RunnerInput, A>, A: Send, O: Send + Debug, - H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn( + Self::MaliciousContext<'a>, + as ShardingScheme>::Container, + ) -> R + + Send + + Sync, R: Future + Send, { - unimplemented!() + let shards = self.shards(); + let [h1, h2, h3]: [[Vec; SHARDS]; 3] = input.share().map(D::distribute); + let gate = self.next_gate(); + // todo!() + + // No clippy, you're wrong, it is not redundant, it allows shard_fn to be `Copy` + #[allow(clippy::redundant_closure)] + let shard_fn = |ctx, input| helper_fn(ctx, input); + zip(shards.into_iter(), zip(zip(h1, h2), h3)) + .map(|(shard, ((h1, h2), h3))| { + ShardWorld::::run_either( + shard.malicious_contexts(&gate), + self.metrics_handle.span(), + [h1, h2, h3], + shard_fn, + ) + }) + .collect::>() + .collect::>() + .await } async fn upgraded_malicious<'a, F, I, A, M, O, H, R, P>( @@ -531,7 +596,7 @@ impl Runner> I: IntoShares + Send + 'static, A: Send + 'static, O: Send + Debug, - H: Fn(DZKPUpgradedMaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(DZKPUpgradedMaliciousContext<'a, NotSharded>, A) -> R + Send + Sync, R: Future + Send, { unimplemented!() @@ -541,6 +606,7 @@ impl Runner> #[async_trait] impl Runner for TestWorld { type SemiHonestContext<'ctx> = SemiHonestContext<'ctx>; + type MaliciousContext<'ctx> = MaliciousContext<'ctx>; async fn semi_honest<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3] where @@ -583,10 +649,10 @@ impl Runner for TestWorld { async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3] where - I: IntoShares + Send + 'static, + I: RunnerInput, A: Send, O: Send + Debug, - H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(Self::MaliciousContext<'a>, A) -> R + Send + Sync, R: Future + Send, { ShardWorld::::run_either( @@ -672,11 +738,11 @@ impl Runner for TestWorld { I: IntoShares + Send + 'static, A: Send + 'static, O: Send + Debug, - H: (Fn(DZKPUpgradedMaliciousContext<'a>, A) -> R) + Send + Sync, + H: (Fn(DZKPUpgradedMaliciousContext<'a, NotSharded>, A) -> R) + Send + Sync, R: Future + Send, { self.malicious(input, |ctx, share| async { - let v = ctx.dzkp_validator(10); + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 8); let m_ctx = v.context(); let m_result = helper_fn(m_ctx, share).await; v.validate().await.unwrap(); @@ -778,9 +844,14 @@ impl ShardWorld { /// # Panics /// Panics if world has more or less than 3 gateways/participants #[must_use] - pub fn malicious_contexts(&self, gate: &Gate) -> [MaliciousContext<'_>; 3] { + pub fn malicious_contexts(&self, gate: &Gate) -> [MaliciousContext<'_, B>; 3] { zip3_ref(&self.participants, &self.gateways).map(|(participant, gateway)| { - MaliciousContext::new_with_gate(participant, gateway, gate.clone()) + MaliciousContext::new_with_gate( + participant, + gateway, + gate.clone(), + self.shard_info.clone(), + ) }) } } @@ -826,12 +897,20 @@ mod tests { use futures_util::future::try_join4; use crate::{ - ff::{boolean_array::BA3, Field, Fp31, U128Conversions}, + ff::{boolean::Boolean, boolean_array::BA3, Field, Fp31, U128Conversions}, helpers::{ in_memory_config::{MaliciousHelper, MaliciousHelperContext}, - Direction, Role, + Direction, Role, TotalRecords, + }, + protocol::{ + basics::SecureMul, + context::{ + dzkp_validator::DZKPValidator, upgrade::Upgradable, Context, UpgradableContext, + UpgradedContext, Validator, TEST_DZKP_STEPS, + }, + prss::SharedRandomness, + RecordId, }, - protocol::{context::Context, prss::SharedRandomness, RecordId}, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, SharedValue, @@ -961,4 +1040,65 @@ mod tests { assert_eq!(shares[1].right(), shares[2].left()); }); } + + #[test] + fn zkp_malicious_sharded() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let input = vec![Boolean::truncate_from(0_u32), Boolean::truncate_from(1_u32)]; + let r = world + .malicious(input.clone().into_iter(), |ctx, input| async move { + assert_eq!(1, input.iter().len()); + let ctx = ctx.set_total_records(TotalRecords::ONE); + let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 1); + let ctx = validator.context(); + let r = input[0] + .multiply(&input[0], ctx, RecordId::FIRST) + .await + .unwrap(); + validator.validate().await.unwrap(); + + vec![r] + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } + + #[test] + fn mac_malicious_sharded() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let input = vec![Fp31::truncate_from(0_u32), Fp31::truncate_from(1_u32)]; + let r = world + .malicious(input.clone().into_iter(), |ctx, input| async move { + assert_eq!(1, input.iter().len()); + let validator = ctx.set_total_records(1).validator(); + let ctx = validator.context(); + let (a_upgraded, b_upgraded) = (input[0].clone(), input[0].clone()) + .upgrade(ctx.clone(), RecordId::FIRST) + .await + .unwrap(); + let _ = a_upgraded + .multiply(&b_upgraded, ctx.narrow("multiply"), RecordId::FIRST) + .await + .unwrap(); + ctx.validate_record(RecordId::FIRST).await.unwrap(); + + input + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } } diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index a3600e899..e8dfd95ae 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -1,2 +1,7 @@ pub mod array; pub mod arraychunks; +#[cfg(target_pointer_width = "64")] +mod power_of_two; + +#[cfg(target_pointer_width = "64")] +pub use power_of_two::NonZeroU32PowerOfTwo; diff --git a/ipa-core/src/utils/power_of_two.rs b/ipa-core/src/utils/power_of_two.rs new file mode 100644 index 000000000..a84455c92 --- /dev/null +++ b/ipa-core/src/utils/power_of_two.rs @@ -0,0 +1,110 @@ +use std::{fmt::Display, num::NonZeroUsize, str::FromStr}; + +#[derive(Debug, thiserror::Error)] +#[error("{0} is not a power of two or not within the 1..u32::MAX range")] +pub struct ConvertError(I); + +impl PartialEq for ConvertError { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +/// This construction guarantees the value to be a power of two and +/// within the range 0..2^32-1 +#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)] +pub struct NonZeroU32PowerOfTwo(u32); + +impl Display for NonZeroU32PowerOfTwo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", u32::from(*self)) + } +} + +impl TryFrom for NonZeroU32PowerOfTwo { + type Error = ConvertError; + + fn try_from(value: usize) -> Result { + if value > 0 && value < usize::try_from(u32::MAX).unwrap() && value.is_power_of_two() { + Ok(NonZeroU32PowerOfTwo(u32::try_from(value).unwrap())) + } else { + Err(ConvertError(value)) + } + } +} + +impl From for usize { + fn from(value: NonZeroU32PowerOfTwo) -> Self { + // we are using 64 bit registers + usize::try_from(value.0).unwrap() + } +} + +impl From for u32 { + fn from(value: NonZeroU32PowerOfTwo) -> Self { + value.0 + } +} + +impl FromStr for NonZeroU32PowerOfTwo { + type Err = ConvertError; + + fn from_str(s: &str) -> Result { + let v = s.parse::().map_err(|_| ConvertError(s.to_owned()))?; + NonZeroU32PowerOfTwo::try_from(v).map_err(|_| ConvertError(s.to_owned())) + } +} + +impl NonZeroU32PowerOfTwo { + #[must_use] + pub fn to_non_zero_usize(self) -> NonZeroUsize { + let v = usize::from(self); + NonZeroUsize::new(v).unwrap_or_else(|| unreachable!()) + } + + #[must_use] + pub fn get(self) -> usize { + usize::from(self) + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use super::{ConvertError, NonZeroU32PowerOfTwo}; + + #[test] + fn rejects_invalid_values() { + assert!(matches!( + NonZeroU32PowerOfTwo::try_from(0), + Err(ConvertError(0)) + )); + assert!(matches!( + NonZeroU32PowerOfTwo::try_from(3), + Err(ConvertError(3)) + )); + + assert!(matches!( + NonZeroU32PowerOfTwo::try_from(1_usize << 33), + Err(ConvertError(_)) + )); + } + + #[test] + fn accepts_valid() { + assert_eq!(4, u32::from(NonZeroU32PowerOfTwo::try_from(4).unwrap())); + assert_eq!(16, u32::from(NonZeroU32PowerOfTwo::try_from(16).unwrap())); + } + + #[test] + fn parse_from_str() { + assert_eq!(NonZeroU32PowerOfTwo(4), "4".parse().unwrap()); + assert_eq!( + ConvertError("0".to_owned()), + "0".parse::().unwrap_err() + ); + assert_eq!( + ConvertError("3".to_owned()), + "3".parse::().unwrap_err() + ); + } +} diff --git a/ipa-core/tests/common/mod.rs b/ipa-core/tests/common/mod.rs index 56cdc37ce..ca1d5e08a 100644 --- a/ipa-core/tests/common/mod.rs +++ b/ipa-core/tests/common/mod.rs @@ -24,6 +24,7 @@ pub mod tempdir; pub const HELPER_BIN: &str = env!("CARGO_BIN_EXE_helper"); pub const TEST_MPC_BIN: &str = env!("CARGO_BIN_EXE_test_mpc"); pub const TEST_RC_BIN: &str = env!("CARGO_BIN_EXE_report_collector"); +pub const CRYPTO_UTIL_BIN: &str = env!("CARGO_BIN_EXE_crypto_util"); pub trait UnwrapStatusExt { fn unwrap_status(self); @@ -216,17 +217,27 @@ pub fn test_network(https: bool) { T::execute(path, https); } -pub fn test_ipa(mode: IpaSecurityModel, https: bool) { +pub fn test_ipa(mode: IpaSecurityModel, https: bool, encrypted_inputs: bool) { test_ipa_with_config( mode, https, IpaQueryConfig { ..Default::default() }, + encrypted_inputs, ); } -pub fn test_ipa_with_config(mode: IpaSecurityModel, https: bool, config: IpaQueryConfig) { +pub fn test_ipa_with_config( + mode: IpaSecurityModel, + https: bool, + config: IpaQueryConfig, + encrypted_inputs: bool, +) { + if encrypted_inputs & !https { + panic!("encrypted_input requires https") + }; + const INPUT_SIZE: usize = 100; // set to true to always keep the temp dir after test finishes let dir = TempDir::new_delete_on_drop(); @@ -250,11 +261,25 @@ pub fn test_ipa_with_config(mode: IpaSecurityModel, https: bool, config: IpaQuer .stdin(Stdio::piped()); command.status().unwrap_status(); + if encrypted_inputs { + // Encrypt Input + let mut command = Command::new(CRYPTO_UTIL_BIN); + command + .arg("encrypt") + .args(["--input-file".as_ref(), inputs_file.as_os_str()]) + .args(["--output-dir".as_ref(), path.as_os_str()]) + .args(["--network".into(), dir.path().join("network.toml")]) + .stdin(Stdio::piped()); + command.status().unwrap_status(); + } + // Run IPA let mut command = Command::new(TEST_RC_BIN); + if !encrypted_inputs { + command.args(["--input-file".as_ref(), inputs_file.as_os_str()]); + } command .args(["--network".into(), dir.path().join("network.toml")]) - .args(["--input-file".as_ref(), inputs_file.as_os_str()]) .args(["--output-file".as_ref(), output_file.as_os_str()]) .args(["--wait", "2"]) .silent(); @@ -263,12 +288,23 @@ pub fn test_ipa_with_config(mode: IpaSecurityModel, https: bool, config: IpaQuer command.arg("--disable-https"); } - let protocol = match mode { - IpaSecurityModel::SemiHonest => "oprf-ipa", - IpaSecurityModel::Malicious => "malicious-ipa", + let protocol = match (mode, encrypted_inputs) { + (IpaSecurityModel::SemiHonest, true) => "semi-honest-oprf-ipa", + (IpaSecurityModel::SemiHonest, false) => "semi-honest-oprf-ipa-test", + (IpaSecurityModel::Malicious, true) => "malicious-oprf-ipa", + (IpaSecurityModel::Malicious, false) => "malicious-oprf-ipa-test", }; + command.arg(protocol); + if encrypted_inputs { + let enc1 = dir.path().join("helper1.enc"); + let enc2 = dir.path().join("helper2.enc"); + let enc3 = dir.path().join("helper3.enc"); + command + .args(["--enc-input-file1".as_ref(), enc1.as_os_str()]) + .args(["--enc-input-file2".as_ref(), enc2.as_os_str()]) + .args(["--enc-input-file3".as_ref(), enc3.as_os_str()]); + } command - .arg(protocol) .args(["--max-breakdown-key", &config.max_breakdown_key.to_string()]) .args([ "--per-user-credit-cap", diff --git a/ipa-core/tests/compact_gate.rs b/ipa-core/tests/compact_gate.rs index f847275f3..354ad438c 100644 --- a/ipa-core/tests/compact_gate.rs +++ b/ipa-core/tests/compact_gate.rs @@ -12,6 +12,7 @@ fn test_compact_gate>( mode: IpaSecurityModel, per_user_credit_cap: u32, attribution_window_seconds: I, + encrypted_input: bool, ) { let config = IpaQueryConfig { per_user_credit_cap, @@ -20,25 +21,82 @@ fn test_compact_gate>( ..Default::default() }; - test_ipa_with_config(mode, false, config); + // test https with encrypted input + // and http with plaintest input + test_ipa_with_config(mode, encrypted_input, config, encrypted_input); } #[test] -fn compact_gate_cap_8_no_window_semi_honest() { - test_compact_gate(IpaSecurityModel::SemiHonest, 8, 0); +fn compact_gate_cap_8_no_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 8, 0, true); } #[test] -fn compact_gate_cap_8_with_window_semi_honest() { - test_compact_gate(IpaSecurityModel::SemiHonest, 8, 86400); +fn compact_gate_cap_1_no_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 1, 0, true); } #[test] -fn compact_gate_cap_16_no_window_semi_honest() { - test_compact_gate(IpaSecurityModel::SemiHonest, 16, 0); +fn compact_gate_cap_2_no_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 2, 0, true); } #[test] -fn compact_gate_cap_16_with_window_semi_honest() { - test_compact_gate(IpaSecurityModel::SemiHonest, 16, 86400); +fn compact_gate_cap_4_no_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 4, 0, true); +} + +#[test] +fn compact_gate_cap_8_no_window_semi_honest_plaintext_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 8, 0, false); +} + +#[test] +/// This test is turned off because of [`issue`]. +/// +/// This test will hang without `relaxed-dp` feature turned out until it is fixed +/// [`issue`]: https://github.com/private-attribution/ipa/issues/1298 +#[ignore] +fn compact_gate_cap_8_no_window_malicious_encrypted_input() { + test_compact_gate(IpaSecurityModel::Malicious, 8, 0, true); +} + +#[test] +/// This test is turned off because of [`issue`]. +/// +/// This test will hang without `relaxed-dp` feature turned out until it is fixed +/// [`issue`]: https://github.com/private-attribution/ipa/issues/1298 +#[ignore] +fn compact_gate_cap_8_no_window_malicious_plaintext_input() { + test_compact_gate(IpaSecurityModel::Malicious, 8, 0, false); +} + +#[test] +fn compact_gate_cap_8_with_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 8, 86400, true); +} + +#[test] +fn compact_gate_cap_8_with_window_semi_honest_plaintext_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 8, 86400, false); +} + +#[test] +fn compact_gate_cap_16_no_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 16, 0, true); +} + +#[test] +fn compact_gate_cap_16_no_window_semi_honest_plaintext_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 16, 0, false); +} + +#[test] +fn compact_gate_cap_16_with_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 16, 86400, true); +} + +#[test] +fn compact_gate_cap_16_with_window_semi_honest_plaintext_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 16, 86400, false); } diff --git a/ipa-core/tests/encrypted_input.rs b/ipa-core/tests/encrypted_input.rs deleted file mode 100644 index 8a7853344..000000000 --- a/ipa-core/tests/encrypted_input.rs +++ /dev/null @@ -1,200 +0,0 @@ -#[cfg(all( - feature = "test-fixture", - feature = "web-app", - feature = "cli", - feature = "in-memory-infra" -))] -mod tests { - - use std::{ - fs::File, - io::{BufRead, BufReader, Write}, - path::Path, - sync::Arc, - }; - - use bytes::BufMut; - use clap::Parser; - use hpke::Deserializable; - use ipa_core::{ - cli::{ - crypto::{encrypt, EncryptArgs}, - CsvSerializer, - }, - ff::{boolean_array::BA16, U128Conversions}, - helpers::{ - query::{IpaQueryConfig, QuerySize}, - BodyStream, - }, - hpke::{IpaPrivateKey, KeyRegistry, PrivateKeyOnly}, - query::OprfIpaQuery, - test_fixture::{ipa::TestRawDataRecord, join3v, Reconstruct, TestWorld}, - }; - use tempfile::{tempdir, NamedTempFile}; - - fn build_encrypt_args( - input_file: &Path, - output_dir: &Path, - network_file: &Path, - ) -> EncryptArgs { - EncryptArgs::try_parse_from([ - "test_encrypt", - "--input-file", - input_file.to_str().unwrap(), - "--output-dir", - output_dir.to_str().unwrap(), - "--network", - network_file.to_str().unwrap(), - ]) - .unwrap() - } - - fn write_network_file() -> NamedTempFile { - let network_data = r#" -[[peers]] -url = "helper1.test" -[peers.hpke] -public_key = "92a6fb666c37c008defd74abf3204ebea685742eab8347b08e2f7c759893947a" -[[peers]] -url = "helper2.test" -[peers.hpke] -public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" -[[peers]] -url = "helper3.test" -[peers.hpke] -public_key = "b900be35da06106a83ed73c33f733e03e4ea5888b7ea4c912ab270b0b0f8381e" -"#; - let mut network = NamedTempFile::new().unwrap(); - writeln!(network.as_file_mut(), "{network_data}").unwrap(); - network - } - - #[tokio::test] - async fn encrypt_and_execute_query() { - const EXPECTED: &[u128] = &[0, 8, 5]; - - let records: Vec = vec![ - TestRawDataRecord { - timestamp: 0, - user_id: 12345, - is_trigger_report: false, - breakdown_key: 2, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 4, - user_id: 68362, - is_trigger_report: false, - breakdown_key: 1, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 10, - user_id: 12345, - is_trigger_report: true, - breakdown_key: 0, - trigger_value: 5, - }, - TestRawDataRecord { - timestamp: 12, - user_id: 68362, - is_trigger_report: true, - breakdown_key: 0, - trigger_value: 2, - }, - TestRawDataRecord { - timestamp: 20, - user_id: 68362, - is_trigger_report: false, - breakdown_key: 1, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 30, - user_id: 68362, - is_trigger_report: true, - breakdown_key: 1, - trigger_value: 7, - }, - ]; - let query_size = QuerySize::try_from(records.len()).unwrap(); - let mut input_file = NamedTempFile::new().unwrap(); - - for event in records { - let _ = event.to_csv(input_file.as_file_mut()); - writeln!(input_file.as_file()).unwrap(); - } - input_file.as_file_mut().flush().unwrap(); - - let output_dir = tempdir().unwrap(); - let network_file = write_network_file(); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - - let enc1 = output_dir.path().join("helper1.enc"); - let enc2 = output_dir.path().join("helper2.enc"); - let enc3 = output_dir.path().join("helper3.enc"); - - let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new()); - for (i, path) in [enc1, enc2, enc3].iter().enumerate() { - let file = File::open(path).unwrap(); - let reader = BufReader::new(file); - for line in reader.lines() { - let line = line.unwrap(); - let encrypted_report_bytes = hex::decode(line.trim()).unwrap(); - println!("{}", encrypted_report_bytes.len()); - buffers[i].put_u16_le(encrypted_report_bytes.len().try_into().unwrap()); - buffers[i].put_slice(encrypted_report_bytes.as_slice()); - } - } - - let world = TestWorld::default(); - let contexts = world.contexts(); - - let mk_private_keys = vec![ - hex::decode("53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff") - .expect("manually provided for test"), - hex::decode("3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569") - .expect("manually provided for test"), - hex::decode("1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7") - .expect("manually provided for test"), - ]; - - #[allow(clippy::large_futures)] - let results = join3v(buffers.into_iter().zip(contexts).zip(mk_private_keys).map( - |((buffer, ctx), mk_private_key)| { - let query_config = IpaQueryConfig { - per_user_credit_cap: 8, - attribution_window_seconds: None, - max_breakdown_key: 3, - with_dp: 0, - epsilon: 1.0, - plaintext_match_keys: false, - }; - let input = BodyStream::from(buffer); - - let private_registry = - Arc::new(KeyRegistry::::from_keys([PrivateKeyOnly( - IpaPrivateKey::from_bytes(&mk_private_key) - .expect("manually constructed for test"), - )])); - - OprfIpaQuery::>::new( - query_config, - private_registry, - ) - .execute(ctx, query_size, input) - }, - )) - .await; - - assert_eq!( - results.reconstruct()[0..3] - .iter() - .map(U128Conversions::as_u128) - .collect::>(), - EXPECTED - ); - } -} diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index d8e8fedd0..7775ffba4 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -45,13 +45,20 @@ fn http_network_large_input() { #[test] #[cfg(all(test, web_test))] fn http_semi_honest_ipa() { - test_ipa(IpaSecurityModel::SemiHonest, false); + test_ipa(IpaSecurityModel::SemiHonest, false, false); } #[test] #[cfg(all(test, web_test))] fn https_semi_honest_ipa() { - test_ipa(IpaSecurityModel::SemiHonest, true); + test_ipa(IpaSecurityModel::SemiHonest, true, true); +} + +#[test] +#[cfg(all(test, web_test))] +#[ignore] +fn https_malicious_ipa() { + test_ipa(IpaSecurityModel::Malicious, true, true); } /// Similar to [`network`] tests, but it uses keygen + confgen CLIs to generate helper client config diff --git a/ipa-core/tests/hybrid.rs b/ipa-core/tests/hybrid.rs new file mode 100644 index 000000000..06caabbce --- /dev/null +++ b/ipa-core/tests/hybrid.rs @@ -0,0 +1,48 @@ +// some pub functions in `common` to be compiled, and rust complains about dead code. +#[allow(dead_code)] +mod common; + +use std::process::{Command, Stdio}; + +use common::{tempdir::TempDir, CommandExt, UnwrapStatusExt, TEST_RC_BIN}; +use rand::thread_rng; +use rand_core::RngCore; + +pub const IN_THE_CLEAR_BIN: &str = env!("CARGO_BIN_EXE_in_the_clear"); + +// this currently only generates data and runs in the clear +// eventaully we'll want to add the MPC as well +#[test] +fn test_hybrid() { + const INPUT_SIZE: usize = 100; + const MAX_CONVERSION_VALUE: usize = 5; + const MAX_BREAKDOWN_KEY: usize = 20; + const MAX_CONVS_PER_IMP: usize = 10; + + let dir = TempDir::new_delete_on_drop(); + + // Gen inputs + let input_file = dir.path().join("ipa_inputs.txt"); + let output_file = dir.path().join("ipa_output.json"); + + let mut command = Command::new(TEST_RC_BIN); + command + .args(["--output-file".as_ref(), input_file.as_os_str()]) + .arg("gen-hybrid-inputs") + .args(["--count", &INPUT_SIZE.to_string()]) + .args(["--max-conversion-value", &MAX_CONVERSION_VALUE.to_string()]) + .args(["--max-breakdown-key", &MAX_BREAKDOWN_KEY.to_string()]) + .args(["--max-convs-per-imp", &MAX_CONVS_PER_IMP.to_string()]) + .args(["--seed", &thread_rng().next_u64().to_string()]) + .silent() + .stdin(Stdio::piped()); + command.status().unwrap_status(); + + let mut command = Command::new(IN_THE_CLEAR_BIN); + command + .args(["--input-file".as_ref(), input_file.as_os_str()]) + .args(["--output-file".as_ref(), output_file.as_os_str()]) + .silent() + .stdin(Stdio::piped()); + command.status().unwrap_status(); +} diff --git a/ipa-core/tests/ipa_with_relaxed_dp.rs b/ipa-core/tests/ipa_with_relaxed_dp.rs new file mode 100644 index 000000000..84c4c2a7b --- /dev/null +++ b/ipa-core/tests/ipa_with_relaxed_dp.rs @@ -0,0 +1,48 @@ +#[allow(dead_code)] +mod common; + +use std::num::NonZeroU32; + +use common::{test_ipa, test_ipa_with_config}; +use ipa_core::{helpers::query::IpaQueryConfig, test_fixture::ipa::IpaSecurityModel}; + +fn build_config() -> IpaQueryConfig { + IpaQueryConfig { + per_user_credit_cap: 8, + attribution_window_seconds: NonZeroU32::new(0), + with_dp: 0, + ..Default::default() + } +} + +#[test] +fn relaxed_dp_semi_honest() { + let encrypted_input = false; + let config = build_config(); + + test_ipa_with_config( + IpaSecurityModel::SemiHonest, + encrypted_input, + config, + encrypted_input, + ); +} + +#[test] +fn relaxed_dp_malicious() { + let encrypted_input = false; + let config = build_config(); + + test_ipa_with_config( + IpaSecurityModel::Malicious, + encrypted_input, + config, + encrypted_input, + ); +} + +#[test] +#[cfg(all(test, web_test))] +fn relaxed_dp_https_malicious_ipa() { + test_ipa(IpaSecurityModel::Malicious, true, true); +} diff --git a/ipa-metrics-tracing/Cargo.toml b/ipa-metrics-tracing/Cargo.toml new file mode 100644 index 000000000..061bfe4ff --- /dev/null +++ b/ipa-metrics-tracing/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "ipa-metrics-tracing" +version = "0.1.0" +edition = "2021" + +[dependencies] +# requires partitions feature because without it, it does not make sense to use +ipa-metrics = { version = "*", path = "../ipa-metrics", features = ["partitions"] } +tracing = "0.1" +tracing-subscriber = "0.3" diff --git a/ipa-metrics-tracing/src/layer.rs b/ipa-metrics-tracing/src/layer.rs new file mode 100644 index 000000000..85d07d910 --- /dev/null +++ b/ipa-metrics-tracing/src/layer.rs @@ -0,0 +1,123 @@ +use std::fmt::Debug; + +use ipa_metrics::{CurrentThreadPartitionContext, MetricPartition, MetricsCurrentThreadContext}; +use tracing::{ + field::{Field, Visit}, + span::{Attributes, Record}, + Id, Subscriber, +}; +use tracing_subscriber::{ + layer::Context, + registry::{Extensions, ExtensionsMut, LookupSpan}, + Layer, +}; + +pub const FIELD: &str = concat!(env!("CARGO_PKG_NAME"), "-", "metrics-partition"); + +/// This layer allows partitioning metric stores. +/// This can be used in tests, where each unit test +/// creates its own unique root span. Upon entering +/// this span, this layer sets a unique partition key +#[derive(Default)] +pub struct MetricsPartitioningLayer; + +impl LookupSpan<'s>> Layer for MetricsPartitioningLayer { + fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) { + #[derive(Default)] + struct MaybeMetricPartition(Option); + + impl Visit for MaybeMetricPartition { + fn record_u64(&mut self, field: &Field, value: u64) { + if field.name() == FIELD { + self.0 = Some(value); + } + } + + fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) { + // not interested in anything else except MetricPartition values. + } + } + + let record = Record::new(attrs.values()); + let mut metric_partition = MaybeMetricPartition::default(); + record.record(&mut metric_partition); + if let Some(v) = metric_partition.0 { + let span = ctx.span(id).expect("Span should exists upon entering"); + span.extensions_mut().insert(MetricPartitionExt { + prev: None, + current: v, + }); + } + } + + fn on_enter(&self, id: &Id, ctx: Context<'_, S>) { + let span = ctx.span(id).expect("Span should exists upon entering"); + MetricPartitionExt::span_enter(span.extensions_mut()); + } + + fn on_exit(&self, id: &Id, ctx: Context<'_, S>) { + let span = ctx.span(id).expect("Span should exists upon exiting"); + MetricPartitionExt::span_exit(span.extensions_mut()); + } + + fn on_close(&self, id: Id, ctx: Context<'_, S>) { + let span = ctx.span(&id).expect("Span should exists before closing it"); + MetricPartitionExt::span_close(&span.extensions()); + } +} + +struct MetricPartitionExt { + // Partition active before span is entered. + prev: Option, + // Partition that must be set when this span is entered. + current: MetricPartition, +} + +impl MetricPartitionExt { + fn span_enter(mut span_ext: ExtensionsMut<'_>) { + if let Some(MetricPartitionExt { current, prev }) = span_ext.get_mut() { + *prev = CurrentThreadPartitionContext::get(); + CurrentThreadPartitionContext::set(*current); + } + } + + fn span_exit(mut span_ext: ExtensionsMut) { + if let Some(MetricPartitionExt { prev, .. }) = span_ext.get_mut() { + CurrentThreadPartitionContext::toggle(prev.take()); + } + } + + fn span_close(span_ext: &Extensions) { + if let Some(MetricPartitionExt { .. }) = span_ext.get() { + MetricsCurrentThreadContext::flush(); + } + } +} + +#[cfg(test)] +mod tests { + use ipa_metrics::CurrentThreadPartitionContext; + use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + + use crate::{layer::FIELD, MetricsPartitioningLayer}; + + #[test] + fn basic() { + CurrentThreadPartitionContext::set(0); + tracing_subscriber::registry() + .with(MetricsPartitioningLayer) + .init(); + let span1 = tracing::info_span!("", { FIELD } = 1_u64); + let span2 = tracing::info_span!("", { FIELD } = 2_u64); + { + let _guard1 = span1.enter(); + assert_eq!(Some(1), CurrentThreadPartitionContext::get()); + { + let _guard2 = span2.enter(); + assert_eq!(Some(2), CurrentThreadPartitionContext::get()); + } + assert_eq!(Some(1), CurrentThreadPartitionContext::get()); + } + assert_eq!(Some(0), CurrentThreadPartitionContext::get()); + } +} diff --git a/ipa-metrics-tracing/src/lib.rs b/ipa-metrics-tracing/src/lib.rs new file mode 100644 index 000000000..c72bb9e54 --- /dev/null +++ b/ipa-metrics-tracing/src/lib.rs @@ -0,0 +1,7 @@ +#![deny(clippy::pedantic)] +#![allow(clippy::similar_names)] +#![allow(clippy::module_name_repetitions)] + +mod layer; + +pub use layer::{MetricsPartitioningLayer, FIELD as PARTITION_FIELD}; diff --git a/ipa-metrics/Cargo.toml b/ipa-metrics/Cargo.toml new file mode 100644 index 000000000..ebaeb9473 --- /dev/null +++ b/ipa-metrics/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "ipa-metrics" +version = "0.1.0" +edition = "2021" + +[features] +default = [] +# support metric partitioning +partitions = [] + +[dependencies] +# crossbeam channels are faster than std +crossbeam-channel = "0.5" +# This crate uses raw entry API that is unstable in stdlib +hashbrown = "0.15" +# Fast non-collision-resistant hashing +rustc-hash = "2.0.0" +# logging +tracing = "0.1" + diff --git a/ipa-metrics/src/collector.rs b/ipa-metrics/src/collector.rs new file mode 100644 index 000000000..94022e340 --- /dev/null +++ b/ipa-metrics/src/collector.rs @@ -0,0 +1,203 @@ +use std::cell::RefCell; + +use crossbeam_channel::{Receiver, Select}; + +use crate::{ + controller::{Command, Status}, + ControllerCommand, MetricsStore, +}; + +thread_local! { + /// Collector that is installed in a thread. It is responsible for receiving metrics from + /// all threads and aggregating them. + static COLLECTOR: RefCell> = const { RefCell::new(None) } +} + +/// Convenience struct to block the current thread on metric collection +pub struct Installed; + +impl Installed { + #[allow(clippy::unused_self)] + pub fn block_until_shutdown(&self) -> MetricsStore { + MetricsCollector::with_current_mut(|c| { + c.event_loop(); + + std::mem::take(&mut c.local_store) + }) + } +} + +pub struct MetricsCollector { + pub(super) rx: Receiver, + pub(super) local_store: MetricsStore, + pub(super) command_rx: Receiver, +} + +impl MetricsCollector { + /// This installs metrics collection mechanism to current thread. + /// + /// ## Panics + /// It panics if there is another collector system already installed. + #[allow(clippy::must_use_candidate)] + pub fn install(self) -> Installed { + COLLECTOR.with_borrow_mut(|c| { + assert!(c.replace(self).is_none(), "Already initialized"); + }); + + Installed + } + + fn event_loop(&mut self) { + let mut select = Select::new(); + let data_idx = select.recv(&self.rx); + let command_idx = select.recv(&self.command_rx); + let mut state = Status::Active; + + loop { + let next_op = select.select(); + match next_op.index() { + i if i == data_idx => match next_op.recv(&self.rx) { + Ok(store) => { + tracing::trace!("Collector received more data: {store:?}"); + self.local_store.merge(store); + } + Err(e) => { + tracing::debug!("No more threads collecting metrics. Disconnected: {e}"); + select.remove(data_idx); + state = Status::Disconnected; + } + }, + i if i == command_idx => match next_op.recv(&self.command_rx) { + Ok(ControllerCommand::Snapshot(tx)) => { + tracing::trace!("Snapshot request received"); + tx.send(self.local_store.clone()).unwrap(); + } + Ok(ControllerCommand::Stop(tx)) => { + tracing::trace!("Stop signal received"); + tx.send(()).unwrap(); + break; + } + Ok(Command::Status(tx)) => { + tx.send(state).unwrap(); + } + Err(e) => { + tracing::debug!("Metric controller is disconnected: {e}"); + break; + } + }, + _ => unreachable!(), + } + } + } + + fn with_current_mut T, T>(f: F) -> T { + COLLECTOR.with_borrow_mut(|c| { + let collector = c.as_mut().expect("Collector is installed"); + f(collector) + }) + } +} + +impl Drop for MetricsCollector { + fn drop(&mut self) { + tracing::debug!("Collector is dropped"); + } +} + +#[cfg(test)] +mod tests { + use std::{ + thread, + thread::{Scope, ScopedJoinHandle}, + }; + + use crate::{ + controller::Status, counter, install, install_new_thread, producer::Producer, + MetricChannelType, + }; + + struct MeteredScope<'scope, 'env: 'scope>(&'scope Scope<'scope, 'env>, Producer); + + impl<'scope, 'env: 'scope> MeteredScope<'scope, 'env> { + fn spawn(&self, f: F) -> ScopedJoinHandle<'scope, T> + where + F: FnOnce() -> T + Send + 'scope, + T: Send + 'scope, + { + let producer = self.1.clone(); + + self.0.spawn(move || { + producer.install(); + let r = f(); + let _ = producer.drop_handle(); + + r + }) + } + } + + trait IntoMetered<'scope, 'env: 'scope> { + fn metered(&'scope self, meter: Producer) -> MeteredScope<'scope, 'env>; + } + + impl<'scope, 'env: 'scope> IntoMetered<'scope, 'env> for Scope<'scope, 'env> { + fn metered(&'scope self, meter: Producer) -> MeteredScope<'scope, 'env> { + MeteredScope(self, meter) + } + } + + #[test] + fn start_stop() { + let (collector, producer, controller) = install(MetricChannelType::Unbounded); + let handle = thread::spawn(|| { + let store = collector.install().block_until_shutdown(); + store.counter_val(counter!("foo")) + }); + + thread::scope(move |s| { + let s = s.metered(producer); + s.spawn(|| counter!("foo", 3)).join().unwrap(); + s.spawn(|| counter!("foo", 5)).join().unwrap(); + drop(s); // this causes collector to eventually stop receiving signals + while controller.status().unwrap() == Status::Active {} + controller.stop().unwrap(); + }); + + assert_eq!(8, handle.join().unwrap()); + } + + #[test] + fn with_thread() { + let (producer, controller, handle) = + install_new_thread(MetricChannelType::Unbounded).unwrap(); + thread::scope(move |s| { + let s = s.metered(producer); + s.spawn(|| counter!("baz", 4)); + s.spawn(|| counter!("bar", 1)); + s.spawn(|| { + let snapshot = controller.snapshot().unwrap(); + println!("snapshot: {snapshot:?}"); + controller.stop().unwrap(); + }); + }); + + handle.join().unwrap(); // Collector thread should be terminated by now + } + + #[test] + fn with_thread_rendezvous() { + let (producer, controller, _handle) = + install_new_thread(MetricChannelType::Rendezvous).unwrap(); + let counter = thread::scope(move |s| { + let s = s.metered(producer); + s.spawn(|| counter!("foo", 3)).join().unwrap(); + s.spawn(|| counter!("foo", 5)).join().unwrap(); + // we don't need to check the status because producer threads are now + // blocked until the collector receives their stores. This means that + // the snapshot must be up to date by now. + controller.snapshot().unwrap().counter_val(counter!("foo")) + }); + + assert_eq!(8, counter); + } +} diff --git a/ipa-metrics/src/context.rs b/ipa-metrics/src/context.rs new file mode 100644 index 000000000..f166d610b --- /dev/null +++ b/ipa-metrics/src/context.rs @@ -0,0 +1,186 @@ +use std::{cell::RefCell, mem}; + +use crossbeam_channel::Sender; + +use crate::MetricsStore; + +thread_local! { + pub(crate) static METRICS_CTX: RefCell = const { RefCell::new(MetricsContext::new()) } +} + +#[macro_export] +macro_rules! counter { + ($metric:expr, $val:expr $(, $l:expr => $v:expr)*) => {{ + let name = $crate::metric_name!($metric $(, $l => $v)*); + $crate::MetricsCurrentThreadContext::store_mut(|store| store.counter(&name).inc($val)) + }}; + ($metric:expr $(, $l:expr => $v:expr)*) => {{ + $crate::metric_name!($metric $(, $l => $v)*) + }}; +} + +/// Provides access to the metric store associated with the current thread. +/// If there is no store associated with the current thread, it will create a new one. +pub struct CurrentThreadContext; + +impl CurrentThreadContext { + pub fn init(tx: Sender) { + METRICS_CTX.with_borrow_mut(|ctx| ctx.init(tx)); + } + + pub fn flush() { + METRICS_CTX.with_borrow_mut(MetricsContext::flush); + } + + pub fn store T, T>(f: F) -> T { + METRICS_CTX.with_borrow(|ctx| f(ctx.store())) + } + + pub fn store_mut T, T>(f: F) -> T { + METRICS_CTX.with_borrow_mut(|ctx| f(ctx.store_mut())) + } + + #[must_use] + pub fn is_connected() -> bool { + METRICS_CTX.with_borrow(|ctx| ctx.tx.is_some()) + } +} + +/// This context is used inside thread-local storage, +/// so it must be wrapped inside [`std::cell::RefCell`]. +/// +/// For single-threaded applications, it is possible +/// to use it w/o connecting to the collector thread. +pub struct MetricsContext { + store: MetricsStore, + /// Handle to send metrics to the collector thread + tx: Option>, +} + +impl Default for MetricsContext { + fn default() -> Self { + Self::new() + } +} + +impl MetricsContext { + #[must_use] + pub const fn new() -> Self { + Self { + store: MetricsStore::new(), + tx: None, + } + } + + /// Connects this context to the collector thread. + /// Sender will be used to send data from this thread + fn init(&mut self, tx: Sender) { + assert!(self.tx.is_none(), "Already connected"); + + self.tx = Some(tx); + } + + #[must_use] + pub fn store(&self) -> &MetricsStore { + &self.store + } + + pub fn store_mut(&mut self) -> &mut MetricsStore { + &mut self.store + } + + fn flush(&mut self) { + if self.store.is_empty() { + return; + } + + if let Some(tx) = self.tx.as_ref() { + let store = mem::take(&mut self.store); + match tx.send(store) { + Ok(()) => {} + Err(e) => { + // Note that the store is dropped at this point. + // If it becomes a problem with collector threads disconnecting + // somewhat randomly, we can keep the old store around + // and clone it when sending. + tracing::warn!("MetricsContext is disconnected from the collector: {e}"); + } + } + } else { + tracing::warn!("MetricsContext is not connected"); + } + } +} + +impl Drop for MetricsContext { + fn drop(&mut self) { + if !self.store.is_empty() { + tracing::warn!( + "Non-empty metric store is dropped: {} metrics lost", + self.store.len() + ); + } + } +} + +#[cfg(test)] +mod tests { + use std::thread; + + use crate::{context::CurrentThreadContext, MetricsContext}; + + /// Each thread has its local store by default, and it is exclusive to it + #[test] + #[cfg(feature = "partitions")] + fn local_store() { + use crate::{context::CurrentThreadContext, CurrentThreadPartitionContext}; + + CurrentThreadPartitionContext::set(0xdead_beef); + counter!("foo", 7); + + std::thread::spawn(|| { + counter!("foo", 1); + counter!("foo", 5); + assert_eq!( + 5, + CurrentThreadContext::store(|store| store.counter_val(counter!("foo"))) + ); + }); + + assert_eq!( + 7, + CurrentThreadContext::store(|store| store.counter_val(counter!("foo"))) + ); + } + + #[test] + fn default() { + assert_eq!(0, MetricsContext::default().store().len()); + } + + #[test] + fn ignore_empty_store_on_flush() { + let (tx, rx) = crossbeam_channel::unbounded(); + let mut ctx = MetricsContext::new(); + ctx.init(tx); + let handle = + thread::spawn(move || assert!(rx.recv().is_err(), "Context sent non-empty store")); + + ctx.flush(); + drop(ctx); + handle.join().unwrap(); + } + + #[test] + fn is_connected() { + assert!(!CurrentThreadContext::is_connected()); + let (tx, rx) = crossbeam_channel::unbounded(); + + CurrentThreadContext::init(tx); + CurrentThreadContext::store_mut(|store| store.counter(counter!("foo")).inc(1)); + CurrentThreadContext::flush(); + + assert!(CurrentThreadContext::is_connected()); + assert_eq!(1, rx.recv().unwrap().counter_val(counter!("foo"))); + } +} diff --git a/ipa-metrics/src/controller.rs b/ipa-metrics/src/controller.rs new file mode 100644 index 000000000..52deed853 --- /dev/null +++ b/ipa-metrics/src/controller.rs @@ -0,0 +1,98 @@ +use crossbeam_channel::Sender; + +use crate::MetricsStore; + +/// Indicates the current status of collector thread +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum Status { + /// There are at least one active thread that can send + /// the store snapshots to the collector. Collector is actively + /// listening for new snapshots. + Active, + /// All threads have been disconnected from this collector, + /// and it is currently awaiting shutdown via [`Command::Stop`] + Disconnected, +} + +pub enum Command { + Snapshot(Sender), + Stop(Sender<()>), + Status(Sender), +} + +/// Handle to communicate with centralized metrics collection system. +pub struct Controller { + pub(super) tx: Sender, +} + +impl Controller { + /// Request new metric snapshot from the collector thread. + /// Blocks current thread until the snapshot is received + /// + /// ## Errors + /// If collector thread is disconnected or an error occurs during snapshot request + /// + /// ## Example + /// ```rust + /// use ipa_metrics::{install_new_thread, MetricChannelType, MetricsStore}; + /// + /// let (_, controller, _handle) = install_new_thread(MetricChannelType::Unbounded).unwrap(); + /// let snapshot = controller.snapshot().unwrap(); + /// println!("Current metrics: {snapshot:?}"); + /// ``` + #[inline] + pub fn snapshot(&self) -> Result { + let (tx, rx) = crossbeam_channel::bounded(0); + self.tx + .send(Command::Snapshot(tx)) + .map_err(|e| format!("An error occurred while requesting metrics snapshot: {e}"))?; + rx.recv().map_err(|e| format!("Disconnected channel: {e}")) + } + + /// Send request to terminate the collector thread. + /// Blocks current thread until the snapshot is received. + /// If this request is successful, any subsequent snapshot + /// or stop requests will return an error. + /// + /// ## Errors + /// If collector thread is disconnected or an error occurs while sending + /// or receiving data from the collector thread. + /// + /// ## Example + /// ```rust + /// use ipa_metrics::{install_new_thread, MetricChannelType, MetricsStore}; + /// + /// let (_, controller, _handle) = install_new_thread(MetricChannelType::Unbounded).unwrap(); + /// controller.stop().unwrap(); + /// ``` + pub fn stop(self) -> Result<(), String> { + let (tx, rx) = crossbeam_channel::bounded(0); + self.tx + .send(Command::Stop(tx)) + .map_err(|e| format!("An error occurred while requesting termination: {e}"))?; + rx.recv().map_err(|e| format!("Disconnected channel: {e}")) + } + + /// Request current collector status. + /// + /// ## Errors + /// If collector thread is disconnected or an error occurs while sending + /// or receiving data from the collector thread. + /// + /// ## Example + /// ```rust + /// use ipa_metrics::{install_new_thread, ControllerStatus, MetricChannelType}; + /// + /// let (_, controller, _handle) = install_new_thread(MetricChannelType::Unbounded).unwrap(); + /// let status = controller.status().unwrap(); + /// println!("Collector status: {status:?}"); + /// ``` + #[inline] + pub fn status(&self) -> Result { + let (tx, rx) = crossbeam_channel::bounded(0); + self.tx + .send(Command::Status(tx)) + .map_err(|e| format!("An error occurred while requesting status: {e}"))?; + rx.recv().map_err(|e| format!("Disconnected channel: {e}")) + } +} diff --git a/ipa-metrics/src/key.rs b/ipa-metrics/src/key.rs new file mode 100644 index 000000000..620e193e3 --- /dev/null +++ b/ipa-metrics/src/key.rs @@ -0,0 +1,293 @@ +//! Metric names supported by this crate. +//! +//! Providing a good use for metrics is a tradeoff between +//! performance and ergonomics. Flexible metric engines support +//! dynamic names, like "bytes.sent.{ip}" or "cpu.{core}.instructions" +//! but that comes with a significant performance cost. +//! String interning helps to mitigate this on the storage site +//! but callsites need to allocate those at every call. +//! +//! IPA metrics can be performance sensitive. There are counters +//! incremented on every send and receive operation, so they need +//! to be fast. For this reason, dynamic metric names are not supported. +//! Metric name can only be a string, known at compile time. +//! +//! However, it is not flexible enough. Most metrics have dimensions +//! attached to them. IPA example is `bytes.sent` metric with step breakdown. +//! It is very useful to know the required throughput per circuit. +//! +//! This metric engine supports up to 5 dimensions attached to every metric, +//! again trying to strike a good balance between performance and usability. + +use std::{ + array, + hash::{Hash, Hasher}, + iter, + iter::repeat, +}; + +pub use Name as MetricName; +pub(super) use OwnedName as OwnedMetricName; + +use crate::label::{Label, OwnedLabel, MAX_LABELS}; + +#[macro_export] +macro_rules! metric_name { + // Match when two key-value pairs are provided + // TODO: enforce uniqueness at compile time + ($metric:expr, $l1:expr => $v1:expr, $l2:expr => $v2:expr) => {{ + use $crate::UniqueElements; + let labels = [ + $crate::Label { + name: $l1, + val: $v1, + }, + $crate::Label { + name: $l2, + val: $v2, + }, + ] + .enforce_unique(); + $crate::MetricName::from_parts($metric, labels) + }}; + // Match when one key-value pair is provided + ($metric:expr, $l1:expr => $v1:expr) => {{ + $crate::MetricName::from_parts( + $metric, + [$crate::Label { + name: $l1, + val: $v1, + }], + ) + }}; + // Match when no key-value pairs are provided + ($metric:expr) => {{ + $crate::MetricName::from_parts($metric, []) + }}; +} + +/// Metric name that is created at callsite on each metric invocation. +/// For this reason, it is performance sensitive - it tries to borrow +/// whatever it can from callee stack. +#[derive(Debug, PartialEq)] +pub struct Name<'lv, const LABELS: usize = 0> { + pub(super) key: &'static str, + labels: [Label<'lv>; LABELS], +} + +impl<'lv, const LABELS: usize> Name<'lv, LABELS> { + /// Constructs this instance from key and labels. + /// ## Panics + /// If number of labels exceeds `MAX_LABELS`. + pub fn from_parts>(key: I, labels: [Label<'lv>; LABELS]) -> Self { + assert!( + LABELS <= MAX_LABELS, + "Maximum 5 labels per metric is supported" + ); + + Self { + key: key.into(), + labels, + } + } + + /// [`ToOwned`] trait does not work because of + /// extra [`Borrow`] requirement + pub(super) fn to_owned(&self) -> OwnedName { + let labels: [_; 5] = array::from_fn(|i| { + if i < self.labels.len() { + Some(self.labels[i].to_owned()) + } else { + None + } + }); + + OwnedName { + key: self.key, + labels, + } + } +} + +/// Same as [`Name`], but intended for internal use. This is an owned +/// version of it, that does not borrow anything from outside. +/// This is the key inside metric stores which are simple hashmaps. +#[derive(Debug, Clone, Eq)] +pub struct OwnedName { + pub key: &'static str, + labels: [Option; 5], +} + +impl OwnedName { + pub fn labels(&self) -> impl Iterator { + self.labels.iter().filter_map(|l| l.as_ref()) + } + + /// Checks that a subset of labels in `self` matches all values in `other`. + #[must_use] + pub fn partial_match(&self, other: &Name<'_, LABELS>) -> bool { + if self.key == other.key { + other.labels.iter().all(|l| self.find_label(l)) + } else { + false + } + } + + fn find_label(&self, label: &Label<'_>) -> bool { + self.labels().any(|l| l.as_borrowed().eq(label)) + } +} + +impl Hash for Name<'_, LABELS> { + fn hash(&self, state: &mut H) { + Hash::hash(&self.key, state); + // to be consistent with `OwnedName` hashing, we need to + // serialize labels without slice length prefix. + for label in &self.labels { + label.hash(state); + } + } +} + +impl From<&'static str> for Name<'_, 0> { + fn from(value: &'static str) -> Self { + Self { + key: value, + labels: [], + } + } +} + +pub trait UniqueElements { + #[must_use] + fn enforce_unique(self) -> Self; +} + +impl UniqueElements for [Label<'_>; 2] { + fn enforce_unique(self) -> Self { + assert_ne!(self[0].name, self[1].name, "label names must be unique"); + + self + } +} + +impl<'a, const LABELS: usize> PartialEq> for OwnedName { + fn eq(&self, other: &Name<'a, LABELS>) -> bool { + self.key == other.key + && iter::zip( + &self.labels, + other.labels.iter().map(Some).chain(repeat(None)), + ) + .all(|(a, b)| match (a, b) { + (Some(a), Some(b)) => a.as_borrowed() == *b, + (None, None) => true, + _ => false, + }) + } +} + +impl PartialEq for OwnedName { + fn eq(&self, other: &OwnedName) -> bool { + self.key == other.key && self.labels.eq(&other.labels) + } +} + +impl Hash for OwnedName { + fn hash(&self, state: &mut H) { + Hash::hash(self.key, state); + for label in self.labels.iter().flatten() { + label.hash(state); + } + } +} + +#[cfg(test)] +pub fn compute_hash(value: V) -> u64 { + let mut hasher = crate::label_hasher(); + value.hash(&mut hasher); + + hasher.finish() +} + +#[cfg(test)] +mod tests { + use crate::{ + key::{compute_hash, Name}, + label::Label, + }; + + #[test] + fn eq() { + let name = Name::from("foo"); + assert_eq!(name.to_owned(), name); + } + + #[test] + fn hash_eq() { + let a = Name::from("foo"); + let b = Name::from("foo"); + assert_eq!(compute_hash(&a), compute_hash(b)); + assert_eq!(compute_hash(&a), compute_hash(a.to_owned())); + } + + #[test] + fn not_eq() { + let foo = Name::from("foo"); + let bar = Name::from("bar"); + assert_ne!(foo.to_owned(), bar); + } + + #[test] + fn hash_not_eq() { + let foo = Name::from("foo"); + let bar = Name::from("bar"); + assert_ne!(compute_hash(&foo), compute_hash(&bar)); + assert_ne!(compute_hash(foo), compute_hash(bar.to_owned())); + } + + #[test] + #[should_panic(expected = "Maximum 5 labels per metric is supported")] + fn more_than_5_labels() { + let _ = Name::from_parts( + "foo", + [ + Label { + name: "label_1", + val: &1, + }, + Label { + name: "label_2", + val: &1, + }, + Label { + name: "label_3", + val: &1, + }, + Label { + name: "label_4", + val: &1, + }, + Label { + name: "label_5", + val: &1, + }, + Label { + name: "label_6", + val: &1, + }, + ], + ); + } + + #[test] + fn eq_is_consistent() { + let a_name = metric_name!("foo", "label_1" => &1); + let b_name = metric_name!("foo", "label_1" => &1, "label_2" => &2); + + assert_eq!(a_name, a_name); + assert_eq!(a_name.to_owned(), a_name); + + assert_ne!(b_name.to_owned(), a_name); + assert_ne!(a_name.to_owned(), b_name); + } +} diff --git a/ipa-metrics/src/kind.rs b/ipa-metrics/src/kind.rs new file mode 100644 index 000000000..3a48d105b --- /dev/null +++ b/ipa-metrics/src/kind.rs @@ -0,0 +1,5 @@ +//! Different metric types supported by this crate. +//! Currently, only counters are supported. + +/// Counters are simple 8 byte values. +pub type CounterValue = u64; diff --git a/ipa-metrics/src/label.rs b/ipa-metrics/src/label.rs new file mode 100644 index 000000000..dd822be86 --- /dev/null +++ b/ipa-metrics/src/label.rs @@ -0,0 +1,183 @@ +use std::{ + fmt::{Debug, Display, Formatter}, + hash::{Hash, Hasher}, +}; + +use rustc_hash::FxHasher; + +pub const MAX_LABELS: usize = 5; + +/// Provides a fast, non-collision resistant implementation of [`Hasher`] +/// for label values. +/// +/// [`Hasher`]: std::hash::Hasher +#[must_use] +pub fn label_hasher() -> impl Hasher { + FxHasher::default() +} + +/// Dimension value (or label value) must be sendable to another thread +/// and there must be a way to show it +pub trait LabelValue: Display + Send { + /// Creates a unique hash for this value. + /// It is easy to create collisions, so better avoid them, + /// by assigning a unique integer to each value + /// + /// Note that this value is used for uniqueness check inside + /// metric stores + fn hash(&self) -> u64; + + /// Creates an owned copy of this value. Dynamic dispatch + /// is required, because values are stored in a generic store + /// that can't be specialized for value types. + fn boxed(&self) -> Box; +} + +impl LabelValue for u32 { + fn hash(&self) -> u64 { + u64::from(*self) + } + + fn boxed(&self) -> Box { + Box::new(*self) + } +} + +pub struct Label<'lv> { + pub name: &'static str, + pub val: &'lv dyn LabelValue, +} + +impl Label<'_> { + #[must_use] + pub fn to_owned(&self) -> OwnedLabel { + OwnedLabel { + name: self.name, + val: self.val.boxed(), + } + } +} + +impl Debug for Label<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Label") + .field("name", &self.name) + .field("val", &format!("{}", self.val)) + .finish() + } +} + +impl Hash for Label<'_> { + fn hash(&self, state: &mut H) { + Hash::hash(&self.name, state); + Hash::hash(&self.val.hash(), state); + } +} + +impl PartialEq for Label<'_> { + fn eq(&self, other: &Self) -> bool { + // name check should be fast - just pointer comparison. + // val check is more involved with dynamic dispatch, so we can consider + // making label immutable and storing a hash of the value in place + self.name == other.name && self.val.hash() == other.val.hash() + } +} + +/// Same as [`Label`] but owns the values. This instance is stored +/// inside metric hashmaps as they need to own the keys. +pub struct OwnedLabel { + pub name: &'static str, + pub val: Box, +} + +impl Clone for OwnedLabel { + fn clone(&self) -> Self { + Self { + name: self.name, + val: self.val.boxed(), + } + } +} + +impl OwnedLabel { + pub fn as_borrowed(&self) -> Label<'_> { + Label { + name: self.name, + val: self.val.as_ref(), + } + } +} + +impl Hash for OwnedLabel { + fn hash(&self, state: &mut H) { + self.as_borrowed().hash(state); + } +} + +impl Debug for OwnedLabel { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OwnedLabel") + .field("name", &self.name) + .field("val", &format!("{}", self.val)) + .finish() + } +} + +impl PartialEq for OwnedLabel { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.val.hash() == other.val.hash() + } +} + +impl Eq for OwnedLabel {} + +#[cfg(test)] +mod tests { + + use crate::{key::compute_hash, metric_name}; + + #[test] + fn one_label() { + let foo_1 = metric_name!("foo", "l1" => &1); + let foo_2 = metric_name!("foo", "l1" => &2); + + assert_ne!(foo_1.to_owned(), foo_2); + assert_ne!(compute_hash(&foo_1), compute_hash(&foo_2)); + assert_ne!(foo_2.to_owned(), foo_1); + + assert_eq!(compute_hash(&foo_1), compute_hash(foo_1.to_owned())); + } + + #[test] + #[should_panic(expected = "label names must be unique")] + fn unique() { + metric_name!("foo", "l1" => &1, "l1" => &0); + } + + #[test] + fn non_commutative() { + assert_ne!( + compute_hash(&metric_name!("foo", "l1" => &1, "l2" => &0)), + compute_hash(&metric_name!("foo", "l1" => &0, "l2" => &1)), + ); + assert_ne!( + compute_hash(&metric_name!("foo", "l1" => &1)), + compute_hash(&metric_name!("foo", "l1" => &1, "l2" => &1)), + ); + } + + #[test] + fn clone() { + let metric = metric_name!("foo", "l1" => &1).to_owned(); + assert_eq!(&metric.labels().next(), &metric.labels().next().clone()); + } + + #[test] + fn fields() { + let metric = metric_name!("foo", "l1" => &1).to_owned(); + let label = metric.labels().next().unwrap().to_owned(); + + assert_eq!(label.name, "l1"); + assert_eq!(label.val.to_string(), "1"); + } +} diff --git a/ipa-metrics/src/lib.rs b/ipa-metrics/src/lib.rs new file mode 100644 index 000000000..2449d41a3 --- /dev/null +++ b/ipa-metrics/src/lib.rs @@ -0,0 +1,121 @@ +#![deny(clippy::pedantic)] +#![allow(clippy::similar_names)] +#![allow(clippy::module_name_repetitions)] + +mod collector; +mod context; +mod controller; +mod key; +mod kind; +mod label; +#[cfg(feature = "partitions")] +mod partitioned; +mod producer; +mod store; + +use std::{io, thread::JoinHandle}; + +pub use collector::MetricsCollector; +pub use context::{CurrentThreadContext as MetricsCurrentThreadContext, MetricsContext}; +pub use controller::{ + Command as ControllerCommand, Controller as MetricsCollectorController, + Status as ControllerStatus, +}; +pub use key::{MetricName, OwnedName, UniqueElements}; +pub use label::{label_hasher, Label, LabelValue}; +#[cfg(feature = "partitions")] +pub use partitioned::{ + CurrentThreadContext as CurrentThreadPartitionContext, Partition as MetricPartition, + PartitionedStore as MetricsStore, +}; +pub use producer::Producer as MetricsProducer; +#[cfg(not(feature = "partitions"))] +pub use store::Store as MetricsStore; + +/// Type of the communication channel between metric producers +/// and the collector. +#[derive(Copy, Clone)] +pub enum MetricChannelType { + /// Each send message must be paired with receive. Sends that + /// don't get a pair block the thread until collector processes + /// the request. This mode is suitable for unit tests where metric + /// consistency is important and gets more priority than availability. + Rendezvous, + /// Each channel between producer and collector gets unlimited capacity. + Unbounded, +} + +/// Creates metric infrastructure that is ready to use +/// in the application code. It consists a triple of +/// [`MetricsCollector`], [`MetricsProducer`], and +/// [`MetricsCollectorController`]. +/// +/// Collector is used in the centralized place (a dedicated thread) +/// to collect metrics coming from thread local stores. +/// +/// Metric producer must be installed on every thread that is used +/// to emit telemetry, and it connects that thread to the collector. +/// +/// Controller provides command-line API interface to the collector. +/// A thread that owns the controller, can request current snapshot. +/// For more information about API, see [`Command`]. +/// +/// The communication channel between producers and collector is configured +/// via `channel_type` parameter. See [`MetricChannelType`] for details +/// +/// ## Example 1 (Rendezvous channels) +/// ```rust +/// use ipa_metrics::MetricChannelType; +/// let (collector, producer, controller) = ipa_metrics::install(MetricChannelType::Rendezvous); +/// ``` +/// +/// ## Example 2 (unbounded) +/// ```rust +/// use ipa_metrics::MetricChannelType; +/// let (collector, producer, controller) = ipa_metrics::install(MetricChannelType::Unbounded); +/// ``` +/// +/// [`MetricsCollector`]: crate::MetricsCollector +/// [`MetricsProducer`]: crate::MetricsProducer +/// [`MetricsCollectorController`]: crate::MetricsCollectorController +/// [`Command`]: crate::ControllerCommand +#[must_use] +pub fn install( + channel_type: MetricChannelType, +) -> ( + MetricsCollector, + MetricsProducer, + MetricsCollectorController, +) { + let (command_tx, command_rx) = crossbeam_channel::unbounded(); + let (tx, rx) = match channel_type { + MetricChannelType::Rendezvous => crossbeam_channel::bounded(0), + MetricChannelType::Unbounded => crossbeam_channel::unbounded(), + }; + ( + MetricsCollector { + rx, + local_store: MetricsStore::default(), + command_rx, + }, + MetricsProducer { tx }, + MetricsCollectorController { tx: command_tx }, + ) +} + +/// Same as [`install`] but spawns a new thread to run the collector. +/// +/// ## Errors +/// if thread cannot be started +pub fn install_new_thread( + channel_type: MetricChannelType, +) -> io::Result<(MetricsProducer, MetricsCollectorController, JoinHandle<()>)> { + let (collector, producer, controller) = install(channel_type); + let handle = std::thread::Builder::new() + .name("metric-collector".to_string()) + .spawn(|| { + collector.install().block_until_shutdown(); + })?; + + Ok((producer, controller, handle)) +} diff --git a/ipa-metrics/src/partitioned.rs b/ipa-metrics/src/partitioned.rs new file mode 100644 index 000000000..0f71d0e28 --- /dev/null +++ b/ipa-metrics/src/partitioned.rs @@ -0,0 +1,253 @@ +//! This module enables metric partitioning that can be useful +//! when threads that emit metrics are shared across multiple executions. +//! A typical example for it are unit tests in Rust that share threads. +//! Having a global per-thread store would mean that it is not possible +//! to distinguish between different runs. +//! +//! Partitioning attempts to solve this with a global 16 byte identifier that +//! is set in thread local storage and read automatically by [`PartitionedStore`] +//! +//! Note that this module does not provide means to automatically set and unset +//! partitions. `ipa-metrics-tracing` defines a way to do it via tracing context +//! that is good enough for the vast majority of use cases. +//! +//! Because partitioned stores carry additional cost of extra lookup (partition -> store), +//! it is disabled by default and requires explicit opt-in via `partitioning` feature. + +use std::{borrow::Borrow, cell::Cell}; + +use hashbrown::hash_map::Entry; +use rustc_hash::FxBuildHasher; + +use crate::{ + kind::CounterValue, + store::{CounterHandle, Store}, + MetricName, +}; + +thread_local! { + static PARTITION: Cell> = const { Cell::new(None) } +} + +/// Each partition is a unique 8 byte value, meaning roughly 1B partitions +/// can be supported and the limiting factor is birthday bound. +pub type Partition = u64; + +pub struct CurrentThreadContext; + +impl CurrentThreadContext { + pub fn set(new: Partition) { + Self::toggle(Some(new)); + } + + pub fn toggle(new: Option) { + PARTITION.set(new); + } + + #[must_use] + pub fn get() -> Option { + PARTITION.get() + } +} + +/// Provides the same functionality as [`Store`], but partitioned +/// across many dimensions. There is an extra price for it, so +/// don't use it, unless you need it. +/// The dimension is set through [`std::thread::LocalKey`], so +/// each thread can set only one dimension at a time. +/// +/// The API of this struct will match [`Store`] as they +/// can be used interchangeably. +#[derive(Clone, Debug)] +pub struct PartitionedStore { + /// Set of stores partitioned by [`Partition`] + inner: hashbrown::HashMap, + /// We don't want to lose metrics that are emitted when partitions are not set. + /// So we provide a default store for those + default_store: Store, +} + +impl Default for PartitionedStore { + fn default() -> Self { + Self::new() + } +} + +impl PartitionedStore { + #[must_use] + pub const fn new() -> Self { + Self { + inner: hashbrown::HashMap::with_hasher(FxBuildHasher), + default_store: Store::new(), + } + } + + pub fn with_partition T, T>( + &self, + partition: Partition, + f: F, + ) -> Option { + let store = self.inner.get(&partition); + store.map(f) + } + + pub fn merge(&mut self, other: Self) { + for (partition, store) in other.inner { + self.get_mut(Some(partition)).merge(store); + } + self.default_store.merge(other.default_store); + } + + pub fn counter_val<'a, const LABELS: usize, B: Borrow>>( + &'a self, + key: B, + ) -> CounterValue { + let name = key.borrow(); + if let Some(partition) = CurrentThreadContext::get() { + self.inner + .get(&partition) + .map(|store| store.counter_val(name)) + .unwrap_or_default() + } else { + self.default_store.counter_val(name) + } + } + + pub fn counter<'a, const LABELS: usize, B: Borrow>>( + &'a mut self, + key: B, + ) -> CounterHandle<'a, LABELS> { + self.get_mut(CurrentThreadContext::get()).counter(key) + } + + #[must_use] + pub fn len(&self) -> usize { + self.inner.len() + self.default_store.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[allow(dead_code)] + fn with_partition_mut T, T>( + &mut self, + partition: Partition, + f: F, + ) -> T { + let store = self.get_mut(Some(partition)); + f(store) + } + + fn get_mut(&mut self, partition: Option) -> &mut Store { + if let Some(v) = partition { + match self.inner.entry(v) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => entry.insert(Store::default()), + } + } else { + &mut self.default_store + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + counter, metric_name, + partitioned::{CurrentThreadContext, PartitionedStore}, + }; + + #[test] + fn unique_partition() { + let metric = metric_name!("foo"); + let mut store = PartitionedStore::new(); + store.with_partition_mut(1, |store| { + store.counter(&metric).inc(1); + }); + store.with_partition_mut(5, |store| { + store.counter(&metric).inc(5); + }); + + assert_eq!( + 5, + store.with_partition_mut(5, |store| store.counter(&metric).get()) + ); + assert_eq!( + 1, + store.with_partition_mut(1, |store| store.counter(&metric).get()) + ); + assert_eq!( + 0, + store.with_partition_mut(10, |store| store.counter(&metric).get()) + ); + } + + #[test] + fn current_partition() { + let metric = metric_name!("foo"); + let mut store = PartitionedStore::new(); + store.counter(&metric).inc(7); + + CurrentThreadContext::set(4); + + store.counter(&metric).inc(1); + store.counter(&metric).inc(5); + + assert_eq!(6, store.counter_val(&metric)); + CurrentThreadContext::toggle(None); + assert_eq!(7, store.counter_val(&metric)); + } + + #[test] + fn empty() { + let mut store = PartitionedStore::default(); + assert!(store.is_empty()); + store.counter(&metric_name!("foo")).inc(1); + + assert!(!store.is_empty()); + } + + #[test] + fn len() { + let mut store = PartitionedStore::new(); + assert_eq!(0, store.len()); + + store.counter(metric_name!("foo")).inc(1); + CurrentThreadContext::set(4); + store.counter(metric_name!("foo")).inc(1); + + // one metric in partition 4, another one in default. Even that they are the same, + // partitioned store cannot distinguish between them + assert_eq!(2, store.len()); + } + + #[test] + fn merge() { + let mut store1 = PartitionedStore::new(); + let mut store2 = PartitionedStore::new(); + store1.with_partition_mut(1, |store| store.counter(counter!("foo")).inc(1)); + store2.with_partition_mut(1, |store| store.counter(counter!("foo")).inc(1)); + store1.with_partition_mut(2, |store| store.counter(counter!("foo")).inc(2)); + store2.with_partition_mut(2, |store| store.counter(counter!("foo")).inc(2)); + + store1.counter(counter!("foo")).inc(3); + store2.counter(counter!("foo")).inc(3); + + store1.merge(store2); + assert_eq!( + 2, + store1 + .with_partition(1, |store| store.counter_val(counter!("foo"))) + .unwrap() + ); + assert_eq!( + 4, + store1 + .with_partition(2, |store| store.counter_val(counter!("foo"))) + .unwrap() + ); + assert_eq!(6, store1.counter_val(counter!("foo"))); + } +} diff --git a/ipa-metrics/src/producer.rs b/ipa-metrics/src/producer.rs new file mode 100644 index 000000000..f9ee42cc3 --- /dev/null +++ b/ipa-metrics/src/producer.rs @@ -0,0 +1,48 @@ +use crossbeam_channel::Sender; + +use crate::{context::CurrentThreadContext, MetricsStore}; + +/// A handle to enable centralized metrics collection from the current thread. +/// +/// This is a cloneable handle, so it can be installed in multiple threads. +/// The handle is installed by calling [`install`], which returns a drop handle. +/// When the drop handle is dropped, the context of local store is flushed +/// to the collector thread. +/// +/// Thread local store is always enabled by [`MetricsContext`], so it is always +/// possible to have a local view of metrics emitted by this thread. +/// +/// [`install`]: Producer::install +#[derive(Clone)] +pub struct Producer { + pub(super) tx: Sender, +} + +impl Producer { + pub fn install(&self) { + CurrentThreadContext::init(self.tx.clone()); + } + + /// Returns a drop handle that should be used when thread is stopped. + /// One may think destructor on [`MetricsContext`] could do this, + /// but as pointed in [`LocalKey`] documentation, deadlocks are possible + /// if another TLS storage is accessed at destruction time. + /// + /// I actually ran into this problem with crossbeam channels. Send operation + /// requires access to `thread::current` and that panics at runtime if called + /// from inside `Drop`. + /// + /// [`LocalKey`]: + pub fn drop_handle(&self) -> ProducerDropHandle { + ProducerDropHandle + } +} + +#[must_use] +pub struct ProducerDropHandle; + +impl Drop for ProducerDropHandle { + fn drop(&mut self) { + CurrentThreadContext::flush(); + } +} diff --git a/ipa-metrics/src/store.rs b/ipa-metrics/src/store.rs new file mode 100644 index 000000000..e893ffd84 --- /dev/null +++ b/ipa-metrics/src/store.rs @@ -0,0 +1,243 @@ +use std::{borrow::Borrow, hash::BuildHasher}; + +use hashbrown::hash_map::RawEntryMut; +use rustc_hash::FxBuildHasher; + +use crate::{key::OwnedMetricName, kind::CounterValue, MetricName}; + +/// A basic store. Currently only supports counters. +/// Counters and other metrics are stored to optimize writes. That means, one lookup +/// per write. The cost of assembling the total count across all dimensions is absorbed +/// by readers +#[derive(Clone, Debug)] +pub struct Store { + counters: hashbrown::HashMap, +} + +impl Default for Store { + fn default() -> Self { + Self::new() + } +} + +impl Store { + #[must_use] + pub const fn new() -> Self { + Self { + counters: hashbrown::HashMap::with_hasher(FxBuildHasher), + } + } + + pub fn merge(&mut self, other: Self) { + for (k, v) in other.counters { + let hash_builder = self.counters.hasher(); + let hash = hash_builder.hash_one(&k); + *self + .counters + .raw_entry_mut() + .from_hash(hash, |other| other.eq(&k)) + .or_insert(k, 0) + .1 += v; + } + } + + pub fn counter<'a, const LABELS: usize, B: Borrow>>( + &'a mut self, + key: B, + ) -> CounterHandle<'a, LABELS> { + let key = key.borrow(); + let hash_builder = self.counters.hasher(); + let hash = hash_builder.hash_one(key); + let entry = self + .counters + .raw_entry_mut() + .from_hash(hash, |key_found| key_found.eq(key)); + match entry { + RawEntryMut::Occupied(slot) => CounterHandle { + val: slot.into_mut(), + }, + RawEntryMut::Vacant(slot) => { + let (_, val) = slot.insert_hashed_nocheck(hash, key.to_owned(), Default::default()); + CounterHandle { val } + } + } + } + + /// Returns the value for the specified metric, limited by any specified dimensions, + /// but not by any unspecified dimensions. If metric foo has dimensions dim1 and dim2, + /// a query for (foo, dim1 = 1) will sum the counter values having dim1 = 1 + /// and any value of dim2. + /// The cost of this operation is `O(N*M)` where `N` - number of unique metrics + /// registered in this store and `M` number of dimensions. + /// + /// Note that the cost can be improved if it ever becomes a bottleneck by + /// creating a specialized two-level map (metric -> label -> value). + pub fn counter_val<'a, const LABELS: usize, B: Borrow>>( + &'a self, + key: B, + ) -> CounterValue { + let key = key.borrow(); + + self.counters + .iter() + .filter(|(counter, _)| counter.partial_match(key)) + .map(|(_, val)| val) + .sum() + } + + #[must_use] + pub fn len(&self) -> usize { + self.counters.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns an iterator over the counters in the store. + /// + /// The iterator item is a tuple of the metric name and the counter value. + pub fn counters(&self) -> impl Iterator { + self.counters.iter().map(|(key, value)| (key, *value)) + } +} + +pub struct CounterHandle<'a, const LABELS: usize> { + val: &'a mut CounterValue, +} + +impl CounterHandle<'_, LABELS> { + pub fn inc(&mut self, inc: CounterValue) { + *self.val += inc; + } + + pub fn get(&self) -> CounterValue { + *self.val + } +} + +#[cfg(test)] +mod tests { + use std::hash::{DefaultHasher, Hash, Hasher}; + + use crate::{counter, metric_name, store::Store, LabelValue}; + + impl LabelValue for &'static str { + fn hash(&self) -> u64 { + // TODO: use fast hashing here + let mut hasher = DefaultHasher::default(); + Hash::hash(self, &mut hasher); + + hasher.finish() + } + + fn boxed(&self) -> Box { + Box::new(*self) + } + } + + #[test] + fn counter() { + let mut store = Store::default(); + let name = metric_name!("foo"); + { + let mut handle = store.counter(&name); + assert_eq!(0, handle.get()); + handle.inc(3); + assert_eq!(3, handle.get()); + } + + { + store.counter(&name).inc(0); + assert_eq!(3, store.counter(&name).get()); + } + } + + #[test] + fn with_labels() { + let mut store = Store::default(); + let valid_name = metric_name!("foo", "h1" => &1, "h2" => &"2"); + let wrong_name = metric_name!("foo", "h1" => &2, "h2" => &"2"); + store.counter(&valid_name).inc(2); + + assert_eq!(2, store.counter(&valid_name).get()); + assert_eq!(0, store.counter(&wrong_name).get()); + } + + #[test] + fn merge() { + let mut store1 = Store::default(); + let mut store2 = Store::default(); + let foo = metric_name!("foo", "h1" => &1, "h2" => &"2"); + let bar = metric_name!("bar", "h2" => &"2"); + let baz = metric_name!("baz"); + store1.counter(&foo).inc(2); + store2.counter(&foo).inc(1); + + store1.counter(&bar).inc(7); + store2.counter(&baz).inc(3); + + store1.merge(store2); + + assert_eq!(3, store1.counter(&foo).get()); + assert_eq!(7, store1.counter(&bar).get()); + assert_eq!(3, store1.counter(&baz).get()); + } + + #[test] + fn counter_value() { + let mut store = Store::default(); + store + .counter(counter!("foo", "h1" => &1, "h2" => &"1")) + .inc(1); + store + .counter(counter!("foo", "h1" => &1, "h2" => &"2")) + .inc(1); + store + .counter(counter!("foo", "h1" => &2, "h2" => &"1")) + .inc(1); + store + .counter(counter!("foo", "h1" => &2, "h2" => &"2")) + .inc(1); + store + .counter(counter!("bar", "h1" => &1, "h2" => &"1")) + .inc(3); + + assert_eq!(4, store.counter_val(counter!("foo"))); + assert_eq!( + 1, + store.counter_val(&counter!("foo", "h1" => &1, "h2" => &"2")) + ); + assert_eq!(2, store.counter_val(&counter!("foo", "h1" => &1))); + assert_eq!(2, store.counter_val(&counter!("foo", "h2" => &"2"))); + } + + #[test] + fn len_empty() { + let mut store = Store::default(); + assert!(store.is_empty()); + assert_eq!(0, store.len()); + + store.counter(counter!("foo")).inc(1); + assert!(!store.is_empty()); + assert_eq!(1, store.len()); + + store.counter(counter!("foo")).inc(1); + assert_eq!(1, store.len()); + + store.counter(counter!("bar")).inc(1); + assert_eq!(2, store.len()); + } + + #[test] + fn counters() { + let mut store = Store::default(); + store.counter(counter!("foo")).inc(1); + store.counter(counter!("foo", "h1" => &1)).inc(1); + store.counter(counter!("foo", "h2" => &2)).inc(1); + store.counter(counter!("bar")).inc(1); + + assert_eq!((4, Some(4)), store.counters().size_hint()); + } +} diff --git a/ipa-step-derive/src/lib.rs b/ipa-step-derive/src/lib.rs index 0ffdd7b1c..2298e3969 100644 --- a/ipa-step-derive/src/lib.rs +++ b/ipa-step-derive/src/lib.rs @@ -118,7 +118,7 @@ fn derive_step_impl(ast: &DeriveInput) -> Result { let mut g = Generator::default(); let attr = match &ast.data { Data::Enum(data) => { - for v in VariantAttribute::parse_variants(data)? { + for v in VariantAttribute::parse_variants(ident, data)? { g.add_variant(&v); } VariantAttribute::parse_outer(ident, &ast.attrs, None)? @@ -165,6 +165,18 @@ fn derive_gate_impl(ast: &DeriveInput) -> TokenStream { ::fmt(self, f) } } + + impl #name { + /// Returns the current index. It matches the index of the latest step + /// this gate has been narrowed to. + /// + /// If gate hasn't been narrowed yet, it returns the index of the default value. + #[must_use] + pub fn index(&self) -> ::ipa_step::CompactGateIndex { + self.0 + } + } + }; // This environment variable is set by build scripts, @@ -402,6 +414,16 @@ mod test { "e! { impl ::ipa_step::Step for ManyArms {} + impl ManyArms { + pub fn arm(v: u8) -> Self { + assert!( + v < u8::try_from(3usize).unwrap(), + "Step index {v} out of bounds for ManyArms::Arm with count 3.", + ); + Self::Arm(v) + } + } + #[allow( clippy::useless_conversion, clippy::unnecessary_fallible_conversions, @@ -424,7 +446,8 @@ mod test { const STEP_COUNT: ::ipa_step::CompactGateIndex = 3; fn base_index (& self) -> ::ipa_step::CompactGateIndex { match self { - Self::Arm (i) => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + Self::Arm (i) if *i < u8::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + Self::Arm (i) => panic!("Step index {i} out of bounds for ManyArms::Arm with count 3. Consider using bounds-checked step constructors."), } } fn step_string(i: ::ipa_step::CompactGateIndex) -> String { @@ -451,6 +474,16 @@ mod test { "e! { impl ::ipa_step::Step for ManyArms {} + impl ManyArms { + pub fn arm(v: u8) -> Self { + assert!( + v < u8::try_from(3usize).unwrap(), + "Step index {v} out of bounds for ManyArms::Arm with count 3.", + ); + Self::Arm(v) + } + } + #[allow( clippy::useless_conversion, clippy::unnecessary_fallible_conversions, @@ -473,7 +506,8 @@ mod test { const STEP_COUNT: ::ipa_step::CompactGateIndex = 3; fn base_index (& self) -> ::ipa_step::CompactGateIndex { match self { - Self::Arm (i) => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + Self::Arm (i) if *i < u8::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + Self::Arm (i) => panic!("Step index {i} out of bounds for ManyArms::Arm with count 3. Consider using bounds-checked step constructors."), } } fn step_string(i: ::ipa_step::CompactGateIndex) -> String { @@ -642,6 +676,15 @@ mod test { "e! { impl ::ipa_step::Step for Parent {} + impl Parent { + pub fn offspring(v: u8) -> Self { + assert!( + v < u8::try_from(5usize).unwrap(), + "Step index {v} out of bounds for Parent::Offspring with count 5.", + ); + Self::Offspring(v) + } + } #[allow( clippy::useless_conversion, @@ -667,7 +710,8 @@ mod test { const STEP_COUNT: ::ipa_step::CompactGateIndex = (::STEP_COUNT + 1) * 5; fn base_index(&self) -> ::ipa_step::CompactGateIndex { match self { - Self::Offspring(i) => (::STEP_COUNT + 1) * ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + Self::Offspring(i) if *i < u8::try_from(5usize).unwrap() => (::STEP_COUNT + 1) * ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + Self::Offspring(i) => panic!("Step index {i} out of bounds for Parent::Offspring with count 5. Consider using bounds-checked step constructors."), } } fn step_string(i: ::ipa_step::CompactGateIndex) -> String { @@ -726,6 +770,16 @@ mod test { "e! { impl ::ipa_step::Step for AllArms {} + impl AllArms { + pub fn int(v: usize) -> Self { + assert!( + v < usize::try_from(3usize).unwrap(), + "Step index {v} out of bounds for AllArms::Int with count 3.", + ); + Self::Int(v) + } + } + #[allow( clippy::useless_conversion, clippy::unnecessary_fallible_conversions, @@ -752,7 +806,8 @@ mod test { fn base_index(&self) -> ::ipa_step::CompactGateIndex { match self { Self::Empty => 0, - Self::Int(i) => ::ipa_step::CompactGateIndex::try_from(*i).unwrap() + 1, + Self::Int(i) if *i < usize::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap() + 1, + Self::Int(i) => panic!("Step index {i} out of bounds for AllArms::Int with count 3. Consider using bounds-checked step constructors."), Self::Child => 4, Self::Final => <::some::other::StepEnum as ::ipa_step::CompactStep>::STEP_COUNT + 5, } @@ -854,6 +909,66 @@ mod test { ); } + #[test] + fn struct_int() { + derive_success( + quote! { + #[derive(CompactStep)] + #[step(count = 3)] + struct StructInt(u8); + }, + "e! { + impl ::ipa_step::Step for StructInt {} + + impl From for StructInt { + fn from(v: u8) -> Self { + assert!( + v < u8::try_from(3usize).unwrap(), + "Step index {v} out of bounds for StructInt with count 3.", + ); + Self(v) + } + } + + #[allow( + clippy::useless_conversion, + clippy::unnecessary_fallible_conversions, + )] + impl ::std::convert::AsRef for StructInt { + fn as_ref(&self) -> &str { + const STRUCT_INT_NAMES: [&str; 3] = ["struct_int0" , "struct_int1" , "struct_int2"]; + match self { + Self(i) => STRUCT_INT_NAMES[usize::try_from(*i).unwrap()], + } + } + } + + #[allow( + clippy::useless_conversion, + clippy::unnecessary_fallible_conversions, + clippy::identity_op, + )] + impl ::ipa_step::CompactStep for StructInt { + const STEP_COUNT: ::ipa_step::CompactGateIndex = 3; + + fn base_index(&self) -> ::ipa_step::CompactGateIndex { + match self { + Self(i) if *i < u8::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + Self(i) => panic!("Step index {i} out of bounds for StructInt with count 3. Consider using bounds-checked step constructors."), + } + } + + fn step_string(i: ::ipa_step::CompactGateIndex) -> String { + match i { + _ if i < 3 => Self(u8::try_from(i - (0)).unwrap()).as_ref().to_owned(), + _ => panic!("step {i} is not valid for {t}", t = ::std::any::type_name::()), + } + } + } + }, + ); + } + #[test] fn struct_missing_count() { derive_failure( diff --git a/ipa-step-derive/src/variant.rs b/ipa-step-derive/src/variant.rs index 1b72bdfe1..792aa8fd1 100644 --- a/ipa-step-derive/src/variant.rs +++ b/ipa-step-derive/src/variant.rs @@ -9,6 +9,7 @@ use syn::{ use crate::{sum::ExtendedSum, IntoSpan}; struct VariantAttrParser<'a> { + full_name: String, ident: &'a Ident, name: Option, count: Option, @@ -17,8 +18,9 @@ struct VariantAttrParser<'a> { } impl<'a> VariantAttrParser<'a> { - fn new(ident: &'a Ident) -> Self { + fn new(full_name: String, ident: &'a Ident) -> Self { Self { + full_name, ident, name: None, count: None, @@ -161,6 +163,7 @@ impl<'a> VariantAttrParser<'a> { ) } else { Ok(VariantAttribute { + full_name: self.full_name, ident: self.ident.clone(), name: self .name @@ -173,6 +176,7 @@ impl<'a> VariantAttrParser<'a> { } pub struct VariantAttribute { + full_name: String, ident: Ident, name: String, integer: Option<(usize, TypePath)>, @@ -188,10 +192,11 @@ impl VariantAttribute { } /// Parse a set of attributes out from a representation of an enum. - pub fn parse_variants(data: &DataEnum) -> Result, syn::Error> { + pub fn parse_variants(enum_ident: &Ident, data: &DataEnum) -> Result, syn::Error> { let mut steps = Vec::with_capacity(data.variants.len()); for v in &data.variants { - steps.push(VariantAttrParser::new(&v.ident).parse_variant(v)?); + let full_name = format!("{}::{}", enum_ident, v.ident); + steps.push(VariantAttrParser::new(full_name, &v.ident).parse_variant(v)?); } Ok(steps) } @@ -202,7 +207,7 @@ impl VariantAttribute { attrs: &[Attribute], fields: Option<&Fields>, ) -> Result { - VariantAttrParser::new(ident).parse_outer(attrs, fields) + VariantAttrParser::new(ident.to_string(), ident).parse_outer(attrs, fields) } } @@ -214,6 +219,8 @@ pub struct Generator { arm_count: ExtendedSum, // This tracks the index of each item. index_arms: TokenStream, + // This tracks integer variant constructors. + int_variant_constructors: TokenStream, // This tracks the arrays of names that are used for integer variants. name_arrays: TokenStream, // This tracks the arms of the `AsRef` match implementation. @@ -254,6 +261,7 @@ impl Generator { fn add_empty(&mut self, v: &VariantAttribute, is_variant: bool) { // Unpack so that we can use `quote!()`. let VariantAttribute { + full_name: _, ident: step_ident, name: step_name, integer: None, @@ -323,6 +331,7 @@ impl Generator { fn add_int(&mut self, v: &VariantAttribute, is_variant: bool) { // Unpack so that we can use `quote!()`. let VariantAttribute { + full_name: step_full_name, ident: step_ident, name: step_name, integer: Some((step_count, step_integer)), @@ -339,6 +348,22 @@ impl Generator { quote!(Self) }; + if is_variant { + let constructor = format_ident!("{}", step_ident.to_string().to_snake_case()); + let out_of_bounds_msg = format!( + "Step index {{v}} out of bounds for {step_full_name} with count {step_count}." + ); + self.int_variant_constructors.extend(quote! { + pub fn #constructor(v: #step_integer) -> Self { + assert!( + v < #step_integer::try_from(#step_count).unwrap(), + #out_of_bounds_msg, + ); + Self::#step_ident(v) + } + }); + } + // Construct some nice names for each integer value in the range. let array_name = format_ident!("{}_NAMES", step_ident.to_string().to_shouting_case()); let skip_zeros = match *step_count - 1 { @@ -362,8 +387,11 @@ impl Generator { if let Some(child) = step_child { let idx = self.arm_count.clone() + quote!((<#child as ::ipa_step::CompactStep>::STEP_COUNT + 1) * ::ipa_step::CompactGateIndex::try_from(*i).unwrap()); + let out_of_bounds_msg = + format!("Step index {{i}} out of bounds for {step_full_name} with count {step_count}. Consider using bounds-checked step constructors."); self.index_arms.extend(quote! { - #arm(i) => #idx, + #arm(i) if *i < #step_integer::try_from(#step_count).unwrap() => #idx, + #arm(i) => panic!(#out_of_bounds_msg), }); // With `step_count` variations present, each has a name. @@ -403,8 +431,11 @@ impl Generator { } else { let idx = self.arm_count.clone() + quote!(::ipa_step::CompactGateIndex::try_from(*i).unwrap()); + let out_of_bounds_msg = + format!("Step index {{i}} out of bounds for {step_full_name} with count {step_count}. Consider using bounds-checked step constructors."); self.index_arms.extend(quote! { - #arm(i) => #idx, + #arm(i) if *i < #step_integer::try_from(#step_count).unwrap() => #idx, + #arm(i) => panic!(#out_of_bounds_msg), }); let range_end = arm_count.clone() + *step_count; @@ -415,6 +446,7 @@ impl Generator { } } + #[allow(clippy::too_many_lines)] pub fn generate(mut self, ident: &Ident, attr: &VariantAttribute) -> TokenStream { self.add_outer(attr); @@ -422,6 +454,33 @@ impl Generator { impl ::ipa_step::Step for #ident {} }; + // Generate a bounds-checking `impl From` if this is an integer unit struct step. + if let &Some((count, ref type_path)) = &attr.integer { + let out_of_bounds_msg = + format!("Step index {{v}} out of bounds for {ident} with count {count}."); + result.extend(quote! { + impl From<#type_path> for #ident { + fn from(v: #type_path) -> Self { + assert!( + v < #type_path::try_from(#count).unwrap(), + #out_of_bounds_msg, + ); + Self(v) + } + } + }); + } + + // Generate bounds-checking variant constructors if there are integer variants. + if !self.int_variant_constructors.is_empty() { + let constructors = self.int_variant_constructors; + result.extend(quote! { + impl #ident { + #constructors + } + }); + } + assert_eq!(self.index_arms.is_empty(), self.as_ref_arms.is_empty()); let (index_arms, as_ref_arms) = if self.index_arms.is_empty() { let n = attr.name(); diff --git a/ipa-step-test/src/lib.rs b/ipa-step-test/src/lib.rs index c2c9abfed..3789e6fcf 100644 --- a/ipa-step-test/src/lib.rs +++ b/ipa-step-test/src/lib.rs @@ -17,15 +17,21 @@ mod tests { #[test] fn narrows() { + assert_eq!(ComplexGate::default().index(), 0); assert_eq!(ComplexGate::default().as_ref(), "/"); assert_eq!( ComplexGate::default().narrow(&ComplexStep::One).as_ref(), "/one" ); + assert_eq!(ComplexGate::default().narrow(&ComplexStep::One).index(), 1,); assert_eq!( ComplexGate::default().narrow(&ComplexStep::Two(2)).as_ref(), "/two2" ); + assert_eq!( + ComplexGate::default().narrow(&ComplexStep::Two(2)).index(), + 10, + ); assert_eq!( ComplexGate::default() .narrow(&ComplexStep::Two(2)) @@ -33,6 +39,13 @@ mod tests { .as_ref(), "/two2/one" ); + assert_eq!( + ComplexGate::default() + .narrow(&ComplexStep::Two(2)) + .narrow(&BasicStep::One) + .index(), + 11, + ); assert_eq!( ComplexGate::from("/two2/one"), ComplexGate::default() @@ -55,6 +68,16 @@ mod tests { _ = ComplexGate::from("/two2/one").narrow(&BasicStep::Two); } + /// Attempts to narrow with an out-of-range index should panic + /// (rather than produce an incorrect output gate). + #[test] + #[should_panic( + expected = "Step index 10 out of bounds for ComplexStep::Two with count 10. Consider using bounds-checked step constructors." + )] + fn index_out_of_range() { + _ = ComplexGate::default().narrow(&ComplexStep::Two(10)); + } + /// Test that the alpha and beta gates work. #[test] fn alpha_and_beta() { diff --git a/ipa-step/src/gate.rs b/ipa-step/src/gate.rs index 0eac78393..0aa45cc46 100644 --- a/ipa-step/src/gate.rs +++ b/ipa-step/src/gate.rs @@ -29,7 +29,7 @@ fn build_narrows( let short_name = t.rsplit_once("::").map_or_else(|| t.as_ref(), |(_a, b)| b); let msg = format!("unexpected narrow for {gate_name}({{s}}) => {short_name}({{ss}})"); syntax.extend(quote! { - #[allow(clippy::too_many_lines)] + #[allow(clippy::too_many_lines, clippy::unreadable_literal)] impl ::ipa_step::StepNarrow<#ty> for #ident { fn narrow(&self, step: &#ty) -> Self { match self.0 { @@ -87,6 +87,7 @@ fn compact_gate_impl(gate_name: &str) -> TokenStream { let gate_lookup_type = step_hasher.lookup_type(); let mut syntax = quote! { + #[allow(clippy::unreadable_literal)] static STR_LOOKUP: [&str; #step_count] = [#(#gate_names),*]; static GATE_LOOKUP: #gate_lookup_type = #step_hasher diff --git a/ipa-step/src/hash.rs b/ipa-step/src/hash.rs index 6b036dc74..133d3713a 100644 --- a/ipa-step/src/hash.rs +++ b/ipa-step/src/hash.rs @@ -62,6 +62,7 @@ impl ToTokens for HashingSteps { }; struct #lookup_type { + #[allow(clippy::unreadable_literal)] inner: [(u64, u32); #sz] } diff --git a/pre-commit b/pre-commit index 308164245..b5319ac91 100755 --- a/pre-commit +++ b/pre-commit @@ -82,43 +82,51 @@ check() { fi } -check "Benchmark compilation" \ - cargo build --benches --no-default-features --features "enable-benches compact-gate" - check "Clippy checks" \ - cargo clippy --tests -- -D warnings + cargo clippy --features="cli test-fixture" --tests -- -D warnings -check "Clippy concurrency checks" \ - cargo clippy --tests --features shuttle -- -D warnings +check "Tests" \ + cargo test --features="cli test-fixture relaxed-dp" -check "Clippy web checks" \ - cargo clippy --tests --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" -- -D warnings +if [ -n "$EXEC_SLOW_TESTS" ] +then -# The tests here need to be kept in sync with scripts/coverage-ci. + check "Benchmark compilation" \ + cargo build --benches --no-default-features --features "enable-benches compact-gate" -check "Tests" \ - cargo test + check "Clippy concurrency checks" \ + cargo clippy --tests --features shuttle -- -D warnings -check "Web tests" \ - cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + check "Clippy web checks" \ + cargo clippy --tests --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" -- -D warnings -check "Web tests (descriptive gate)" \ - cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture" + # The tests here need to be kept in sync with scripts/coverage-ci. -check "Concurrency tests" \ - cargo test -p ipa-core --release --features "shuttle multi-threading" + check "Web tests" \ + cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" -check "Encrypted Input Tests" \ - cargo test --test encrypted_input --features "cli test-fixture web-app in-memory-infra" + check "Web tests (descriptive gate)" \ + cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate" -check "IPA benchmark" \ - cargo bench --bench oneshot_ipa --no-default-features --features="enable-benches compact-gate" -- -n 62 -c 16 + check "Concurrency tests" \ + cargo test -p ipa-core --release --features "shuttle multi-threading" -check "Arithmetic circuit benchmark" \ - cargo bench --bench oneshot_arithmetic --no-default-features --features "enable-benches compact-gate" + check "IPA benchmark" \ + cargo bench --bench oneshot_ipa --no-default-features --features="enable-benches compact-gate" -- -n 62 -c 16 -if [ -z "$EXEC_SLOW_TESTS" ] -then - check "Slow tests" \ - cargo test --release --test "*" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + check "Arithmetic circuit benchmark" \ + cargo bench --bench oneshot_arithmetic --no-default-features --features "enable-benches compact-gate" + + check "Slow tests: Compact Gate" \ + cargo test --release --test "compact_gate" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + + check "Slow tests: Helper Networks" \ + cargo test --release --test "helper_networks" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + + check "Slow tests: Hybrid tests" \ + cargo test --release --test "hybrid" --features "cli test-fixture" + + + check "Slow tests: IPA with Relaxed DP" \ + cargo test --release --test "ipa_with_relaxed_dp" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate relaxed-dp" fi diff --git a/scripts/coverage-ci b/scripts/coverage-ci index b34bc8920..6f79c7599 100755 --- a/scripts/coverage-ci +++ b/scripts/coverage-ci @@ -9,14 +9,22 @@ cargo llvm-cov clean --workspace cargo build --all-targets # Need to be kept in sync manually with tests we run inside check.yml. -cargo test +cargo test --features "cli test-fixture relaxed-dp" -# descriptive-gate does not require a feature flag. -for gate in "compact-gate" ""; do +# Provide code coverage stats for ipa-metrics crate with partitions enabled +cargo test -p ipa-metrics --features "partitions" + +for gate in "compact-gate" "descriptive-gate"; do cargo test --no-default-features --features "cli web-app real-world-infra test-fixture $gate" done +# integration tests run without relaxed dp, except for these +cargo test --release --test "ipa_with_relaxed_dp" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate relaxed-dp" + cargo test --bench oneshot_ipa --no-default-features --features "enable-benches compact-gate" -- -n 62 -c 16 cargo test --bench criterion_arithmetic --no-default-features --features "enable-benches compact-gate" +# compact gate + in-memory-infra +cargo test --features "compact-gate" + cargo llvm-cov report "$@"