diff --git a/ipa-core/src/cli/metric_collector.rs b/ipa-core/src/cli/metric_collector.rs index 8f9a374b4..2ff6627a1 100644 --- a/ipa-core/src/cli/metric_collector.rs +++ b/ipa-core/src/cli/metric_collector.rs @@ -1,15 +1,13 @@ use std::{io, thread, thread::JoinHandle}; - -use ipa_metrics::{ - MetricChannelType, MetricsCollectorController, MetricsCurrentThreadContext, MetricsProducer, -}; +use std::collections::HashMap; +use std::time::Duration; +use ipa_metrics::{counter, MetricChannelType, MetricsCollectorController, MetricsCurrentThreadContext, MetricsProducer}; use tokio::runtime::Builder; +use crate::ff::curve_points::{COMPRESS_OP, COMPRESS_SER_OP, COMPRESS_FROM_FP_OP, COMPRESS_FROM_SCALAR_OP, COMPRESS_HASH_OP, DECOMPRESS_ADD_OP, DECOMPRESS_DESER_OP, DECOMPRESS_MUL_OP, DECOMPRESS_OP}; /// Holds a reference to metrics controller and producer pub struct CollectorHandle { thread_handle: JoinHandle<()>, - /// This will be used once we start consuming metrics - _controller: MetricsCollectorController, producer: MetricsProducer, } @@ -23,10 +21,63 @@ pub fn install_collector() -> io::Result { let (producer, controller, handle) = ipa_metrics::install_new_thread(MetricChannelType::Unbounded)?; tracing::info!("Metrics engine is enabled"); + thread::spawn(|| { + tracing::info!("metric observer is started"); + const METRICS_OF_INTEREST: [&'static str; 9] = [ + COMPRESS_OP, + COMPRESS_SER_OP, + COMPRESS_HASH_OP, + COMPRESS_FROM_SCALAR_OP, + COMPRESS_FROM_FP_OP, + DECOMPRESS_OP, + DECOMPRESS_ADD_OP, + DECOMPRESS_MUL_OP, + DECOMPRESS_DESER_OP + ]; + + struct Watcher(MetricsCollectorController, HashMap<&'static str, u64>); + + impl Watcher { + fn dump(&mut self) { + if let Ok(snapshot) = self.0.snapshot() { + let mut dump = String::new(); + let mut needs_dump = false; + for metric in METRICS_OF_INTEREST { + let value = snapshot.counter_val(counter!(metric)); + if value > 0 && self.1.get(metric) != Some(&value) { + self.1.insert(metric, value); + needs_dump = true; + } + if needs_dump { + dump += &format!("{metric}={value}\n"); + } + } + + if dump.len() > 0 { + tracing::info!("Metrics dump:\n{dump}") + } + } else { + tracing::error!("Failed to dump metrics") + } + } + } + + impl Drop for Watcher { + fn drop(&mut self) { + tracing::info!("metric watcher is being dropped"); + self.dump(); + } + } + + let mut w = Watcher(controller, HashMap::default()); + loop { + thread::sleep(Duration::from_secs(10)); + w.dump(); + } + }); Ok(CollectorHandle { thread_handle: handle, - _controller: controller, producer, }) } diff --git a/ipa-core/src/ff/curve_points.rs b/ipa-core/src/ff/curve_points.rs index 07567caa2..7dd837f5d 100644 --- a/ipa-core/src/ff/curve_points.rs +++ b/ipa-core/src/ff/curve_points.rs @@ -1,10 +1,7 @@ -use curve25519_dalek::{ - ristretto::{CompressedRistretto, RistrettoPoint}, - Scalar, -}; +use curve25519_dalek::{constants, ristretto::{CompressedRistretto, RistrettoPoint}, Scalar}; use generic_array::GenericArray; use ipa_metrics::counter; -use typenum::U32; +use typenum::{U128, U32}; use crate::{ ff::{ec_prime_field::Fp25519, Serializable}, @@ -29,7 +26,8 @@ impl Block for CompressedRistretto { /// only deserialize previously serialized valid points, panics will not occur /// However, we still added a debug assert to deserialize since values are sent by other servers #[derive(Clone, Copy, PartialEq, Eq, Debug)] -pub struct RP25519(CompressedRistretto); +// make an enum and play +pub struct RP25519(RistrettoPoint); impl Default for RP25519 { fn default() -> Self { @@ -37,17 +35,28 @@ impl Default for RP25519 { } } +impl Block for RistrettoPoint { + type Size = U128; +} + /// Implementing trait for secret sharing impl SharedValue for RP25519 { - type Storage = CompressedRistretto; - const BITS: u32 = 256; - const ZERO: Self = Self(CompressedRistretto([0_u8; 32])); + type Storage = RistrettoPoint; + const BITS: u32 = 1024; + const ZERO: Self = Self(constants::RISTRETTO_BASEPOINT_POINT); impl_shared_value_common!(); } pub const DECOMPRESS_OP: &str = "RP25519.decompress"; +pub const DECOMPRESS_ADD_OP: &str = "RP25519.decompress.add"; +pub const DECOMPRESS_MUL_OP: &str = "RP25519.decompress.mul"; +pub const DECOMPRESS_DESER_OP: &str = "RP25519.decompress.deserialize"; pub const COMPRESS_OP: &str = "RP25519.compress"; +pub const COMPRESS_HASH_OP: &str = "RP25519.compress.hash"; +pub const COMPRESS_SER_OP: &str = "RP25519.compress.serialize"; +pub const COMPRESS_FROM_SCALAR_OP: &str = "RP25519.compress.from.scalar"; +pub const COMPRESS_FROM_FP_OP: &str = "RP25519.compress.from.fp"; impl Vectorizable<1> for RP25519 { type Array = StdArray; @@ -61,8 +70,14 @@ impl Vectorizable for RP25519 { #[error("{0:?} is not the canonical encoding of a Ristretto point.")] pub struct NonCanonicalEncoding(CompressedRistretto); -impl Serializable for RP25519 { - type Size = <::Storage as Block>::Size; +#[derive(Copy, Clone, Debug)] +pub struct CompressedRp25519(CompressedRistretto); +impl Block for CompressedRp25519 { + type Size = U32; +} + +impl Serializable for CompressedRp25519 { + type Size = ::Size; type DeserializationError = NonCanonicalEncoding; fn serialize(&self, buf: &mut GenericArray) { @@ -71,12 +86,29 @@ impl Serializable for RP25519 { fn deserialize(buf: &GenericArray) -> Result { let point = CompressedRistretto((*buf).into()); - if cfg!(debug_assertions) && point.decompress().is_none() { - counter!(DECOMPRESS_OP, 1); - return Err(NonCanonicalEncoding(point)); - } + Ok(CompressedRp25519(point)) + } +} + +impl Serializable for RP25519 { + type Size = ::Size; + type DeserializationError = NonCanonicalEncoding; + + fn serialize(&self, buf: &mut GenericArray) { + counter!(COMPRESS_OP, 1); + counter!(COMPRESS_SER_OP, 1); + let compressed = CompressedRp25519(self.0.compress()); + *buf.as_mut() = compressed.0.to_bytes(); + } - Ok(RP25519(point)) + fn deserialize(buf: &GenericArray) -> Result { + let point = CompressedRistretto((*buf).into()); + counter!(DECOMPRESS_OP, 1); + counter!(DECOMPRESS_DESER_OP, 1); + match point.decompress() { + Some(v) => Ok(Self(v)), + None => Err(NonCanonicalEncoding(point)), + } } } @@ -87,8 +119,11 @@ impl std::ops::Add for RP25519 { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - counter!(DECOMPRESS_OP, 2); - Self((self.0.decompress().unwrap() + rhs.0.decompress().unwrap()).compress()) + // counter!(DECOMPRESS_OP, 2); + // counter!(COMPRESS_OP, 1); + // counter!(COMPRESS_ADD_OP, 1); + // counter!(DECOMPRESS_ADD_OP, 2); + Self(self.0 + rhs.0) } } @@ -106,9 +141,9 @@ impl std::ops::Neg for RP25519 { type Output = Self; fn neg(self) -> Self::Output { - counter!(DECOMPRESS_OP, 1); - counter!(COMPRESS_OP, 1); - Self(self.0.decompress().unwrap().neg().compress()) + // counter!(DECOMPRESS_OP, 1); + // counter!(COMPRESS_OP, 1); + Self(self.0.neg()) } } @@ -119,9 +154,9 @@ impl std::ops::Sub for RP25519 { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - counter!(DECOMPRESS_OP, 2); - counter!(COMPRESS_OP, 1); - Self((self.0.decompress().unwrap() - rhs.0.decompress().unwrap()).compress()) + // counter!(DECOMPRESS_OP, 2); + // counter!(COMPRESS_OP, 1); + Self(self.0 - rhs.0) } } @@ -140,12 +175,12 @@ impl std::ops::SubAssign for RP25519 { impl std::ops::Mul for RP25519 { type Output = Self; - fn mul(self, rhs: Fp25519) -> RP25519 { - counter!(DECOMPRESS_OP, 1); - counter!(COMPRESS_OP, 1); - (self.0.decompress().unwrap() * Scalar::from(rhs)) - .compress() - .into() + fn mul(self, rhs: Fp25519) -> Self { + // counter!(DECOMPRESS_OP, 1); + // counter!(DECOMPRESS_MUL_OP, 1); + // counter!(COMPRESS_OP, 1); + // counter!(COMPRESS_MUL_OP, 1); + Self(self.0 * Scalar::from(rhs)) } } @@ -158,29 +193,31 @@ impl std::ops::MulAssign for RP25519 { impl From for RP25519 { fn from(s: Scalar) -> Self { - counter!(COMPRESS_OP, 1); - RP25519(RistrettoPoint::mul_base(&s).compress()) + // counter!(COMPRESS_OP, 1); + // counter!(COMPRESS_FROM_SCALAR_OP, 1); + Self(RistrettoPoint::mul_base(&s)) } } impl From for RP25519 { fn from(s: Fp25519) -> Self { - counter!(COMPRESS_OP, 1); - RP25519(RistrettoPoint::mul_base(&s.into()).compress()) + // counter!(COMPRESS_OP, 1); + // counter!(COMPRESS_FROM_FP_OP, 1); + Self(RistrettoPoint::mul_base(&s.into())) } } -impl From for RP25519 { - fn from(s: CompressedRistretto) -> Self { - RP25519(s) - } -} +// impl From for RP25519 { +// fn from(s: CompressedRistretto) -> Self { +// RP25519(s) +// } +// } -impl From for CompressedRistretto { - fn from(s: RP25519) -> Self { - s.0 - } -} +// impl From for CompressedRistretto { +// fn from(s: RP25519) -> Self { +// s.0 +// } +// } ///allows to convert curve points into unsigned integers, preserving high entropy macro_rules! cp_hash_impl { @@ -189,7 +226,9 @@ macro_rules! cp_hash_impl { fn from(s: RP25519) -> Self { use hkdf::Hkdf; use sha2::Sha256; - let hk = Hkdf::::new(None, s.0.as_bytes()); + ipa_metrics::counter!(crate::ff::curve_points::COMPRESS_OP, 1); + ipa_metrics::counter!(crate::ff::curve_points::COMPRESS_HASH_OP, 1); + let hk = Hkdf::::new(None, s.0.compress().as_bytes()); let mut okm = <$u_type>::MIN.to_le_bytes(); //error invalid length from expand only happens when okm is very large hk.expand(&[], &mut okm).unwrap(); @@ -199,8 +238,8 @@ macro_rules! cp_hash_impl { }; } -cp_hash_impl!(u128); - +// cp_hash_impl!(u128); +// cp_hash_impl!(u64); /// implementing random curve point generation for testing purposes, @@ -210,8 +249,7 @@ impl rand::distributions::Distribution for rand::distributions::Standar fn sample(&self, rng: &mut R) -> RP25519 { let mut scalar_bytes = [0u8; 64]; rng.fill_bytes(&mut scalar_bytes); - counter!(COMPRESS_OP, 1); - RP25519(RistrettoPoint::from_uniform_bytes(&scalar_bytes).compress()) + RP25519(RistrettoPoint::from_uniform_bytes(&scalar_bytes)) } } @@ -247,8 +285,8 @@ mod test { let b: RP25519 = a.into(); let d: Fp25519 = a.into(); let c: RP25519 = RP25519::from(d); - assert_eq!(b, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED)); - assert_eq!(c, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED)); + assert_eq!(b, RP25519(constants::RISTRETTO_BASEPOINT_POINT)); + assert_eq!(c, RP25519(constants::RISTRETTO_BASEPOINT_POINT)); } ///testing simple curve arithmetics to check that `curve25519_dalek` library is used correctly @@ -260,13 +298,13 @@ mod test { let fp_c = fp_a + fp_b; let fp_d = RP25519::from(fp_a) + RP25519::from(fp_b); assert_eq!(fp_d, RP25519::from(fp_c)); - assert_ne!(fp_d, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED)); + assert_ne!(fp_d, RP25519(constants::RISTRETTO_BASEPOINT_POINT)); let fp_e = rng.gen::(); let fp_f = rng.gen::(); let fp_g = fp_e * fp_f; let fp_h = RP25519::from(fp_e) * fp_f; assert_eq!(fp_h, RP25519::from(fp_g)); - assert_ne!(fp_h, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED)); + assert_ne!(fp_h, RP25519(constants::RISTRETTO_BASEPOINT_POINT)); assert_eq!(RP25519::ZERO, fp_h * Scalar::ZERO.into()); } diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 55a02f9f8..0c7d223a0 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -317,7 +317,7 @@ where // multiplications per batch const CONV_PROOF_CHUNK: usize = 256; -#[tracing::instrument(name = "compute_prf_for_inputs", skip_all)] +#[tracing::instrument(name = "compute_prf_for_inputs", skip_all, fields(sz = input_rows.len()))] async fn compute_prf_for_inputs( ctx: C, input_rows: &[OPRFIPAInputRow], diff --git a/ipa-core/src/secret_sharing/vector/array.rs b/ipa-core/src/secret_sharing/vector/array.rs index ebc0c1947..a01ca4219 100644 --- a/ipa-core/src/secret_sharing/vector/array.rs +++ b/ipa-core/src/secret_sharing/vector/array.rs @@ -353,7 +353,9 @@ macro_rules! impl_serializable { type DeserializationError = ::DeserializationError; fn serialize(&self, buf: &mut GenericArray) { - let sz: usize = (::BITS / 8).try_into().unwrap(); + use typenum::Unsigned; + let sz: usize = ::Size::USIZE; + for i in 0..$width { self.0[i].serialize( GenericArray::try_from_mut_slice(&mut buf[sz * i..sz * (i + 1)]).unwrap(), @@ -364,7 +366,9 @@ macro_rules! impl_serializable { fn deserialize( buf: &GenericArray, ) -> Result { - let sz: usize = (::BITS / 8).try_into().unwrap(); + use typenum::Unsigned; + let sz: usize = ::Size::USIZE; + let mut res = [V::ZERO; $width]; for i in 0..$width { res[i] = V::deserialize(GenericArray::from_slice(&buf[sz * i..sz * (i + 1)]))?;