Skip to content

Commit

Permalink
Improve performance of hybrid encrypt CLI.
Browse files Browse the repository at this point in the history
There were quite a few bottlenecks here:
* Writes were done serially, writing one file at a time.
* Shares were encrypted on a single CPU core

I almost used `rayon` to parallelize encryption, but the problem is that we need to get the output sorted to maintain total order across files. Rayon can do that, but requires collecting `ParallelIterator` which would be bad for generating 100M+ reports.

Our goal is to be able to share and encrypt 1B, so streaming and manual fiddling with thread pools is justified imo.

The way this CLI works right now: it keeps a compute pool for encryption (thread-per-core) and a separate pool of 3 threads to write data for each helper in parallel

I also made a few tweaks to improve code re-usability in this module.

## Benchmarks
Done locally on M1 Mac Pro (10 cores)

Before this change:
Encryption process is completed. 442.15834075s

After this change
Encryption process is completed. 55.63269625s
  • Loading branch information
akoshelev committed Dec 7, 2024
1 parent 215cd86 commit 6caf63f
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 41 deletions.
5 changes: 2 additions & 3 deletions ipa-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@ default = [
# by default remove all TRACE, DEBUG spans from release builds
"tracing/max_level_trace",
"tracing/release_max_level_info",
"aggregate-circuit",
"stall-detection",
"aggregate-circuit",
"ipa-prf",
"descriptive-gate",
]
cli = ["comfy-table", "clap"]
cli = ["comfy-table", "clap", "num_cpus"]
# Enable compact gate optimization
compact-gate = []
# mutually exclusive with compact-gate and disables compact gate optimization.
Expand Down Expand Up @@ -130,6 +128,7 @@ hyper-util = { version = "0.1.3", optional = true, features = ["http2"] }
http-body-util = { version = "0.1.1", optional = true }
http-body = { version = "1", optional = true }
iai = { version = "0.1.1", optional = true }
num_cpus = { version = "1.0", optional = true }
once_cell = "1.18"
pin-project = "1.0"
rand = "0.8"
Expand Down
10 changes: 9 additions & 1 deletion ipa-core/src/bin/crypto_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@ use std::fmt::Debug;

use clap::{Parser, Subcommand};
use ipa_core::{
cli::crypto::{DecryptArgs, EncryptArgs, HybridDecryptArgs, HybridEncryptArgs},
cli::{
crypto::{DecryptArgs, EncryptArgs, HybridDecryptArgs, HybridEncryptArgs},
Verbosity,
},
error::BoxError,
};

