Skip to content

Commit

Permalink
Temp commit to run on draft
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Oct 30, 2024
1 parent d80386e commit d2a6e96
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 63 deletions.
65 changes: 58 additions & 7 deletions ipa-core/src/cli/metric_collector.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use std::{io, thread, thread::JoinHandle};

use ipa_metrics::{
MetricChannelType, MetricsCollectorController, MetricsCurrentThreadContext, MetricsProducer,
};
use std::collections::HashMap;
use std::time::Duration;
use ipa_metrics::{counter, MetricChannelType, MetricsCollectorController, MetricsCurrentThreadContext, MetricsProducer};
use tokio::runtime::Builder;
use crate::ff::curve_points::{COMPRESS_OP, COMPRESS_SER_OP, COMPRESS_FROM_FP_OP, COMPRESS_FROM_SCALAR_OP, COMPRESS_HASH_OP, DECOMPRESS_ADD_OP, DECOMPRESS_DESER_OP, DECOMPRESS_MUL_OP, DECOMPRESS_OP};

/// Holds a reference to metrics controller and producer
pub struct CollectorHandle {
thread_handle: JoinHandle<()>,
/// This will be used once we start consuming metrics
_controller: MetricsCollectorController,
producer: MetricsProducer,
}

Expand All @@ -23,10 +21,63 @@ pub fn install_collector() -> io::Result<CollectorHandle> {
let (producer, controller, handle) =
ipa_metrics::install_new_thread(MetricChannelType::Unbounded)?;
tracing::info!("Metrics engine is enabled");
thread::spawn(|| {
tracing::info!("metric observer is started");
const METRICS_OF_INTEREST: [&'static str; 9] = [
COMPRESS_OP,
COMPRESS_SER_OP,
COMPRESS_HASH_OP,
COMPRESS_FROM_SCALAR_OP,
COMPRESS_FROM_FP_OP,
DECOMPRESS_OP,
DECOMPRESS_ADD_OP,
DECOMPRESS_MUL_OP,
DECOMPRESS_DESER_OP
];

struct Watcher(MetricsCollectorController, HashMap<&'static str, u64>);

impl Watcher {
fn dump(&mut self) {
if let Ok(snapshot) = self.0.snapshot() {
let mut dump = String::new();
let mut needs_dump = false;
for metric in METRICS_OF_INTEREST {
let value = snapshot.counter_val(counter!(metric));
if value > 0 && self.1.get(metric) != Some(&value) {
self.1.insert(metric, value);
needs_dump = true;
}
if needs_dump {
dump += &format!("{metric}={value}\n");
}
}

if dump.len() > 0 {
tracing::info!("Metrics dump:\n{dump}")
}
} else {
tracing::error!("Failed to dump metrics")
}
}
}

impl Drop for Watcher {
fn drop(&mut self) {
tracing::info!("metric watcher is being dropped");
self.dump();
}
}

let mut w = Watcher(controller, HashMap::default());
loop {
thread::sleep(Duration::from_secs(10));
w.dump();
}
});

Ok(CollectorHandle {
thread_handle: handle,
_controller: controller,
producer,
})
}
Expand Down
144 changes: 91 additions & 53 deletions ipa-core/src/ff/curve_points.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use curve25519_dalek::{
ristretto::{CompressedRistretto, RistrettoPoint},
Scalar,
};
use curve25519_dalek::{constants, ristretto::{CompressedRistretto, RistrettoPoint}, Scalar};
use generic_array::GenericArray;
use ipa_metrics::counter;
use typenum::U32;
use typenum::{U128, U32};

use crate::{
ff::{ec_prime_field::Fp25519, Serializable},
Expand All @@ -29,25 +26,37 @@ impl Block for CompressedRistretto {
/// only deserialize previously serialized valid points, panics will not occur
/// However, we still added a debug assert to deserialize since values are sent by other servers
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct RP25519(CompressedRistretto);
// make an enum and play
pub struct RP25519(RistrettoPoint);

impl Default for RP25519 {
fn default() -> Self {
Self::ZERO
}
}

impl Block for RistrettoPoint {
type Size = U128;
}

/// Implementing trait for secret sharing
impl SharedValue for RP25519 {
type Storage = CompressedRistretto;
const BITS: u32 = 256;
const ZERO: Self = Self(CompressedRistretto([0_u8; 32]));
type Storage = RistrettoPoint;
const BITS: u32 = 1024;
const ZERO: Self = Self(constants::RISTRETTO_BASEPOINT_POINT);

impl_shared_value_common!();
}

pub const DECOMPRESS_OP: &str = "RP25519.decompress";
pub const DECOMPRESS_ADD_OP: &str = "RP25519.decompress.add";
pub const DECOMPRESS_MUL_OP: &str = "RP25519.decompress.mul";
pub const DECOMPRESS_DESER_OP: &str = "RP25519.decompress.deserialize";
pub const COMPRESS_OP: &str = "RP25519.compress";
pub const COMPRESS_HASH_OP: &str = "RP25519.compress.hash";
pub const COMPRESS_SER_OP: &str = "RP25519.compress.serialize";
pub const COMPRESS_FROM_SCALAR_OP: &str = "RP25519.compress.from.scalar";
pub const COMPRESS_FROM_FP_OP: &str = "RP25519.compress.from.fp";

impl Vectorizable<1> for RP25519 {
type Array = StdArray<Self, 1>;
Expand All @@ -61,8 +70,14 @@ impl Vectorizable<PRF_CHUNK> for RP25519 {
#[error("{0:?} is not the canonical encoding of a Ristretto point.")]
pub struct NonCanonicalEncoding(CompressedRistretto);

impl Serializable for RP25519 {
type Size = <<RP25519 as SharedValue>::Storage as Block>::Size;
#[derive(Copy, Clone, Debug)]
pub struct CompressedRp25519(CompressedRistretto);
impl Block for CompressedRp25519 {
type Size = U32;
}

impl Serializable for CompressedRp25519 {
type Size = <Self as Block>::Size;
type DeserializationError = NonCanonicalEncoding;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
Expand All @@ -71,12 +86,29 @@ impl Serializable for RP25519 {

fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError> {
let point = CompressedRistretto((*buf).into());
if cfg!(debug_assertions) && point.decompress().is_none() {
counter!(DECOMPRESS_OP, 1);
return Err(NonCanonicalEncoding(point));
}
Ok(CompressedRp25519(point))
}
}

impl Serializable for RP25519 {
type Size = <CompressedRp25519 as Block>::Size;
type DeserializationError = NonCanonicalEncoding;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
counter!(COMPRESS_OP, 1);
counter!(COMPRESS_SER_OP, 1);
let compressed = CompressedRp25519(self.0.compress());
*buf.as_mut() = compressed.0.to_bytes();
}

Ok(RP25519(point))
fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError> {
let point = CompressedRistretto((*buf).into());
counter!(DECOMPRESS_OP, 1);
counter!(DECOMPRESS_DESER_OP, 1);
match point.decompress() {
Some(v) => Ok(Self(v)),
None => Err(NonCanonicalEncoding(point)),
}
}
}

Expand All @@ -87,8 +119,11 @@ impl std::ops::Add for RP25519 {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
counter!(DECOMPRESS_OP, 2);
Self((self.0.decompress().unwrap() + rhs.0.decompress().unwrap()).compress())
// counter!(DECOMPRESS_OP, 2);
// counter!(COMPRESS_OP, 1);
// counter!(COMPRESS_ADD_OP, 1);
// counter!(DECOMPRESS_ADD_OP, 2);
Self(self.0 + rhs.0)
}
}

Expand All @@ -106,9 +141,9 @@ impl std::ops::Neg for RP25519 {
type Output = Self;

fn neg(self) -> Self::Output {
counter!(DECOMPRESS_OP, 1);
counter!(COMPRESS_OP, 1);
Self(self.0.decompress().unwrap().neg().compress())
// counter!(DECOMPRESS_OP, 1);
// counter!(COMPRESS_OP, 1);
Self(self.0.neg())
}
}

Expand All @@ -119,9 +154,9 @@ impl std::ops::Sub for RP25519 {
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
counter!(DECOMPRESS_OP, 2);
counter!(COMPRESS_OP, 1);
Self((self.0.decompress().unwrap() - rhs.0.decompress().unwrap()).compress())
// counter!(DECOMPRESS_OP, 2);
// counter!(COMPRESS_OP, 1);
Self(self.0 - rhs.0)
}
}

Expand All @@ -140,12 +175,12 @@ impl std::ops::SubAssign for RP25519 {
impl std::ops::Mul<Fp25519> for RP25519 {
type Output = Self;

fn mul(self, rhs: Fp25519) -> RP25519 {
counter!(DECOMPRESS_OP, 1);
counter!(COMPRESS_OP, 1);
(self.0.decompress().unwrap() * Scalar::from(rhs))
.compress()
.into()
fn mul(self, rhs: Fp25519) -> Self {
// counter!(DECOMPRESS_OP, 1);
// counter!(DECOMPRESS_MUL_OP, 1);
// counter!(COMPRESS_OP, 1);
// counter!(COMPRESS_MUL_OP, 1);
Self(self.0 * Scalar::from(rhs))
}
}

Expand All @@ -158,29 +193,31 @@ impl std::ops::MulAssign<Fp25519> for RP25519 {

impl From<Scalar> for RP25519 {
fn from(s: Scalar) -> Self {
counter!(COMPRESS_OP, 1);
RP25519(RistrettoPoint::mul_base(&s).compress())
// counter!(COMPRESS_OP, 1);
// counter!(COMPRESS_FROM_SCALAR_OP, 1);
Self(RistrettoPoint::mul_base(&s))
}
}

impl From<Fp25519> for RP25519 {
fn from(s: Fp25519) -> Self {
counter!(COMPRESS_OP, 1);
RP25519(RistrettoPoint::mul_base(&s.into()).compress())
// counter!(COMPRESS_OP, 1);
// counter!(COMPRESS_FROM_FP_OP, 1);
Self(RistrettoPoint::mul_base(&s.into()))
}
}

impl From<CompressedRistretto> for RP25519 {
fn from(s: CompressedRistretto) -> Self {
RP25519(s)
}
}
// impl From<CompressedRistretto> for RP25519 {
// fn from(s: CompressedRistretto) -> Self {
// RP25519(s)
// }
// }

impl From<RP25519> for CompressedRistretto {
fn from(s: RP25519) -> Self {
s.0
}
}
// impl From<RP25519> for CompressedRistretto {
// fn from(s: RP25519) -> Self {
// s.0
// }
// }

///allows to convert curve points into unsigned integers, preserving high entropy
macro_rules! cp_hash_impl {
Expand All @@ -189,7 +226,9 @@ macro_rules! cp_hash_impl {
fn from(s: RP25519) -> Self {
use hkdf::Hkdf;
use sha2::Sha256;
let hk = Hkdf::<Sha256>::new(None, s.0.as_bytes());
ipa_metrics::counter!(crate::ff::curve_points::COMPRESS_OP, 1);
ipa_metrics::counter!(crate::ff::curve_points::COMPRESS_HASH_OP, 1);
let hk = Hkdf::<Sha256>::new(None, s.0.compress().as_bytes());
let mut okm = <$u_type>::MIN.to_le_bytes();
//error invalid length from expand only happens when okm is very large
hk.expand(&[], &mut okm).unwrap();
Expand All @@ -199,8 +238,8 @@ macro_rules! cp_hash_impl {
};
}

cp_hash_impl!(u128);

// cp_hash_impl!(u128);
//
cp_hash_impl!(u64);

/// implementing random curve point generation for testing purposes,
Expand All @@ -210,8 +249,7 @@ impl rand::distributions::Distribution<RP25519> for rand::distributions::Standar
fn sample<R: crate::rand::Rng + ?Sized>(&self, rng: &mut R) -> RP25519 {
let mut scalar_bytes = [0u8; 64];
rng.fill_bytes(&mut scalar_bytes);
counter!(COMPRESS_OP, 1);
RP25519(RistrettoPoint::from_uniform_bytes(&scalar_bytes).compress())
RP25519(RistrettoPoint::from_uniform_bytes(&scalar_bytes))
}
}

Expand Down Expand Up @@ -247,8 +285,8 @@ mod test {
let b: RP25519 = a.into();
let d: Fp25519 = a.into();
let c: RP25519 = RP25519::from(d);
assert_eq!(b, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED));
assert_eq!(c, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED));
assert_eq!(b, RP25519(constants::RISTRETTO_BASEPOINT_POINT));
assert_eq!(c, RP25519(constants::RISTRETTO_BASEPOINT_POINT));
}

///testing simple curve arithmetics to check that `curve25519_dalek` library is used correctly
Expand All @@ -260,13 +298,13 @@ mod test {
let fp_c = fp_a + fp_b;
let fp_d = RP25519::from(fp_a) + RP25519::from(fp_b);
assert_eq!(fp_d, RP25519::from(fp_c));
assert_ne!(fp_d, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED));
assert_ne!(fp_d, RP25519(constants::RISTRETTO_BASEPOINT_POINT));
let fp_e = rng.gen::<Fp25519>();
let fp_f = rng.gen::<Fp25519>();
let fp_g = fp_e * fp_f;
let fp_h = RP25519::from(fp_e) * fp_f;
assert_eq!(fp_h, RP25519::from(fp_g));
assert_ne!(fp_h, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED));
assert_ne!(fp_h, RP25519(constants::RISTRETTO_BASEPOINT_POINT));
assert_eq!(RP25519::ZERO, fp_h * Scalar::ZERO.into());
}

Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/ipa_prf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ where
// multiplications per batch
const CONV_PROOF_CHUNK: usize = 256;

#[tracing::instrument(name = "compute_prf_for_inputs", skip_all)]
#[tracing::instrument(name = "compute_prf_for_inputs", skip_all, fields(sz = input_rows.len()))]
async fn compute_prf_for_inputs<C, BK, TV, TS>(
ctx: C,
input_rows: &[OPRFIPAInputRow<BK, TV, TS>],
Expand Down
8 changes: 6 additions & 2 deletions ipa-core/src/secret_sharing/vector/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,9 @@ macro_rules! impl_serializable {
type DeserializationError = <V as Serializable>::DeserializationError;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
let sz: usize = (<V as SharedValue>::BITS / 8).try_into().unwrap();
use typenum::Unsigned;
let sz: usize = <V as Serializable>::Size::USIZE;

for i in 0..$width {
self.0[i].serialize(
GenericArray::try_from_mut_slice(&mut buf[sz * i..sz * (i + 1)]).unwrap(),
Expand All @@ -364,7 +366,9 @@ macro_rules! impl_serializable {
fn deserialize(
buf: &GenericArray<u8, Self::Size>,
) -> Result<Self, Self::DeserializationError> {
let sz: usize = (<V as SharedValue>::BITS / 8).try_into().unwrap();
use typenum::Unsigned;
let sz: usize = <V as Serializable>::Size::USIZE;

let mut res = [V::ZERO; $width];
for i in 0..$width {
res[i] = V::deserialize(GenericArray::from_slice(&buf[sz * i..sz * (i + 1)]))?;
Expand Down

0 comments on commit d2a6e96

Please sign in to comment.