#[derive(Debug, Parser)]
#[clap(name = "crypto-util", about = "Crypto Util CLI")]
#[command(about)]
struct Args {
// Configure logging.
#[clap(flatten)]
logging: Verbosity,

#[command(subcommand)]
action: CryptoUtilCommand,
}
Expand All @@ -25,6 +32,7 @@ enum CryptoUtilCommand {
#[tokio::main]
async fn main() -> Result<(), BoxError> {
let args = Args::parse();
let _handle = args.logging.setup_logging();
match args.action {
CryptoUtilCommand::Encrypt(encrypt_args) => encrypt_args.encrypt()?,
CryptoUtilCommand::HybridEncrypt(hybrid_encrypt_args) => hybrid_encrypt_args.encrypt()?,
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/bin/report_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ async fn ipa_test(
r
};

let mut key_registries = KeyRegistries::default();
let key_registries = KeyRegistries::default();
let Some(key_registries) = key_registries.init_from(network) else {
panic!("could not load network file")
};
Expand All @@ -546,7 +546,7 @@ async fn ipa_test(
helper_clients,
query_id,
ipa_query_config,
Some((DEFAULT_KEY_ID, key_registries)),
Some((DEFAULT_KEY_ID, key_registries.each_ref())),
)
.await;

Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/cli/crypto/encrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl EncryptArgs {
let input = InputSource::from_file(&self.input_file);

let mut rng = thread_rng();
let mut key_registries = KeyRegistries::default();
let key_registries = KeyRegistries::default();

let network =
NetworkConfig::from_toml_str(&read_to_string(&self.network).unwrap_or_else(|e| {
Expand Down Expand Up @@ -84,7 +84,7 @@ impl EncryptArgs {

for share in shares {
let output = share
.encrypt(DEFAULT_KEY_ID, key_registry, &mut rng)
.encrypt(DEFAULT_KEY_ID, &key_registry, &mut rng)
.unwrap();
let hex_output = hex::encode(&output);
writeln!(writer, "{hex_output}")?;
Expand Down
231 changes: 210 additions & 21 deletions ipa-core/src/cli/crypto/hybrid_encrypt.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use std::{
fs::{read_to_string, OpenOptions},
io::Write,
iter::zip,
array,
collections::BTreeMap,
fs::{read_to_string, File, OpenOptions},
io::{BufWriter, Write},
path::{Path, PathBuf},
sync::mpsc::{channel, Sender},
thread,
thread::JoinHandle,
time::Instant,
};

use clap::Parser;
Expand All @@ -15,11 +20,23 @@ use crate::{
},
config::{KeyRegistries, NetworkConfig},
error::BoxError,
hpke::{KeyRegistry, PublicKeyOnly},
report::hybrid::{HybridReport, DEFAULT_KEY_ID},
secret_sharing::IntoShares,
test_fixture::hybrid::TestHybridRecord,
};

/// Encryptor takes 3 arguments: `report_id`, helper that the shares must be encrypted towards
/// and the actual share to encrypt.
type EncryptorInput = (usize, usize, HybridReport<BreakdownKey, TriggerValue>);
/// Encryptor sends report id and encrypted bytes down to file worker to write those bytes
/// down
type EncryptorOutput = (usize, Vec<u8>);
type FileWorkerInput = EncryptorOutput;

/// This type is used quite often in this module
type UnitResult = Result<(), BoxError>;

#[derive(Debug, Parser)]
#[clap(name = "test_hybrid_encrypt", about = "Test Hybrid Encrypt")]
#[command(about)]
Expand Down Expand Up @@ -51,11 +68,12 @@ impl HybridEncryptArgs {
/// if input file or network file are not correctly formatted
/// # Errors
/// if it cannot open the files
pub fn encrypt(&self) -> Result<(), BoxError> {
pub fn encrypt(&self) -> UnitResult {
tracing::info!("encrypting input from {:?}", self.input_file);
let start = Instant::now();
let input = InputSource::from_file(&self.input_file);

let mut rng = thread_rng();
let mut key_registries = KeyRegistries::default();
let key_registries = KeyRegistries::default();

let network =
NetworkConfig::from_toml_str(&read_to_string(&self.network).unwrap_or_else(|e| {
Expand All @@ -71,28 +89,199 @@ impl HybridEncryptArgs {
panic!("could not load network file")
};

let shares: [Vec<HybridReport<BreakdownKey, TriggerValue>>; 3] =
input.iter::<TestHybridRecord>().share();
let mut worker_pool = ReportWriter::new(key_registries, &self.output_dir);
for (report_id, record) in input.iter::<TestHybridRecord>().enumerate() {
worker_pool.submit(report_id, record.share())?;
}

worker_pool.join()?;

let elapsed = start.elapsed();
tracing::info!(
"Encryption process is completed. {}s",
elapsed.as_secs_f64()
);

Ok(())
}
}

for (index, (shares, key_registry)) in zip(shares, key_registries).enumerate() {
let output_filename = format!("helper{}.enc", index + 1);
let mut writer = OpenOptions::new()
/// A thread-per-core pool responsible for encrypting reports in parallel.
/// This pool is shared across all writers to reduce the number of context switches.
struct EncryptorPool {
pool: Vec<(Sender<EncryptorInput>, JoinHandle<UnitResult>)>,
next_worker: usize,
}

impl EncryptorPool {
pub fn with_worker_threads(
thread_count: usize,
file_writer: [Sender<EncryptorOutput>; 3],
key_registries: [KeyRegistry<PublicKeyOnly>; 3],
) -> Self {
Self {
pool: (0..thread_count)
.map(move |i| {
let (tx, rx) = channel::<EncryptorInput>();
let key_registries = key_registries.clone();
let file_writer = file_writer.clone();
(
tx,
std::thread::Builder::new()
.name(format!("encryptor-{i}"))
.spawn(move || {
for (i, helper_id, report) in rx {
let key_registry = &key_registries[helper_id];
let output = report.encrypt(
DEFAULT_KEY_ID,
key_registry,
&mut thread_rng(),
)?;
file_writer[helper_id].send((i, output))?;
}

Ok(())
})
.unwrap(),
)
})
.collect(),
next_worker: 0,
}
}

pub fn encrypt_share(&mut self, report: EncryptorInput) -> UnitResult {
let tx = &self.pool[self.next_worker].0;
tx.send(report)?;
self.next_worker = (self.next_worker + 1) % self.pool.len();

Ok(())
}

pub fn stop(self) -> UnitResult {
for (tx, handle) in self.pool {
drop(tx);
handle.join().unwrap()?;
}

Ok(())
}
}

/// Performs end-to-end encryption, taking individual shares as input
/// (see [`ReportWriter::submit`]), encrypting them in parallel and writing
/// encrypted shares into 3 separate files. This optimizes for memory usage,
/// and maximizes CPU utilization.
struct ReportWriter {
encryptor_pool: EncryptorPool,
workers: Option<[FileWriteWorker; 3]>,
}

impl ReportWriter {
pub fn new(key_registries: [KeyRegistry<PublicKeyOnly>; 3], output_dir: &Path) -> Self {
// create 3 worker threads to write data into 3 files
let workers = array::from_fn(|i| {
let output_filename = format!("helper{}.enc", i + 1);
let file = OpenOptions::new()
.write(true)
.create_new(true)
.open(self.output_dir.join(&output_filename))
.unwrap_or_else(|e| panic!("unable write to {}. {}", &output_filename, e));

for share in shares {
let output = share
.encrypt(DEFAULT_KEY_ID, key_registry, &mut rng)
.unwrap();
let hex_output = hex::encode(&output);
writeln!(writer, "{hex_output}")?;
}
.open(output_dir.join(&output_filename))
.unwrap_or_else(|e| panic!("unable write to {:?}. {}", &output_filename, e));

FileWriteWorker::new(file)
});
let encryptor_pool = EncryptorPool::with_worker_threads(
num_cpus::get(),
workers.each_ref().map(|x| x.sender.clone()),
key_registries,
);

Self {
encryptor_pool,
workers: Some(workers),
}
}

pub fn submit(
&mut self,
report_id: usize,
shares: [HybridReport<BreakdownKey, TriggerValue>; 3],
) -> UnitResult {
for (i, share) in shares.into_iter().enumerate() {
self.encryptor_pool.encrypt_share((report_id, i, share))?;
}

Ok(())
}

pub fn join(mut self) -> UnitResult {
self.encryptor_pool.stop()?;
self.workers
.take()
.unwrap()
.map(|worker| {
let FileWriteWorker { handle, sender } = worker;
drop(sender);
handle.join().unwrap()
})
.into_iter()
.collect()
}
}

/// This takes a file and writes all encrypted reports to it,
/// ensuring the same total order based on `report_id`. Report id is
/// just the index of file input row that guarantees consistency
/// of shares written across 3 files
struct FileWriteWorker {
sender: Sender<FileWorkerInput>,
handle: JoinHandle<UnitResult>,
}

impl FileWriteWorker {
pub fn new(file: File) -> Self {
let (tx, rx) = std::sync::mpsc::channel();
Self {
sender: tx,
handle: thread::spawn(move || {
fn write_report<W: Write>(writer: &mut W, report: &[u8]) -> Result<(), BoxError> {
let hex_output = hex::encode(report);
writeln!(writer, "{hex_output}")?;
Ok(())
}

// write low watermark. All reports below this line have been written
let mut lw = 0;
let mut pending_reports = BTreeMap::new();

// Buffered writes should improve IO, but it is likely not the bottleneck here.
let mut writer = BufWriter::new(file);
for (report_id, report) in rx {
// Because reports are encrypted in parallel, it is possible
// to receive report_id = X+1 before X. To mitigate that, we keep
// a buffer, ordered by report_id and always write from low watermark.
// This ensures consistent order of reports written to files. Any misalignment
// will result in broken shares and garbage output.
assert!(
report_id >= lw,
"Internal error: received a report {report_id} below low watermark"
);
assert!(
pending_reports.insert(report_id, report).is_none(),
"Internal error: received a duplicate report {report_id}"
);
while let Some(report) = pending_reports.remove(&lw) {
write_report(&mut writer, &report)?;
lw += 1;
if lw % 1_000_000 == 0 {
tracing::info!("Encrypted {} reports", lw / 1_000_000);
}
}
}
Ok(())
}),
}
}
}

#[cfg(all(test, unit_test))]
Expand Down
14 changes: 11 additions & 3 deletions ipa-core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,11 @@ pub struct KeyRegistries(Vec<KeyRegistry<PublicKeyOnly>>);
impl KeyRegistries {
/// # Panics
/// If network file is improperly formatted
#[must_use]
pub fn init_from(
&mut self,
mut self,
network: &NetworkConfig<Helper>,
) -> Option<[&KeyRegistry<PublicKeyOnly>; 3]> {
) -> Option<[KeyRegistry<PublicKeyOnly>; 3]> {
// Get the configs, if all three peers have one
let peers = network.peers();
let configs = peers.iter().try_fold(Vec::new(), |acc, peer| {
Expand All @@ -487,7 +488,14 @@ impl KeyRegistries {
.map(|hpke| KeyRegistry::from_keys([PublicKeyOnly(hpke.public_key.clone())]))
.collect::<Vec<KeyRegistry<PublicKeyOnly>>>();

Some(self.0.iter().collect::<Vec<_>>().try_into().ok().unwrap())
Some(
self.0
.into_iter()
.collect::<Vec<_>>()
.try_into()
.ok()
.unwrap(),
)
}
}

Expand Down
Loading

0 comments on commit 6caf63f

Please sign in to comment.