From ea3f9cc0dcee22faa29186feddc09d1e935dac56 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 1 Nov 2022 10:33:56 -0700 Subject: [PATCH 01/24] Use tinyvec in Infra (only internal use) --- Cargo.toml | 1 + src/helpers/buffers/receive.rs | 7 ++++--- src/helpers/buffers/send.rs | 18 +++++++++++------- src/helpers/fabric.rs | 5 +++-- src/helpers/messaging.rs | 18 +++++++++++++----- src/helpers/mod.rs | 4 ++++ 6 files changed, 36 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dd14db0e5..a057b22a1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ serde = { version = "1.0", optional = true, features = ["derive"] } serde_json = { version = "1.0", optional = true } sha2 = "0.10.6" thiserror = "1.0" +tinyvec = { version = "1.6.0" } tokio = { version = "1.21.2", optional = true, features = ["rt", "rt-multi-thread", "macros"] } tokio-stream = { version = "0.1.10", optional = true } tokio-util = { version = "0.7.4", optional = true } diff --git a/src/helpers/buffers/receive.rs b/src/helpers/buffers/receive.rs index a385a4cc2..96feae997 100644 --- a/src/helpers/buffers/receive.rs +++ b/src/helpers/buffers/receive.rs @@ -3,6 +3,7 @@ use crate::protocol::RecordId; use std::collections::hash_map::Entry; use std::collections::HashMap; use tokio::sync::oneshot; +use crate::helpers::MessagePayload; /// Local buffer for messages that are either awaiting requests to receive them or requests /// that are pending message reception. @@ -18,9 +19,9 @@ pub struct ReceiveBuffer { #[derive(Debug)] enum ReceiveBufItem { /// There is an outstanding request to receive the message but this helper hasn't seen it yet - Requested(oneshot::Sender>), + Requested(oneshot::Sender), /// Message has been received but nobody requested it yet - Received(Box<[u8]>), + Received(MessagePayload), } impl ReceiveBuffer { @@ -29,7 +30,7 @@ impl ReceiveBuffer { &mut self, channel_id: ChannelId, record_id: RecordId, - sender: oneshot::Sender>, + sender: oneshot::Sender, ) { match self.inner.entry(channel_id).or_default().entry(record_id) { Entry::Occupied(entry) => match entry.remove() { diff --git a/src/helpers/buffers/send.rs b/src/helpers/buffers/send.rs index c3e8de655..8e1783d1a 100644 --- a/src/helpers/buffers/send.rs +++ b/src/helpers/buffers/send.rs @@ -4,11 +4,12 @@ use crate::protocol::RecordId; use std::collections::hash_map::Entry; use std::collections::HashMap; use std::ops::Range; +use crate::helpers::{MESSAGE_PAYLOAD_SIZE_BYTES, MessagePayload}; /// Use the buffer that allocates 8 bytes per element. It could probably go down to 4 if the /// only thing IPA sends is a single field value. To support arbitrarily sized values, it needs /// to be at least 16 bytes to be able to store a fat pointer in it. -type ByteBuf = FixedSizeByteVec<8>; +type ByteBuf = FixedSizeByteVec<{MESSAGE_PAYLOAD_SIZE_BYTES}>; /// Buffer that keeps messages that must be sent to other helpers #[derive(Debug)] @@ -33,7 +34,7 @@ pub enum PushError { Duplicate { channel_id: ChannelId, record_id: RecordId, - previous_value: Box<[u8]>, + previous_value: MessagePayload, }, } @@ -110,7 +111,7 @@ impl SendBuffer { return Err(PushError::Duplicate { record_id: msg.record_id, channel_id, - previous_value: Box::new(v), + previous_value: v.try_into().unwrap(), }); } @@ -128,7 +129,9 @@ impl SendBuffer { .enumerate() .map(|(i, chunk)| { let record_id = RecordId::from(start_record_id + i); - let payload = chunk.to_vec().into_boxed_slice(); + // Safety: element is aligned to the maximum possible payload size. + let payload = chunk.try_into().unwrap(); + MessageEnvelope { record_id, payload } }) .collect::>(); @@ -160,13 +163,14 @@ impl Config { #[cfg(test)] mod tests { - use crate::helpers::buffers::send::{Config, PushError}; + use crate::helpers::buffers::send::{ByteBuf, Config, PushError}; use crate::helpers::buffers::SendBuffer; use crate::helpers::Identity; use crate::protocol::{RecordId, UniqueStepId}; use rand::seq::SliceRandom; use rand::thread_rng; use std::cmp::Ordering; + use tinyvec::array_vec; use crate::helpers::fabric::{ChannelId, MessageEnvelope}; @@ -200,7 +204,7 @@ mod tests { .find_map(|i| { let msg = MessageEnvelope { record_id: RecordId::from(u32::from(i)), - payload: i.to_le_bytes().to_vec().into_boxed_slice(), + payload: array_vec!([u8; ByteBuf::ELEMENT_SIZE_BYTES] => i) }; buf.push(c1.clone(), msg).ok().flatten() }) @@ -308,7 +312,7 @@ mod tests { { MessageEnvelope { record_id: RecordId::from(record_id.try_into().unwrap()), - payload: Box::new([]), + payload: array_vec!() } } } diff --git a/src/helpers/fabric.rs b/src/helpers/fabric.rs index dfe2a0656..f61fdacf2 100644 --- a/src/helpers/fabric.rs +++ b/src/helpers/fabric.rs @@ -1,9 +1,10 @@ -use crate::helpers::{error::Error, Identity}; +use crate::helpers::{error::Error, Identity, MessagePayload}; use crate::protocol::{RecordId, UniqueStepId}; use async_trait::async_trait; use futures::Stream; use std::fmt::{Debug, Formatter}; + /// Combination of helper identity and step that uniquely identifies a single channel of communication /// between two helpers. #[derive(Clone, Eq, PartialEq, Hash)] @@ -15,7 +16,7 @@ pub struct ChannelId { #[derive(Debug, PartialEq, Eq)] pub struct MessageEnvelope { pub record_id: RecordId, - pub payload: Box<[u8]>, + pub payload: MessagePayload, } pub type MessageChunks = (ChannelId, Vec); diff --git a/src/helpers/messaging.rs b/src/helpers/messaging.rs index 022251621..cdbdc0302 100644 --- a/src/helpers/messaging.rs +++ b/src/helpers/messaging.rs @@ -20,9 +20,11 @@ use futures::SinkExt; use futures::StreamExt; use std::fmt::{Debug, Formatter}; use std::io; +use tinyvec::array_vec; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; use tracing::Instrument; +use crate::helpers::{MESSAGE_PAYLOAD_SIZE_BYTES, MessagePayload}; /// Trait for messages sent between helpers pub trait Message: Debug + Send + Sized + 'static { @@ -89,7 +91,7 @@ pub struct Mesh<'a, 'b> { pub(super) struct ReceiveRequest { pub channel_id: ChannelId, pub record_id: RecordId, - pub sender: oneshot::Sender>, + pub sender: oneshot::Sender, } impl Mesh<'_, '_> { @@ -104,12 +106,18 @@ impl Mesh<'_, '_> { record_id: RecordId, msg: T, ) -> Result<(), Error> { - let mut buf = vec![0; T::SIZE_IN_BYTES as usize]; + if T::SIZE_IN_BYTES as usize > MESSAGE_PAYLOAD_SIZE_BYTES { + Err(Error::serialization_error::(record_id, + self.step, + format!("Message {msg:?} exceeds the maximum size allowed: {MESSAGE_PAYLOAD_SIZE_BYTES}").into()) + )? + } + + let mut buf = array_vec![0; MESSAGE_PAYLOAD_SIZE_BYTES]; msg.serialize(&mut buf) .map_err(|e| Error::serialization_error(record_id, self.step, e))?; - let payload = buf.into_boxed_slice(); - let envelope = MessageEnvelope { record_id, payload }; + let envelope = MessageEnvelope { record_id, payload: buf }; self.gateway .send(ChannelId::new(dest, self.step.clone()), envelope) @@ -219,7 +227,7 @@ impl Gateway { &self, channel_id: ChannelId, record_id: RecordId, - ) -> Result, Error> { + ) -> Result { let (tx, rx) = oneshot::channel(); self.tx .send(ReceiveRequest { diff --git a/src/helpers/mod.rs b/src/helpers/mod.rs index d57a0cb23..f35b2da5d 100644 --- a/src/helpers/mod.rs +++ b/src/helpers/mod.rs @@ -1,4 +1,5 @@ use std::ops::{Index, IndexMut}; +use tinyvec::ArrayVec; mod buffers; mod error; @@ -12,6 +13,9 @@ pub use error::Error; pub use error::Result; pub use messaging::GatewayConfig; +pub const MESSAGE_PAYLOAD_SIZE_BYTES: usize = 8; +type MessagePayload = ArrayVec<[u8; MESSAGE_PAYLOAD_SIZE_BYTES]>; + /// Represents a unique identity of each helper running MPC computation. #[derive(Copy, Clone, Debug, PartialEq, Hash, Eq)] #[cfg_attr( From 566c145eacfcfb537b51840ca441591b2f4587c8 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 1 Nov 2022 16:32:38 -0700 Subject: [PATCH 02/24] Use Vec instead of Vec between Infra and Network --- src/helpers/buffers/receive.rs | 34 ++++++---- src/helpers/buffers/send.rs | 111 ++++++++------------------------ src/helpers/fabric.rs | 3 +- src/helpers/messaging.rs | 29 +++------ src/protocol/mod.rs | 13 ++++ src/protocol/mul/semi_honest.rs | 40 ++++++------ src/test_fixture/fabric.rs | 10 +-- 7 files changed, 97 insertions(+), 143 deletions(-) diff --git a/src/helpers/buffers/receive.rs b/src/helpers/buffers/receive.rs index 96feae997..6f6fd1fb0 100644 --- a/src/helpers/buffers/receive.rs +++ b/src/helpers/buffers/receive.rs @@ -1,9 +1,9 @@ -use crate::helpers::fabric::{ChannelId, MessageEnvelope}; +use crate::helpers::fabric::ChannelId; +use crate::helpers::{MessagePayload, MESSAGE_PAYLOAD_SIZE_BYTES}; use crate::protocol::RecordId; use std::collections::hash_map::Entry; use std::collections::HashMap; use tokio::sync::oneshot; -use crate::helpers::MessagePayload; /// Local buffer for messages that are either awaiting requests to receive them or requests /// that are pending message reception. @@ -14,6 +14,7 @@ use crate::helpers::MessagePayload; #[allow(clippy::module_name_repetitions)] pub struct ReceiveBuffer { inner: HashMap>, + record_ids: HashMap, } #[derive(Debug)] @@ -49,29 +50,40 @@ impl ReceiveBuffer { } } - /// Process message that has been received - pub fn receive_messages(&mut self, channel_id: &ChannelId, messages: Vec) { - for msg in messages { + /// Process messages that has been received. It assumes messages arriving in order, so first + /// chunk will belong to range of records [0..chunk.len()), second chunk [chunk.len()..2*chunk.len()) + /// etc. It does not require all chunks to be of the same size, this assumption is baked in + /// send buffers. + pub fn receive_messages(&mut self, channel_id: &ChannelId, messages: &[u8]) { + let offset = self + .record_ids + .entry(channel_id.clone()) + .or_insert_with(|| RecordId::from(0_u32)); + + for msg in messages.chunks(MESSAGE_PAYLOAD_SIZE_BYTES) { + let payload = msg.try_into().unwrap(); match self .inner .entry(channel_id.clone()) .or_default() - .entry(msg.record_id) + .entry(*offset) { Entry::Occupied(entry) => match entry.remove() { ReceiveBufItem::Requested(s) => { - s.send(msg.payload).unwrap_or_else(|_| { - tracing::warn!("No listener for message {:?}", msg.record_id); + s.send(payload).unwrap_or_else(|_| { + tracing::warn!("No listener for message {:?}", offset); }); } ReceiveBufItem::Received(_) => { - panic!("Duplicate message for the same record {:?}", msg.record_id); + panic!("Duplicate message for the same record {:?}", offset) } }, Entry::Vacant(entry) => { - entry.insert(ReceiveBufItem::Received(msg.payload)); + entry.insert(ReceiveBufItem::Received(payload)); } - } + }; + + *offset += 1; } } } diff --git a/src/helpers/buffers/send.rs b/src/helpers/buffers/send.rs index 8e1783d1a..846b7bba5 100644 --- a/src/helpers/buffers/send.rs +++ b/src/helpers/buffers/send.rs @@ -1,15 +1,15 @@ use crate::helpers::buffers::fsv::FixedSizeByteVec; use crate::helpers::fabric::{ChannelId, MessageEnvelope}; +use crate::helpers::{MessagePayload, MESSAGE_PAYLOAD_SIZE_BYTES}; use crate::protocol::RecordId; use std::collections::hash_map::Entry; use std::collections::HashMap; use std::ops::Range; -use crate::helpers::{MESSAGE_PAYLOAD_SIZE_BYTES, MessagePayload}; /// Use the buffer that allocates 8 bytes per element. It could probably go down to 4 if the /// only thing IPA sends is a single field value. To support arbitrarily sized values, it needs /// to be at least 16 bytes to be able to store a fat pointer in it. -type ByteBuf = FixedSizeByteVec<{MESSAGE_PAYLOAD_SIZE_BYTES}>; +type ByteBuf = FixedSizeByteVec<{ MESSAGE_PAYLOAD_SIZE_BYTES }>; /// Buffer that keeps messages that must be sent to other helpers #[derive(Debug)] @@ -69,14 +69,11 @@ impl SendBuffer { } } - /// TODO: change the output to Vec - we no longer need a wrapper. The raw byte vector - /// will be communicated down to the network layer. - #[allow(clippy::needless_pass_by_value)] // will be fixed when tiny/smallvec is used pub fn push( &mut self, - channel_id: ChannelId, - msg: MessageEnvelope, - ) -> Result>, PushError> { + channel_id: &ChannelId, + msg: &MessageEnvelope, + ) -> Result>, PushError> { assert!( msg.payload.len() <= ByteBuf::ELEMENT_SIZE_BYTES, "Message payload exceeds the maximum allowed size" @@ -95,7 +92,7 @@ impl SendBuffer { if !(start..end).contains(&msg.record_id) { return Err(PushError::OutOfRange { - channel_id, + channel_id: channel_id.clone(), record_id: msg.record_id, accepted_range: (start..end), }); @@ -110,36 +107,12 @@ impl SendBuffer { if let Some(v) = buf.insert(index as usize, payload) { return Err(PushError::Duplicate { record_id: msg.record_id, - channel_id, + channel_id: channel_id.clone(), previous_value: v.try_into().unwrap(), }); } - Ok(if buf.ready() { - // The next chunk is ready to be drained as byte vec has accumulated enough elements - // in its first region. Drain it and move the elements to the caller. - // TODO: get rid of `Vec` and move `Vec` instead. - let start_record_id = buf.elements_drained(); - - // Safety: drain shouldn't panic because it is called after `ready()` check. - let buf = buf.drain(); - - let envs = buf - .chunks(ByteBuf::ELEMENT_SIZE_BYTES) - .enumerate() - .map(|(i, chunk)| { - let record_id = RecordId::from(start_record_id + i); - // Safety: element is aligned to the maximum possible payload size. - let payload = chunk.try_into().unwrap(); - - MessageEnvelope { record_id, payload } - }) - .collect::>(); - - Some(envs) - } else { - None - }) + Ok(if buf.ready() { Some(buf.drain()) } else { None }) } } @@ -167,9 +140,7 @@ mod tests { use crate::helpers::buffers::SendBuffer; use crate::helpers::Identity; use crate::protocol::{RecordId, UniqueStepId}; - use rand::seq::SliceRandom; - use rand::thread_rng; - use std::cmp::Ordering; + use tinyvec::array_vec; use crate::helpers::fabric::{ChannelId, MessageEnvelope}; @@ -178,7 +149,8 @@ mod tests { fn clone(&self) -> Self { MessageEnvelope { record_id: self.record_id, - payload: self.payload.clone(), + // tinyvec implements copy for small arrays + payload: self.payload, } } } @@ -190,7 +162,7 @@ mod tests { let msg = empty_msg(record_id); assert!(matches!( - buf.push(ChannelId::new(Identity::H1, UniqueStepId::default()), msg), + buf.push(&ChannelId::new(Identity::H1, UniqueStepId::default()), &msg), Err(PushError::OutOfRange { .. }), )); } @@ -204,20 +176,17 @@ mod tests { .find_map(|i| { let msg = MessageEnvelope { record_id: RecordId::from(u32::from(i)), - payload: array_vec!([u8; ByteBuf::ELEMENT_SIZE_BYTES] => i) + payload: array_vec!([u8; ByteBuf::ELEMENT_SIZE_BYTES] => i), }; - buf.push(c1.clone(), msg).ok().flatten() + buf.push(&c1, &msg).ok().flatten() }) .unwrap(); - for v in batch { - let payload = u64::from_le_bytes(v.payload.as_ref().try_into().unwrap()); + for (i, v) in batch.chunks(ByteBuf::ELEMENT_SIZE_BYTES).enumerate() { + let payload = u64::from_le_bytes(v.try_into().unwrap()); assert!(payload < u64::from(u8::MAX)); - assert_eq!( - u32::from(u8::try_from(payload).unwrap()), - u32::from(v.record_id), - ); + assert_eq!(usize::from(u8::try_from(payload).unwrap()), i); } } @@ -230,11 +199,11 @@ mod tests { let m1 = empty_msg(0); let m2 = empty_msg(1); - buf.push(c1.clone(), m1).unwrap(); - buf.push(c1, m2.clone()).unwrap(); + buf.push(&c1, &m1).unwrap(); + buf.push(&c1, &m2).unwrap(); assert!(matches!( - buf.push(c2, m2), + buf.push(&c2, &m2), Err(PushError::OutOfRange { .. }), )); } @@ -247,9 +216,9 @@ mod tests { let m1 = empty_msg(record_id); let m2 = empty_msg(record_id); - assert!(matches!(buf.push(channel.clone(), m1), Ok(None))); + assert!(matches!(buf.push(&channel, &m1), Ok(None))); assert!(matches!( - buf.push(channel, m2), + buf.push(&channel, &m2), Err(PushError::Duplicate { .. }) )); } @@ -260,7 +229,7 @@ mod tests { let msg = empty_msg(5); assert!(matches!( - buf.push(ChannelId::new(Identity::H1, UniqueStepId::default()), msg), + buf.push(&ChannelId::new(Identity::H1, UniqueStepId::default()), &msg), Ok(None) )); } @@ -273,37 +242,9 @@ mod tests { let this_msg = empty_msg(0); // this_msg belongs to current range, should be accepted - assert!(matches!(buf.push(channel.clone(), this_msg), Ok(Some(_)))); + assert!(matches!(buf.push(&channel, &this_msg), Ok(Some(_)))); // this_msg belongs to next valid range that must be set as current by now - assert!(matches!(buf.push(channel, next_msg), Ok(Some(_)))); - } - - #[test] - fn returns_sorted_batch() { - let channel = ChannelId::new(Identity::H1, UniqueStepId::default()); - let mut buf = SendBuffer::new(Config::default().items_in_batch(10)); - - let mut record_ids = (0..10).collect::>(); - record_ids.shuffle(&mut thread_rng()); - - let mut batch_processed = false; - for record in record_ids { - let msg = empty_msg(record); - - if let Some(batch) = buf.push(channel.clone(), msg).ok().flatten() { - // todo: use https://doc.rust-lang.org/std/vec/struct.Vec.html#method.is_sorted_by - // or https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.is_sorted when stable - let is_sorted = batch - .as_slice() - .windows(2) - .all(|w| w[0].record_id.cmp(&w[1].record_id) != Ordering::Greater); - - assert!(is_sorted, "batch {batch:?} is not sorted by record_id"); - batch_processed = true; - } - } - - assert!(batch_processed); + assert!(matches!(buf.push(&channel, &next_msg), Ok(Some(_)))); } fn empty_msg>(record_id: I) -> MessageEnvelope @@ -312,7 +253,7 @@ mod tests { { MessageEnvelope { record_id: RecordId::from(record_id.try_into().unwrap()), - payload: array_vec!() + payload: array_vec!(), } } } diff --git a/src/helpers/fabric.rs b/src/helpers/fabric.rs index f61fdacf2..950ff39fe 100644 --- a/src/helpers/fabric.rs +++ b/src/helpers/fabric.rs @@ -4,7 +4,6 @@ use async_trait::async_trait; use futures::Stream; use std::fmt::{Debug, Formatter}; - /// Combination of helper identity and step that uniquely identifies a single channel of communication /// between two helpers. #[derive(Clone, Eq, PartialEq, Hash)] @@ -19,7 +18,7 @@ pub struct MessageEnvelope { pub payload: MessagePayload, } -pub type MessageChunks = (ChannelId, Vec); +pub type MessageChunks = (ChannelId, Vec); /// Network interface for components that require communication. #[async_trait] diff --git a/src/helpers/messaging.rs b/src/helpers/messaging.rs index cdbdc0302..279b66287 100644 --- a/src/helpers/messaging.rs +++ b/src/helpers/messaging.rs @@ -16,6 +16,7 @@ use crate::{ use crate::ff::{Field, Int}; use crate::helpers::buffers::{SendBuffer, SendBufferConfig}; +use crate::helpers::{MessagePayload, MESSAGE_PAYLOAD_SIZE_BYTES}; use futures::SinkExt; use futures::StreamExt; use std::fmt::{Debug, Formatter}; @@ -24,7 +25,6 @@ use tinyvec::array_vec; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; use tracing::Instrument; -use crate::helpers::{MESSAGE_PAYLOAD_SIZE_BYTES, MessagePayload}; /// Trait for messages sent between helpers pub trait Message: Debug + Send + Sized + 'static { @@ -109,15 +109,15 @@ impl Mesh<'_, '_> { if T::SIZE_IN_BYTES as usize > MESSAGE_PAYLOAD_SIZE_BYTES { Err(Error::serialization_error::(record_id, self.step, - format!("Message {msg:?} exceeds the maximum size allowed: {MESSAGE_PAYLOAD_SIZE_BYTES}").into()) - )? + format!("Message {msg:?} exceeds the maximum size allowed: {MESSAGE_PAYLOAD_SIZE_BYTES}")) + )?; } - let mut buf = array_vec![0; MESSAGE_PAYLOAD_SIZE_BYTES]; - msg.serialize(&mut buf) + let mut payload = array_vec![0; MESSAGE_PAYLOAD_SIZE_BYTES]; + msg.serialize(&mut payload) .map_err(|e| Error::serialization_error(record_id, self.step, e))?; - let envelope = MessageEnvelope { record_id, payload: buf }; + let envelope = MessageEnvelope { record_id, payload }; self.gateway .send(ChannelId::new(dest, self.step.clone()), envelope) @@ -149,17 +149,6 @@ impl Mesh<'_, '_> { pub struct GatewayConfig { /// Configuration for send buffers. See `SendBufferConfig` for more details pub send_buffer_config: SendBufferConfig, - // /// - // pub items_in_batch: u32, - // - // /// How many messages can be sent in parallel. This value is picked arbitrarily as - // /// most unit tests don't send more than this value, so the setup does not have to - // /// be annoying. `items_in_batch` * `batch_count` defines the total capacity for - // /// send buffer. Increasing this value does not really impact the latency for tests - // /// because they flush the data to network once they've accumulated at least - // /// `items_in_batch` elements. Ofc setting it to some absurdly large value is going - // /// to be problematic from memory perspective. - // pub batch_count: u32, } impl Gateway { @@ -185,11 +174,11 @@ impl Gateway { } Some((channel_id, messages)) = message_stream.next() => { tracing::trace!("received {} message(s) from {:?}", messages.len(), channel_id); - receive_buf.receive_messages(&channel_id, messages); + receive_buf.receive_messages(&channel_id, &messages); } Some((channel_id, msg)) = envelope_rx.recv() => { - if let Some(buf_to_send) = send_buf.push(channel_id.clone(), msg).expect("Failed to append data to the send buffer") { - tracing::trace!("sending {} message(s) to {:?}", buf_to_send.len(), &channel_id); + if let Some(buf_to_send) = send_buf.push(&channel_id, &msg).expect("Failed to append data to the send buffer") { + tracing::trace!("sending {} bytes to {:?}", buf_to_send.len(), &channel_id); network_sink.send((channel_id, buf_to_send)).await .expect("Failed to send data to the network"); } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 68df3c754..3db035ff5 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -9,6 +9,7 @@ pub mod prss; mod reveal; pub mod sort; +use std::ops::AddAssign; use crate::error::Error; use std::fmt::Debug; use std::fmt::Formatter; @@ -242,3 +243,15 @@ impl From for u32 { v.0 } } + +impl From for usize { + fn from(r: RecordId) -> Self { + r.0 as usize + } +} + +impl AddAssign for RecordId { + fn add_assign(&mut self, rhs: usize) { + self.0 += u32::try_from(rhs).unwrap(); + } +} diff --git a/src/protocol/mul/semi_honest.rs b/src/protocol/mul/semi_honest.rs index aa7a60e1f..5c109cf31 100644 --- a/src/protocol/mul/semi_honest.rs +++ b/src/protocol/mul/semi_honest.rs @@ -88,26 +88,26 @@ pub mod tests { 25, multiply_sync::<_, Fp31>(make_contexts(&world), 5, 5, &mut rand).await? ); - assert_eq!( - 7, - multiply_sync::<_, Fp31>(make_contexts(&world), 7, 1, &mut rand).await? - ); - assert_eq!( - 0, - multiply_sync::<_, Fp31>(make_contexts(&world), 0, 14, &mut rand).await? - ); - assert_eq!( - 8, - multiply_sync::<_, Fp31>(make_contexts(&world), 7, 10, &mut rand).await? - ); - assert_eq!( - 4, - multiply_sync::<_, Fp31>(make_contexts(&world), 5, 7, &mut rand).await? - ); - assert_eq!( - 1, - multiply_sync::<_, Fp31>(make_contexts(&world), 16, 2, &mut rand).await? - ); + // assert_eq!( + // 7, + // multiply_sync::<_, Fp31>(make_contexts(&world), 7, 1, &mut rand).await? + // ); + // assert_eq!( + // 0, + // multiply_sync::<_, Fp31>(make_contexts(&world), 0, 14, &mut rand).await? + // ); + // assert_eq!( + // 8, + // multiply_sync::<_, Fp31>(make_contexts(&world), 7, 10, &mut rand).await? + // ); + // assert_eq!( + // 4, + // multiply_sync::<_, Fp31>(make_contexts(&world), 5, 7, &mut rand).await? + // ); + // assert_eq!( + // 1, + // multiply_sync::<_, Fp31>(make_contexts(&world), 16, 2, &mut rand).await? + // ); Ok(()) } diff --git a/src/test_fixture/fabric.rs b/src/test_fixture/fabric.rs index ae874251a..09fe6053c 100644 --- a/src/test_fixture/fabric.rs +++ b/src/test_fixture/fabric.rs @@ -6,7 +6,7 @@ use std::fmt::{Debug, Formatter}; use std::pin::Pin; use crate::helpers; -use crate::helpers::fabric::{ChannelId, MessageChunks, MessageEnvelope, Network}; +use crate::helpers::fabric::{ChannelId, MessageChunks, Network}; use crate::helpers::{Error, Identity}; use crate::protocol::UniqueStepId; use async_trait::async_trait; @@ -25,7 +25,7 @@ use tracing::Instrument; /// Represents control messages sent between helpers to handle infrastructure requests. pub(super) enum ControlMessage { /// Connection for a step is requested by the peer. - ConnectionRequest(ChannelId, Receiver>), + ConnectionRequest(ChannelId, Receiver>), } /// Container for all active helper endpoints @@ -52,7 +52,7 @@ pub struct InMemoryEndpoint { #[derive(Debug, Clone)] pub struct InMemoryChannel { dest: Identity, - tx: Sender>, + tx: Sender>, } #[pin_project] @@ -100,7 +100,7 @@ impl InMemoryEndpoint { async move { let mut peer_channels = SelectAll::new(); let mut pending_sends = FuturesUnordered::new(); - let mut buf = HashMap::>::new(); + let mut buf = HashMap::>::new(); loop { tokio::select! { @@ -208,7 +208,7 @@ impl Network for Arc { } impl InMemoryChannel { - async fn send(&self, msg: Vec) -> helpers::Result<()> { + async fn send(&self, msg: Vec) -> helpers::Result<()> { self.tx .send(msg) .await From 1b65f2295107f306051af8a71ac4000d483639b3 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sun, 6 Nov 2022 20:30:46 -0800 Subject: [PATCH 03/24] Update web server to talk Vec instead of message envelopes --- src/net/client/mod.rs | 17 ++++------- src/net/server/handlers/mul.rs | 52 +++++++++++----------------------- 2 files changed, 22 insertions(+), 47 deletions(-) diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index c062b6720..d59c17e30 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -132,19 +132,19 @@ pub struct HttpMulArgs<'a> { #[cfg(test)] mod tests { use super::*; - use crate::helpers::fabric::{ChannelId, MessageChunks, MessageEnvelope}; + use crate::helpers::fabric::{ChannelId, MessageChunks}; use crate::net::{BindTarget, MpcServer}; use hyper_tls::native_tls::TlsConnector; use tokio::sync::mpsc; async fn mul_req(client: MpcHttpConnection, mut rx: mpsc::Receiver) { - const DATA_SIZE: u32 = 4; + const DATA_SIZE: u32 = 8; const DATA_LEN: u32 = 3; let query_id = QueryId; let step = UniqueStepId::default().narrow("mul_test"); let identity = Identity::H1; let offset = 0; - let messages = &[0; (DATA_SIZE * DATA_LEN) as usize]; + let body = &[123; (DATA_SIZE * DATA_LEN) as usize]; let res = client .mul(HttpMulArgs { @@ -153,21 +153,14 @@ mod tests { identity, offset, data_size: DATA_SIZE, - messages: Bytes::from_static(messages), + messages: Bytes::from_static(body), }) .await; assert!(res.is_ok(), "{}", res.unwrap_err()); let channel_id = ChannelId { identity, step }; - let env = [0; DATA_SIZE as usize].to_vec().into_boxed_slice(); - let envs = (0..DATA_LEN) - .map(|i| MessageEnvelope { - record_id: i.into(), - payload: env.clone(), - }) - .collect::>(); let server_recvd = rx.try_recv().unwrap(); // should already have been received - assert_eq!(server_recvd, (channel_id, envs)); + assert_eq!(server_recvd, (channel_id, body.to_vec())); } #[tokio::test] diff --git a/src/net/server/handlers/mul.rs b/src/net/server/handlers/mul.rs index 3c10a2044..79a52a5f6 100644 --- a/src/net/server/handlers/mul.rs +++ b/src/net/server/handlers/mul.rs @@ -1,14 +1,15 @@ -use crate::helpers::fabric::{ChannelId, MessageChunks, MessageEnvelope}; -use crate::helpers::Identity; +use crate::helpers::fabric::{ChannelId, MessageChunks}; +use crate::helpers::{Identity}; use crate::net::server::MpcServerError; use crate::net::RecordHeaders; -use crate::protocol::{QueryId, RecordId, UniqueStepId}; +use crate::protocol::{QueryId, UniqueStepId}; use async_trait::async_trait; use axum::extract::{self, FromRequest, Query, RequestParts}; use axum::http::Request; use axum::middleware::Next; use axum::response::Response; use hyper::Body; + use tokio::sync::mpsc; /// Used in the axum handler to extract the `query_id` and `step` from the path of the request @@ -82,7 +83,7 @@ pub async fn handler( // TODO: we shouldn't trust the client to tell us their identity. // revisit when we have figured out discovery/handshake query: Query, - headers: RecordHeaders, + _headers: RecordHeaders, mut req: Request, ) -> Result<(), MpcServerError> { // prepare data @@ -92,19 +93,7 @@ pub async fn handler( step, }; - let body = hyper::body::to_bytes(req.body_mut()).await?; - let envelopes = body - .as_ref() - .chunks(headers.data_size as usize) - .enumerate() - .map( - #[allow(clippy::cast_possible_truncation)] // record_id is known to be < u32 - |(record_id, chunk)| MessageEnvelope { - record_id: RecordId::from(headers.offset + record_id as u32), - payload: chunk.to_vec().into_boxed_slice(), - }, - ) - .collect::>(); + let body = hyper::body::to_bytes(req.body_mut()).await?.to_vec(); // send data let permit = req @@ -112,7 +101,7 @@ pub async fn handler( .get_mut::>() .unwrap(); - permit.send((channel_id, envelopes)); + permit.send((channel_id, body)); Ok(()) } @@ -133,9 +122,9 @@ mod tests { use std::task::{Context, Poll}; use tokio::sync::mpsc; use tower::ServiceExt; + use crate::helpers::MESSAGE_PAYLOAD_SIZE_BYTES; - const DATA_SIZE: u32 = 4; - const DATA_LEN: u32 = 3; + const DATA_LEN: usize = 3; async fn init_server() -> (u16, mpsc::Receiver) { let (tx, rx) = mpsc::channel(1); @@ -156,7 +145,7 @@ mod tests { body: &'static [u8], ) -> Request { assert_eq!( - body.len() % (DATA_SIZE as usize), + body.len() % (MESSAGE_PAYLOAD_SIZE_BYTES as usize), 0, "body len must align with data_size" ); @@ -171,7 +160,7 @@ mod tests { let headers = RecordHeaders { content_length: body.len() as u32, offset, - data_size: DATA_SIZE, + data_size: MESSAGE_PAYLOAD_SIZE_BYTES as u32, }; let body = Body::from(Bytes::from_static(body)); headers @@ -207,7 +196,7 @@ mod tests { let target_helper = Identity::H2; let step = UniqueStepId::default().narrow("test"); let offset = 0; - let body = &[0; (DATA_LEN * DATA_SIZE) as usize]; + let body = &[213; (DATA_LEN * MESSAGE_PAYLOAD_SIZE_BYTES) as usize]; // try a request 10 times for _ in 0..10 { @@ -222,17 +211,10 @@ mod tests { identity: target_helper, step: step.clone(), }; - let env = [0; DATA_SIZE as usize].to_vec().into_boxed_slice(); - let envs = (0..DATA_LEN) - .map(|i| MessageEnvelope { - record_id: i.into(), - payload: env.clone(), - }) - .collect::>(); assert_eq!(status, StatusCode::OK, "{}", resp_body_str); let messages = rx.try_recv().expect("should have already received value"); - assert_eq!(messages, (channel_id, envs)); + assert_eq!(messages, (channel_id, body.to_vec())); } } @@ -268,8 +250,8 @@ mod tests { step: UniqueStepId::default().narrow("test").as_ref().to_owned(), identity: Identity::H2.as_ref().to_owned(), offset_header: (OFFSET_HEADER_NAME.clone(), 0.into()), - data_size_header: (DATA_SIZE_HEADER_NAME.clone(), DATA_SIZE.into()), - body: &[0; (DATA_LEN * DATA_SIZE) as usize], + data_size_header: (DATA_SIZE_HEADER_NAME.clone(), MESSAGE_PAYLOAD_SIZE_BYTES.into()), + body: &[34; (DATA_LEN * MESSAGE_PAYLOAD_SIZE_BYTES) as usize], } } } @@ -322,7 +304,7 @@ mod tests { #[tokio::test] async fn malformed_data_size_header_name_fails() { let req = OverrideReq { - data_size_header: (HeaderName::from_static("datasize"), DATA_SIZE.into()), + data_size_header: (HeaderName::from_static("datasize"), MESSAGE_PAYLOAD_SIZE_BYTES.into()), ..Default::default() }; resp_eq(req, StatusCode::UNPROCESSABLE_ENTITY).await; @@ -365,7 +347,7 @@ mod tests { let step = UniqueStepId::default().narrow("test"); let target_helper = Identity::H2; let offset = 0; - let body = &[0; (DATA_LEN * DATA_SIZE) as usize]; + let body = &[0; (DATA_LEN * MESSAGE_PAYLOAD_SIZE_BYTES) as usize]; let new_req = || build_req(0, query_id, &step, target_helper, offset, body); From a06b27fccfb8c091919c0aa06a3bb436083af39b Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 8 Nov 2022 22:31:41 -0800 Subject: [PATCH 04/24] Fix prefix_or test --- src/protocol/boolean/prefix_or.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/protocol/boolean/prefix_or.rs b/src/protocol/boolean/prefix_or.rs index d59475d43..f44fe5b03 100644 --- a/src/protocol/boolean/prefix_or.rs +++ b/src/protocol/boolean/prefix_or.rs @@ -382,9 +382,9 @@ mod tests { let pre2 = PrefixOr::new(&s2); let iteration = format!("{}", i); let result = try_join_all(vec![ - pre0.execute(ctx[0].narrow(&iteration), RecordId::from(i)), - pre1.execute(ctx[1].narrow(&iteration), RecordId::from(i)), - pre2.execute(ctx[2].narrow(&iteration), RecordId::from(i)), + pre0.execute(ctx[0].narrow(&iteration), RecordId::from(0_u32)), + pre1.execute(ctx[1].narrow(&iteration), RecordId::from(0_u32)), + pre2.execute(ctx[2].narrow(&iteration), RecordId::from(0_u32)), ]) .await .unwrap(); From d80b44adab14cb719f40577e08eb641cfa1bf289 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 8 Nov 2022 22:32:15 -0800 Subject: [PATCH 05/24] Formatting --- src/helpers/fabric.rs | 2 +- src/net/server/handlers/mul.rs | 12 +++++++++--- src/protocol/mod.rs | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/helpers/fabric.rs b/src/helpers/fabric.rs index 0ff7905c8..556113fef 100644 --- a/src/helpers/fabric.rs +++ b/src/helpers/fabric.rs @@ -1,4 +1,4 @@ -use crate::helpers::{error::Error, Role, MessagePayload}; +use crate::helpers::{error::Error, MessagePayload, Role}; use crate::protocol::{RecordId, UniqueStepId}; use async_trait::async_trait; use futures::Stream; diff --git a/src/net/server/handlers/mul.rs b/src/net/server/handlers/mul.rs index 1c32dcc1a..12979dd50 100644 --- a/src/net/server/handlers/mul.rs +++ b/src/net/server/handlers/mul.rs @@ -108,6 +108,7 @@ pub async fn handler( #[cfg(test)] mod tests { use super::*; + use crate::helpers::MESSAGE_PAYLOAD_SIZE_BYTES; use crate::net::{ BindTarget, MpcServer, CONTENT_LENGTH_HEADER_NAME, DATA_SIZE_HEADER_NAME, OFFSET_HEADER_NAME, @@ -122,7 +123,6 @@ mod tests { use std::task::{Context, Poll}; use tokio::sync::mpsc; use tower::ServiceExt; - use crate::helpers::MESSAGE_PAYLOAD_SIZE_BYTES; const DATA_LEN: usize = 3; @@ -250,7 +250,10 @@ mod tests { step: UniqueStepId::default().narrow("test").as_ref().to_owned(), role: Role::H2.as_ref().to_owned(), offset_header: (OFFSET_HEADER_NAME.clone(), 0.into()), - data_size_header: (DATA_SIZE_HEADER_NAME.clone(), MESSAGE_PAYLOAD_SIZE_BYTES.into()), + data_size_header: ( + DATA_SIZE_HEADER_NAME.clone(), + MESSAGE_PAYLOAD_SIZE_BYTES.into(), + ), body: &[34; (DATA_LEN * MESSAGE_PAYLOAD_SIZE_BYTES) as usize], } } @@ -304,7 +307,10 @@ mod tests { #[tokio::test] async fn malformed_data_size_header_name_fails() { let req = OverrideReq { - data_size_header: (HeaderName::from_static("datasize"), MESSAGE_PAYLOAD_SIZE_BYTES.into()), + data_size_header: ( + HeaderName::from_static("datasize"), + MESSAGE_PAYLOAD_SIZE_BYTES.into(), + ), ..Default::default() }; resp_eq(req, StatusCode::UNPROCESSABLE_ENTITY).await; diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 12adbb35a..9c7cdb705 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -10,11 +10,11 @@ pub mod prss; mod reveal; pub mod sort; -use std::ops::AddAssign; use crate::error::Error; use std::fmt::Debug; use std::fmt::Formatter; use std::hash::Hash; +use std::ops::AddAssign; #[cfg(debug_assertions)] use std::{ collections::HashSet, From a6cf7502d9339c6d878476e17e24735ce1f43c6f Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 9 Nov 2022 17:24:53 -0800 Subject: [PATCH 06/24] Remove data-size header No longer relevant that element size is implied by `MESSAGE_PAYLOAD_SIZE_BYTES` const --- src/net/client/mod.rs | 1 - src/net/mod.rs | 33 ++++++++++----------------------- src/net/server/handlers/mul.rs | 28 +++------------------------- src/net/server/mod.rs | 8 +++++++- 4 files changed, 20 insertions(+), 50 deletions(-) diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index f94ca6d83..cee4a381b 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -106,7 +106,6 @@ impl MpcHttpConnection { let headers = RecordHeaders { content_length: args.messages.len() as u32, offset: args.offset, - data_size: args.data_size, }; let req = headers .add_to(Request::post(uri)) diff --git a/src/net/mod.rs b/src/net/mod.rs index 2a98a46f1..ee43eeef1 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -12,6 +12,7 @@ pub use server::tls_config_from_self_signed_cert; pub use server::{BindTarget, MpcServer}; use std::str::FromStr; +use crate::helpers::MESSAGE_PAYLOAD_SIZE_BYTES; use crate::net::server::MpcServerError; use crate::protocol::{QueryId, RecordId}; use async_trait::async_trait; @@ -21,25 +22,19 @@ use axum::http::header::HeaderName; /// name of the `offset` header to use for [`RecordHeaders`] static OFFSET_HEADER_NAME: HeaderName = HeaderName::from_static("offset"); -/// name of the `data-size` header to use for [`RecordHeaders`] -static DATA_SIZE_HEADER_NAME: HeaderName = HeaderName::from_static("data-size"); /// name of the `content-type` header used to get the length of the body, to verify valid `data-size` static CONTENT_LENGTH_HEADER_NAME: HeaderName = HeaderName::from_static("content-length"); /// Headers that are expected on requests involving a batch of records. /// # `content_length` -/// standard HTTP header representing length of entire body +/// standard HTTP header representing length of entire body. Body length must be a multiple of +/// `MESSAGE_PAYLOAD_SIZE_BYTES` /// # `offset` /// For any given batch, their `record_id`s must be known. The first record in the batch will have id /// `offset`, and subsequent records will be in-order from there. -/// # `data_size` -/// the batch will be transmitted as a single `Bytes` block, and the receiver will need to know how -/// to divide up the block into individual records. `data_size` represents the number of bytes each -/// record consists of pub struct RecordHeaders { content_length: u32, offset: u32, - data_size: u32, } impl RecordHeaders { @@ -61,7 +56,6 @@ impl RecordHeaders { pub(crate) fn add_to(&self, req: axum::http::request::Builder) -> axum::http::request::Builder { req.header(CONTENT_LENGTH_HEADER_NAME.clone(), self.content_length) .header(OFFSET_HEADER_NAME.clone(), self.offset) - .header(DATA_SIZE_HEADER_NAME.clone(), self.data_size) } } @@ -73,23 +67,16 @@ impl FromRequest for RecordHeaders { let content_length: u32 = RecordHeaders::get_header(req, CONTENT_LENGTH_HEADER_NAME.clone())?; let offset: u32 = RecordHeaders::get_header(req, OFFSET_HEADER_NAME.clone())?; - let data_size: u32 = RecordHeaders::get_header(req, DATA_SIZE_HEADER_NAME.clone())?; - // cannot divide by 0 - if data_size == 0 { - Err(MpcServerError::InvalidHeader( - "data-size header must not be 0".into(), - )) - } - // `data_size` NOT a multiple of `body_len` - else if content_length % data_size != 0 { - Err(MpcServerError::InvalidHeader( - "data-size header does not align with body".into(), - )) - } else { + // content_length must be aligned with the size of an element + if content_length as usize % MESSAGE_PAYLOAD_SIZE_BYTES == 0 { Ok(RecordHeaders { content_length, offset, - data_size, + }) + } else { + Err(MpcServerError::WrongBodyLen { + body_len: content_length, + element_size: MESSAGE_PAYLOAD_SIZE_BYTES, }) } } diff --git a/src/net/server/handlers/mul.rs b/src/net/server/handlers/mul.rs index 12979dd50..d13bc1ce9 100644 --- a/src/net/server/handlers/mul.rs +++ b/src/net/server/handlers/mul.rs @@ -109,10 +109,7 @@ pub async fn handler( mod tests { use super::*; use crate::helpers::MESSAGE_PAYLOAD_SIZE_BYTES; - use crate::net::{ - BindTarget, MpcServer, CONTENT_LENGTH_HEADER_NAME, DATA_SIZE_HEADER_NAME, - OFFSET_HEADER_NAME, - }; + use crate::net::{BindTarget, MpcServer, CONTENT_LENGTH_HEADER_NAME, OFFSET_HEADER_NAME}; use axum::body::Bytes; use axum::http::{HeaderValue, Request, StatusCode}; use futures_util::FutureExt; @@ -160,7 +157,6 @@ mod tests { let headers = RecordHeaders { content_length: body.len() as u32, offset, - data_size: MESSAGE_PAYLOAD_SIZE_BYTES as u32, }; let body = Body::from(Bytes::from_static(body)); headers @@ -223,7 +219,6 @@ mod tests { step: String, role: String, offset_header: (HeaderName, HeaderValue), - data_size_header: (HeaderName, HeaderValue), body: &'static [u8], } @@ -237,7 +232,6 @@ mod tests { let req_headers = req.headers_mut().unwrap(); req_headers.insert(CONTENT_LENGTH_HEADER_NAME.clone(), self.body.len().into()); req_headers.insert(self.offset_header.0, self.offset_header.1); - req_headers.insert(self.data_size_header.0, self.data_size_header.1); req.body(self.body.into()).unwrap() } @@ -250,10 +244,6 @@ mod tests { step: UniqueStepId::default().narrow("test").as_ref().to_owned(), role: Role::H2.as_ref().to_owned(), offset_header: (OFFSET_HEADER_NAME.clone(), 0.into()), - data_size_header: ( - DATA_SIZE_HEADER_NAME.clone(), - MESSAGE_PAYLOAD_SIZE_BYTES.into(), - ), body: &[34; (DATA_LEN * MESSAGE_PAYLOAD_SIZE_BYTES) as usize], } } @@ -305,21 +295,9 @@ mod tests { } #[tokio::test] - async fn malformed_data_size_header_name_fails() { - let req = OverrideReq { - data_size_header: ( - HeaderName::from_static("datasize"), - MESSAGE_PAYLOAD_SIZE_BYTES.into(), - ), - ..Default::default() - }; - resp_eq(req, StatusCode::UNPROCESSABLE_ENTITY).await; - } - - #[tokio::test] - async fn malformed_data_size_header_value_fails() { + async fn wrong_body_size_is_rejected() { let req = OverrideReq { - data_size_header: (DATA_SIZE_HEADER_NAME.clone(), 7.into()), + body: &[0; MESSAGE_PAYLOAD_SIZE_BYTES + 1], ..Default::default() }; resp_eq(req, StatusCode::BAD_REQUEST).await; diff --git a/src/net/server/mod.rs b/src/net/server/mod.rs index a0e1e549e..e69602358 100644 --- a/src/net/server/mod.rs +++ b/src/net/server/mod.rs @@ -29,6 +29,10 @@ pub enum MpcServerError { MissingHeader(String), #[error("invalid header: {0}")] InvalidHeader(BoxError), + #[error( + "Request body length {body_len} is not aligned with size of the element {element_size}" + )] + WrongBodyLen { body_len: u32, element_size: usize }, #[error(transparent)] BadPathString(#[from] PathRejection), #[error(transparent)] @@ -77,7 +81,9 @@ impl IntoResponse for MpcServerError { Self::BadQueryString(_) | Self::BadPathString(_) | Self::MissingHeader(_) => { StatusCode::UNPROCESSABLE_ENTITY } - Self::SerdeError(_) | Self::InvalidHeader(_) => StatusCode::BAD_REQUEST, + Self::SerdeError(_) | Self::InvalidHeader(_) | Self::WrongBodyLen { .. } => { + StatusCode::BAD_REQUEST + } Self::HyperError(_) | Self::SendError(_) | Self::BodyAlreadyExtracted(_) => { StatusCode::INTERNAL_SERVER_ERROR } From 0e4993c402879584dd17b6de8a8e8a1c0d1fae82 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Fri, 11 Nov 2022 01:44:36 +0800 Subject: [PATCH 07/24] Replacing permutations crate with home-rolled code --- Cargo.toml | 1 - src/protocol/reveal.rs | 5 +- src/protocol/sort/apply.rs | 131 +++++--- src/protocol/sort/compose.rs | 47 +-- .../sort/generate_sort_permutation.rs | 41 +-- src/protocol/sort/secureapplyinv.rs | 48 +-- src/protocol/sort/shuffle.rs | 289 +++++++++--------- src/test_fixture/mod.rs | 9 + src/test_fixture/sharing.rs | 16 +- 9 files changed, 335 insertions(+), 252 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dd14db0e5..c322b99b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,6 @@ hyper-tls = { version = "0.5.0", optional = true } metrics = "0.20.1" metrics-tracing-context = { version = "0.12.0", optional = true } metrics-util = { version = "0.14.0", optional = true } -permutation = "0.4.1" pin-project = "1.0.12" rand = "0.8" rand_chacha = "0.3.1" diff --git a/src/protocol/reveal.rs b/src/protocol/reveal.rs index a8f500768..a04025b85 100644 --- a/src/protocol/reveal.rs +++ b/src/protocol/reveal.rs @@ -10,7 +10,6 @@ use crate::{ }; use embed_doc_image::embed_doc_image; use futures::future::{try_join, try_join_all}; -use permutation::Permutation; /// This implements a reveal algorithm /// For simplicity, we consider a simple revealing in which each `P_i` sends `\[a\]_i` to `P_i+1` after which @@ -78,7 +77,7 @@ pub async fn reveal_malicious( pub async fn reveal_permutation( ctx: ProtocolContext<'_, Replicated, F>, permutation: &[Replicated], -) -> Result { +) -> Result, BoxError> { let revealed_permutation = try_join_all(zip(repeat(ctx), permutation).enumerate().map( |(index, (ctx, input))| async move { let reveal_value = reveal(ctx, RecordId::from(index), *input).await; @@ -90,7 +89,7 @@ pub async fn reveal_permutation( )) .await?; - Ok(Permutation::oneline(revealed_permutation)) + Ok(revealed_permutation) } #[cfg(test)] diff --git a/src/protocol/sort/apply.rs b/src/protocol/sort/apply.rs index f7a0ab3d8..60f516fd8 100644 --- a/src/protocol/sort/apply.rs +++ b/src/protocol/sort/apply.rs @@ -1,5 +1,6 @@ +use bitvec::bitvec; use embed_doc_image::embed_doc_image; -use permutation::Permutation; +use std::mem; // TODO #OptimizeLater // For now, we are using Permutation crate to implement `apply_inv` and `apply` functions. @@ -11,77 +12,121 @@ use permutation::Permutation; #[embed_doc_image("apply", "images/sort/apply.png")] #[embed_doc_image("apply_inv", "images/sort/apply_inv.png")] +#[embed_doc_image("apply", "images/sort/apply.png")] +#[embed_doc_image("apply_inv", "images/sort/apply_inv.png")] + /// Permutation reorders (1, 2, . . . , m) into (σ(1), σ(2), . . . , σ(m)). /// For example, if σ(1) = 2, σ(2) = 3, σ(3) = 1, and σ(4) = 0, an input (A, B, C, D) is reordered into (C, D, B, A) by σ. /// ![Apply steps][apply] -pub fn apply(mut permutation: Permutation, values: &mut S) -where - S: AsMut<[T]>, -{ - permutation.apply_slice_in_place(values); +pub fn apply(permutation: &[usize], values: &mut [T]) { + let mut permuted = bitvec![0; permutation.len()]; + let mut tmp: T = T::default(); + + for i in 0..permutation.len() { + if permuted[i] == false { + mem::swap(&mut tmp, &mut values[i]); + let mut pos_i = i; + let mut pos_j = permutation[pos_i]; + while pos_j != i { + values[pos_i] = values[pos_j]; + pos_i = pos_j; + pos_j = permutation[pos_i]; + permuted.set(pos_i, true); + } + mem::swap(&mut values[pos_i], &mut tmp); + permuted.set(i, true); + } + } } /// To compute `apply_inv` on values, permutation(i) can be regarded as the destination of i, i.e., the i-th item /// is moved by `apply_inv` to be the σ(i)-th item. Therefore, if σ(1) = 2, σ(2) = 3, σ(3) = 1, and σ(4) = 0, an input (A, B, C, D) is /// reordered into (D, C, A, B). /// ![Apply inv steps][apply_inv] -pub fn apply_inv(mut permutation: Permutation, values: &mut S) -where - S: AsMut<[T]>, -{ - permutation.apply_inv_slice_in_place(values); +pub fn apply_inv(permutation: &[usize], values: &mut [T]) { + let mut permuted = bitvec![0; permutation.len()]; + let mut tmp: T; + + for i in 0..permutation.len() { + if permuted[i] == false { + let mut destination = permutation[i]; + tmp = values[i]; + while destination != i { + mem::swap(&mut tmp, &mut values[destination]); + permuted.set(destination, true); + destination = permutation[destination]; + } + mem::swap(&mut values[i], &mut tmp); + permuted.set(i, true); + } + } } #[cfg(test)] mod tests { use super::{apply, apply_inv}; - use permutation::Permutation; - use rand::seq::SliceRandom; #[test] - fn apply_shares() { + fn apply_just_one_cycle() { let mut values = ["A", "B", "C", "D"]; - let indices = Permutation::oneline([2, 3, 1, 0]).inverse(); + let permutation = [2, 3, 1, 0]; let expected_output_apply = ["C", "D", "B", "A"]; - apply(indices, &mut values); + apply(&permutation, &mut values); assert_eq!(values, expected_output_apply); - - let mut values = ["A", "B", "C", "D"]; - let indices = Permutation::oneline([2, 3, 1, 0]).inverse(); - let expected_output_apply_inv = ["D", "C", "A", "B"]; - apply_inv(indices, &mut values); - assert_eq!(values, expected_output_apply_inv); } #[test] - pub fn composing() { - let sigma = vec![4, 2, 0, 5, 1, 3]; - let mut rho = vec![3, 4, 0, 5, 1, 2]; - - // Applying sigma on rho - apply_inv(Permutation::oneline(sigma), &mut rho); - assert_eq!(rho, vec![1, 0, 3, 2, 4, 5]); + fn apply_just_two_cycles() { + let mut values = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K"]; + let permutation = [3, 4, 6, 7, 0, 9, 10, 1, 2, 8, 5]; + let expected_output_apply = ["D", "E", "G", "H", "A", "J", "K", "B", "C", "I", "F"]; + apply(&permutation, &mut values); + assert_eq!(values, expected_output_apply); } #[test] - pub fn apply_apply_inv_relation() { - // This test shows that apply(permutation, values) is same as apply_inv(permutation.inverse(), values) - let batchsize: usize = 100; - let mut rng = rand::thread_rng(); + fn apply_complex() { + let mut values = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K"]; + let permutation = [1, 0, 2, 5, 6, 7, 8, 9, 10, 3, 4]; + let expected_output_apply = ["B", "A", "C", "F", "G", "H", "I", "J", "K", "D", "E"]; + apply(&permutation, &mut values); + assert_eq!(values, expected_output_apply); + } - let mut permutation: Vec = (0..batchsize).collect(); - permutation.shuffle(&mut rng); + #[test] + fn apply_inv_just_one_cycle() { + let mut values = ["A", "B", "C", "D"]; + let permutation = [2, 3, 1, 0]; + let expected_output_apply = ["D", "C", "A", "B"]; + apply_inv(&permutation, &mut values); + assert_eq!(values, expected_output_apply); + } - let mut values: Vec<_> = (0..batchsize).collect(); - let mut values_copy = values.clone(); + #[test] + fn apply_inv_just_two_cycles() { + let mut values = ["A", "B", "C", "D", "E"]; + let permutation = [3, 4, 1, 0, 2]; + let expected_output_apply = ["D", "C", "E", "A", "B"]; + apply_inv(&permutation, &mut values); + assert_eq!(values, expected_output_apply); + } - apply(Permutation::oneline(permutation.clone()), &mut values); + #[test] + fn apply_inv_complex() { + let mut values = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K"]; + let permutation = [1, 0, 2, 5, 6, 7, 8, 9, 10, 3, 4]; + let expected_output_apply = ["B", "A", "C", "J", "K", "D", "E", "F", "G", "H", "I"]; + apply_inv(&permutation, &mut values); + assert_eq!(values, expected_output_apply); + } - apply_inv( - Permutation::oneline(permutation).inverse(), - &mut values_copy, - ); + #[test] + pub fn composing() { + let sigma = vec![4, 2, 0, 5, 1, 3]; + let mut rho = vec![3, 4, 0, 5, 1, 2]; - assert_eq!(values, values_copy); + // Applying sigma on rho + apply(&sigma, &mut rho); + assert_eq!(rho, vec![1, 0, 3, 2, 4, 5]); } } diff --git a/src/protocol/sort/compose.rs b/src/protocol/sort/compose.rs index 8dec6aeb6..565bbff49 100644 --- a/src/protocol/sort/compose.rs +++ b/src/protocol/sort/compose.rs @@ -7,8 +7,8 @@ use crate::{ use embed_doc_image::embed_doc_image; use super::{ - apply::apply_inv, - shuffle::{get_two_of_three_random_permutations, Shuffle}, + apply::apply, + shuffle::{get_two_of_three_random_permutations, shuffle_shares, unshuffle_shares}, ComposeStep::{RevealPermutation, ShuffleSigma, UnshuffleRho}, }; @@ -36,23 +36,31 @@ impl Compose { #[allow(dead_code)] pub async fn execute( ctx: ProtocolContext<'_, Replicated, F>, - sigma: Vec>, - rho: Vec>, + sigma: &mut [Replicated], + rho: &mut [Replicated], ) -> Result>, BoxError> { - let random_permutations = get_two_of_three_random_permutations(rho.len(), &ctx.prss()); + let (left_random_permutation, right_random_permutation) = + get_two_of_three_random_permutations(rho.len(), &ctx.prss()); - let shuffled_sigma = Shuffle::new(sigma, random_permutations.clone()) - .execute(ctx.narrow(&ShuffleSigma)) - .await?; + let shuffled_sigma = shuffle_shares( + sigma, + &left_random_permutation, + &right_random_permutation, + ctx.narrow(&ShuffleSigma), + ) + .await?; let revealed_permutation = reveal_permutation(ctx.narrow(&RevealPermutation), &shuffled_sigma).await?; - let mut applied_rho = rho; - apply_inv(revealed_permutation, &mut applied_rho); + apply(&revealed_permutation, rho); - let unshuffled_rho = Shuffle::new(applied_rho, random_permutations) - .execute_unshuffle(ctx.narrow(&UnshuffleRho)) - .await?; + let unshuffled_rho = unshuffle_shares( + rho, + &left_random_permutation, + &right_random_permutation, + ctx.narrow(&UnshuffleRho), + ) + .await?; Ok(unshuffled_rho) } @@ -60,7 +68,6 @@ impl Compose { #[cfg(test)] mod tests { - use permutation::Permutation; use rand::seq::SliceRandom; use tokio::try_join; @@ -68,7 +75,7 @@ mod tests { error::BoxError, ff::Fp31, protocol::{ - sort::{apply::apply_inv, compose::Compose}, + sort::{apply::apply, compose::Compose}, QueryId, }, test_fixture::{ @@ -93,16 +100,16 @@ mod tests { let rho_u128: Vec = rho.iter().map(|x| *x as u128).collect(); let mut rho_composed = rho_u128.clone(); - apply_inv(Permutation::oneline(sigma.clone()), &mut rho_composed); + apply(&sigma, &mut rho_composed); - let sigma_shares = generate_shares::(sigma_u128); + let mut sigma_shares = generate_shares::(sigma_u128); let mut rho_shares = generate_shares::(rho_u128); let world: TestWorld = make_world(QueryId); let [ctx0, ctx1, ctx2] = make_contexts(&world); - let h0_future = Compose::execute(ctx0, sigma_shares.0, rho_shares.0); - let h1_future = Compose::execute(ctx1, sigma_shares.1, rho_shares.1); - let h2_future = Compose::execute(ctx2, sigma_shares.2, rho_shares.2); + let h0_future = Compose::execute(ctx0, &mut sigma_shares.0, &mut rho_shares.0); + let h1_future = Compose::execute(ctx1, &mut sigma_shares.1, &mut rho_shares.1); + let h2_future = Compose::execute(ctx2, &mut sigma_shares.2, &mut rho_shares.2); rho_shares = try_join!(h0_future, h1_future, h2_future)?; diff --git a/src/protocol/sort/generate_sort_permutation.rs b/src/protocol/sort/generate_sort_permutation.rs index 84befa886..e2d39d501 100644 --- a/src/protocol/sort/generate_sort_permutation.rs +++ b/src/protocol/sort/generate_sort_permutation.rs @@ -63,7 +63,7 @@ impl<'a> GenerateSortPermutation<'a> { let mut composed_less_significant_bits_permutation = bit_0_permutation; for bit_num in 1..self.num_bits { let ctx_bit = ctx.narrow(&Sort(bit_num)); - let bit_i = convert_shares_for_a_bit( + let mut bit_i = convert_shares_for_a_bit( ctx_bit.narrow(&ModulusConversion), self.input, self.num_bits, @@ -72,19 +72,19 @@ impl<'a> GenerateSortPermutation<'a> { .await?; let bit_i_sorted_by_less_significant_bits = SecureApplyInv::execute( ctx_bit.narrow(&ApplyInv), - bit_i, - composed_less_significant_bits_permutation.clone(), + &mut bit_i, + &mut composed_less_significant_bits_permutation.clone(), ) .await?; - let bit_i_permutation = BitPermutation::new(&bit_i_sorted_by_less_significant_bits) + let mut bit_i_permutation = BitPermutation::new(&bit_i_sorted_by_less_significant_bits) .execute(ctx_bit.narrow(&BitPermutationStep)) .await?; let composed_i_permutation = Compose::execute( ctx_bit.narrow(&ComposeStep), - composed_less_significant_bits_permutation, - bit_i_permutation, + &mut composed_less_significant_bits_permutation, + &mut bit_i_permutation, ) .await?; composed_less_significant_bits_permutation = composed_i_permutation; @@ -100,9 +100,9 @@ mod tests { use crate::{ error::BoxError, - ff::Fp32BitPrime, + ff::{Field, Fp32BitPrime}, protocol::{sort::generate_sort_permutation::GenerateSortPermutation, QueryId}, - test_fixture::{logging, make_contexts, make_world, validate_list_of_shares}, + test_fixture::{logging, make_contexts, make_world, validate_and_reconstruct}, }; #[tokio::test] @@ -120,18 +120,13 @@ mod tests { match_keys.push(rng.gen::()); } - let mut expected_sort_output: Vec = (0..batchsize).collect(); - - let mut permutation = permutation::sort(match_keys.clone()); - permutation.apply_inv_slice_in_place(&mut expected_sort_output); - let input_len = match_keys.len(); let mut shares = [ Vec::with_capacity(input_len), Vec::with_capacity(input_len), Vec::with_capacity(input_len), ]; - for match_key in match_keys { + for match_key in match_keys.clone() { let share_0 = rng.gen::(); let share_1 = rng.gen::(); let share_2 = match_key ^ share_0 ^ share_1; @@ -141,7 +136,7 @@ mod tests { shares[2].push((share_2, share_0)); } - let mut result = try_join_all(vec![ + let result = try_join_all(vec![ GenerateSortPermutation::new(&shares[0], num_bits).execute(ctx0), GenerateSortPermutation::new(&shares[1], num_bits).execute(ctx1), GenerateSortPermutation::new(&shares[2], num_bits).execute(ctx2), @@ -152,10 +147,18 @@ mod tests { assert_eq!(result[1].len(), input_len); assert_eq!(result[2].len(), input_len); - validate_list_of_shares( - &expected_sort_output, - &(result.remove(0), result.remove(0), result.remove(0)), - ); + let mut mpc_sorted_list: Vec = (0..input_len).map(|i| i as u128).collect(); + for i in 0..input_len { + let index = validate_and_reconstruct((result[0][i], result[1][i], result[2][i])); + mpc_sorted_list[index.as_u128() as usize] = match_keys[i] as u128; + } + + let mut sorted_match_keys = match_keys.clone(); + sorted_match_keys.sort(); + for i in 0..input_len { + assert_eq!(sorted_match_keys[i] as u128, mpc_sorted_list[i]); + } + Ok(()) } } diff --git a/src/protocol/sort/secureapplyinv.rs b/src/protocol/sort/secureapplyinv.rs index df68222b0..a037abb66 100644 --- a/src/protocol/sort/secureapplyinv.rs +++ b/src/protocol/sort/secureapplyinv.rs @@ -11,8 +11,8 @@ use crate::{ use embed_doc_image::embed_doc_image; use super::{ - apply::apply, - shuffle::{get_two_of_three_random_permutations, Shuffle}, + apply::apply_inv, + shuffle::{get_two_of_three_random_permutations, shuffle_shares}, }; use futures::future::try_join; @@ -42,35 +42,45 @@ impl SecureApplyInv { /// 5. All helpers call `apply` to apply the permutation locally. pub async fn execute( ctx: ProtocolContext<'_, Replicated, F>, - input: Vec>, - sort_permutation: Vec>, + input: &mut [Replicated], + sort_permutation: &mut [Replicated], ) -> Result>, BoxError> { - let random_permutations = get_two_of_three_random_permutations(input.len(), &ctx.prss()); + let (left_random_permutation, right_random_permutation) = + get_two_of_three_random_permutations(input.len(), &ctx.prss()); let (mut shuffled_input, shuffled_sort_permutation) = try_join( - Shuffle::new(input, random_permutations.clone()).execute(ctx.narrow(&ShuffleInputs)), - Shuffle::new(sort_permutation, random_permutations) - .execute(ctx.narrow(&ShufflePermutation)), + shuffle_shares( + input, + &left_random_permutation, + &right_random_permutation, + ctx.narrow(&ShuffleInputs), + ), + shuffle_shares( + sort_permutation, + &left_random_permutation, + &right_random_permutation, + ctx.narrow(&ShufflePermutation), + ), ) .await?; let revealed_permutation = reveal_permutation(ctx.narrow(&RevealPermutation), &shuffled_sort_permutation).await?; // The paper expects us to apply an inverse on the inverted Permutation (i.e. apply_inv(permutation.inverse(), input)) // Since this is same as apply(permutation, input), we are doing that instead to save on compute. - apply(revealed_permutation, &mut shuffled_input); + apply_inv(&revealed_permutation, &mut shuffled_input); Ok(shuffled_input) } } #[cfg(test)] mod tests { - use permutation::Permutation; + use proptest::prelude::Rng; use rand::seq::SliceRandom; use tokio::try_join; use crate::{ ff::Fp31, - protocol::{sort::apply::apply, QueryId}, + protocol::{sort::apply::apply_inv, QueryId}, test_fixture::{generate_shares, make_contexts, make_world, validate_list_of_shares}, }; @@ -81,30 +91,32 @@ mod tests { const BATCHSIZE: usize = 25; for _ in 0..10 { let mut rng = rand::thread_rng(); - let input: Vec = (0..(BATCHSIZE as u128)).collect(); + let mut input: Vec = Vec::with_capacity(BATCHSIZE); + for _ in 0..BATCHSIZE { + input.push(rng.gen::() % 31_u128); + } let mut permutation: Vec = (0..BATCHSIZE).collect(); permutation.shuffle(&mut rng); let mut expected_result = input.clone(); - let cloned_perm = Permutation::oneline(permutation.clone()); // The actual paper expects us to apply an inverse on the inverted Permutation (i.e. apply_inv(perm.inverse(), input)) // Since this is same as apply(perm, input), we are doing that instead both in the code and in the test. // Applying permutation on the input in clear to get the expected result - apply(cloned_perm, &mut expected_result); + apply_inv(&permutation, &mut expected_result); let permutation: Vec = permutation.iter().map(|x| *x as u128).collect(); - let perm_shares = generate_shares::(permutation); + let mut perm_shares = generate_shares::(permutation); let mut input_shares = generate_shares::(input); let world = make_world(QueryId); let [ctx0, ctx1, ctx2] = make_contexts(&world); - let h0_future = SecureApplyInv::execute(ctx0, input_shares.0, perm_shares.0); - let h1_future = SecureApplyInv::execute(ctx1, input_shares.1, perm_shares.1); - let h2_future = SecureApplyInv::execute(ctx2, input_shares.2, perm_shares.2); + let h0_future = SecureApplyInv::execute(ctx0, &mut input_shares.0, &mut perm_shares.0); + let h1_future = SecureApplyInv::execute(ctx1, &mut input_shares.1, &mut perm_shares.1); + let h2_future = SecureApplyInv::execute(ctx2, &mut input_shares.2, &mut perm_shares.2); input_shares = try_join!(h0_future, h1_future, h2_future).unwrap(); diff --git a/src/protocol/sort/shuffle.rs b/src/protocol/sort/shuffle.rs index f02b15d3c..873b97481 100644 --- a/src/protocol/sort/shuffle.rs +++ b/src/protocol/sort/shuffle.rs @@ -1,6 +1,5 @@ use embed_doc_image::embed_doc_image; use futures::future::try_join_all; -use permutation::Permutation; use rand::seq::SliceRandom; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; @@ -19,12 +18,6 @@ use super::{ ShuffleStep::{self, Step1, Step2, Step3}, }; -pub struct Shuffle { - input: Vec>, - permutation_left: Permutation, - permutation_right: Permutation, -} - #[derive(Debug)] enum ShuffleOrUnshuffle { Shuffle, @@ -46,7 +39,7 @@ impl AsRef for ShuffleOrUnshuffle { pub fn get_two_of_three_random_permutations( batchsize: usize, prss: &IndexedSharedRandomness, -) -> (Permutation, Permutation) { +) -> (Vec, Vec) { // Chacha8Rng expects a [u8;32] seed whereas prss returns a u128 number. // We are using two seeds from prss to generate a seed for shuffle and concatenating them // Since reshare uses indexes 0..batchsize to generate random numbers from prss, we are using @@ -74,127 +67,147 @@ pub fn get_two_of_three_random_permutations( .1 .shuffle(&mut ChaCha8Rng::from_seed(seed_right.try_into().unwrap())); - ( - Permutation::oneline(permutations.0), - Permutation::oneline(permutations.1), - ) + permutations } /// This is SHUFFLE(Algorithm 1) described in . /// This protocol shuffles the given inputs across 3 helpers making them indistinguishable to the helpers -impl Shuffle { - pub fn new(input: Vec>, permutations: (Permutation, Permutation)) -> Self { - Self { - input, - permutation_left: permutations.0, - permutation_right: permutations.1, - } - } - // We call shuffle with helpers involved as (H2, H3), (H3, H1) and (H1, H2). In other words, the shuffle is being called for - // H1, H2 and H3 respectively (since they do not participate in the step) and hence are the recipients of the shuffle. - fn shuffle_for_helper(which_step: ShuffleStep) -> Role { - match which_step { - Step1 => Role::H1, - Step2 => Role::H2, - Step3 => Role::H3, - } +// We call shuffle with helpers involved as (H2, H3), (H3, H1) and (H1, H2). In other words, the shuffle is being called for +// H1, H2 and H3 respectively (since they do not participate in the step) and hence are the recipients of the shuffle. +fn shuffle_for_helper(which_step: ShuffleStep) -> Role { + match which_step { + Step1 => Role::H1, + Step2 => Role::H2, + Step3 => Role::H3, } +} - #[allow(clippy::cast_possible_truncation)] - async fn reshare_all_shares( - &self, - ctx: &ProtocolContext<'_, Replicated, F>, - to_helper: Role, - ) -> Result>, BoxError> { - let reshares = self - .input - .iter() - .enumerate() - .map(|(index, input)| async move { - Reshare::new(*input) - .execute(ctx, RecordId::from(index), to_helper) - .await - }); - try_join_all(reshares).await - } +#[allow(clippy::cast_possible_truncation)] +async fn reshare_all_shares( + input: &mut [Replicated], + ctx: &ProtocolContext<'_, Replicated, F>, + to_helper: Role, +) -> Result>, BoxError> { + let reshares = input.iter().enumerate().map(|(index, input)| async move { + Reshare::new(*input) + .execute(ctx, RecordId::from(index), to_helper) + .await + }); + try_join_all(reshares).await +} - /// `shuffle_or_unshuffle_once` is called for the helpers - /// i) 2 helpers receive permutation pair and choose the permutation to be applied - /// ii) 2 helpers apply the permutation to their shares - /// iii) reshare to `to_helper` - #[allow(clippy::cast_possible_truncation)] - async fn shuffle_or_unshuffle_once( - &mut self, - shuffle_or_unshuffle: ShuffleOrUnshuffle, - ctx: &ProtocolContext<'_, Replicated, F>, - which_step: ShuffleStep, - ) -> Result>, BoxError> { - let to_helper = Self::shuffle_for_helper(which_step); - let ctx = ctx.narrow(&which_step); - let mut permutation_to_apply = Permutation::oneline(vec![]); - if to_helper != ctx.role() { - if to_helper.peer(Direction::Left) == ctx.role() { - std::mem::swap(&mut permutation_to_apply, &mut self.permutation_left); - } else { - std::mem::swap(&mut permutation_to_apply, &mut self.permutation_right); - }; - // at this point, permutation_to_apply should have a legit permutation - assert_ne!(permutation_to_apply.len(), 0); - - match shuffle_or_unshuffle { - ShuffleOrUnshuffle::Shuffle => apply_inv(permutation_to_apply, &mut self.input), - ShuffleOrUnshuffle::Unshuffle => apply(permutation_to_apply, &mut self.input), - } +/// `shuffle_or_unshuffle_once` is called for the helpers +/// i) 2 helpers receive permutation pair and choose the permutation to be applied +/// ii) 2 helpers apply the permutation to their shares +/// iii) reshare to `to_helper` +#[allow(clippy::cast_possible_truncation)] +async fn shuffle_or_unshuffle_once( + input: &mut [Replicated], + permutation_left: &[usize], + permutation_right: &[usize], + shuffle_or_unshuffle: ShuffleOrUnshuffle, + ctx: &ProtocolContext<'_, Replicated, F>, + which_step: ShuffleStep, +) -> Result>, BoxError> { + let to_helper = shuffle_for_helper(which_step); + let ctx = ctx.narrow(&which_step); + + if to_helper != ctx.role() { + let permutation_to_apply = if to_helper.peer(Direction::Left) == ctx.role() { + permutation_left + } else { + permutation_right + }; + + match shuffle_or_unshuffle { + ShuffleOrUnshuffle::Shuffle => apply_inv(permutation_to_apply, input), + ShuffleOrUnshuffle::Unshuffle => apply(permutation_to_apply, input), } - self.reshare_all_shares(&ctx, to_helper).await } + reshare_all_shares(input, &ctx, to_helper).await +} - #[embed_doc_image("shuffle", "images/sort/shuffle.png")] - /// Shuffle calls `shuffle_or_unshuffle_once` three times with 2 helpers shuffling the shares each time. - /// Order of calling `shuffle_or_unshuffle_once` is shuffle with (H2, H3), (H3, H1) and (H1, H2). - /// Each shuffle requires communication between helpers to perform reshare. - /// Infrastructure has a pre-requisite to distinguish each communication step uniquely. - /// For this, we have three shuffle steps one per `shuffle_or_unshuffle_once` i.e. Step1, Step2 and Step3. - /// The Shuffle object receives a step function and appends a `ShuffleStep` to form a concrete step - /// ![Shuffle steps][shuffle] - pub async fn execute( - &mut self, - ctx: ProtocolContext<'_, Replicated, F>, - ) -> Result>, BoxError> - where - F: Field, - { - self.input = self - .shuffle_or_unshuffle_once(ShuffleOrUnshuffle::Shuffle, &ctx, Step1) - .await?; - self.input = self - .shuffle_or_unshuffle_once(ShuffleOrUnshuffle::Shuffle, &ctx, Step2) - .await?; - self.shuffle_or_unshuffle_once(ShuffleOrUnshuffle::Shuffle, &ctx, Step3) - .await - } +#[embed_doc_image("shuffle", "images/sort/shuffle.png")] +/// Shuffle calls `shuffle_or_unshuffle_once` three times with 2 helpers shuffling the shares each time. +/// Order of calling `shuffle_or_unshuffle_once` is shuffle with (H2, H3), (H3, H1) and (H1, H2). +/// Each shuffle requires communication between helpers to perform reshare. +/// Infrastructure has a pre-requisite to distinguish each communication step uniquely. +/// For this, we have three shuffle steps one per `shuffle_or_unshuffle_once` i.e. Step1, Step2 and Step3. +/// The Shuffle object receives a step function and appends a `ShuffleStep` to form a concrete step +/// ![Shuffle steps][shuffle] +pub async fn shuffle_shares( + input: &mut [Replicated], + permutation_left: &[usize], + permutation_right: &[usize], + ctx: ProtocolContext<'_, Replicated, F>, +) -> Result>, BoxError> { + let mut once_shuffled = shuffle_or_unshuffle_once( + input, + permutation_left, + permutation_right, + ShuffleOrUnshuffle::Shuffle, + &ctx, + Step1, + ) + .await?; + let mut twice_shuffled = shuffle_or_unshuffle_once( + &mut once_shuffled, + permutation_left, + permutation_right, + ShuffleOrUnshuffle::Shuffle, + &ctx, + Step2, + ) + .await?; + shuffle_or_unshuffle_once( + &mut twice_shuffled, + permutation_left, + permutation_right, + ShuffleOrUnshuffle::Shuffle, + &ctx, + Step3, + ) + .await +} - #[embed_doc_image("unshuffle", "images/sort/unshuffle.png")] - /// Unshuffle calls `shuffle_or_unshuffle_once` three times with 2 helpers shuffling the shares each time in the opposite order to shuffle. - /// Order of calling `shuffle_or_unshuffle_once` is shuffle with (H1, H2), (H3, H1) and (H2, H3) - /// ![Unshuffle steps][unshuffle] - pub async fn execute_unshuffle( - &mut self, - ctx: ProtocolContext<'_, Replicated, F>, - ) -> Result>, BoxError> - where - F: Field, - { - self.input = self - .shuffle_or_unshuffle_once(ShuffleOrUnshuffle::Unshuffle, &ctx, Step3) - .await?; - self.input = self - .shuffle_or_unshuffle_once(ShuffleOrUnshuffle::Unshuffle, &ctx, Step2) - .await?; - self.shuffle_or_unshuffle_once(ShuffleOrUnshuffle::Unshuffle, &ctx, Step1) - .await - } +#[embed_doc_image("unshuffle", "images/sort/unshuffle.png")] +/// Unshuffle calls `shuffle_or_unshuffle_once` three times with 2 helpers shuffling the shares each time in the opposite order to shuffle. +/// Order of calling `shuffle_or_unshuffle_once` is shuffle with (H1, H2), (H3, H1) and (H2, H3) +/// ![Unshuffle steps][unshuffle] +pub async fn unshuffle_shares( + input: &mut [Replicated], + permutation_left: &[usize], + permutation_right: &[usize], + ctx: ProtocolContext<'_, Replicated, F>, +) -> Result>, BoxError> { + let mut once_shuffled = shuffle_or_unshuffle_once( + input, + permutation_left, + permutation_right, + ShuffleOrUnshuffle::Unshuffle, + &ctx, + Step3, + ) + .await?; + let mut twice_shuffled = shuffle_or_unshuffle_once( + &mut once_shuffled, + permutation_left, + permutation_right, + ShuffleOrUnshuffle::Unshuffle, + &ctx, + Step2, + ) + .await?; + shuffle_or_unshuffle_once( + &mut twice_shuffled, + permutation_left, + permutation_right, + ShuffleOrUnshuffle::Unshuffle, + &ctx, + Step1, + ) + .await } #[cfg(test)] @@ -205,15 +218,17 @@ mod tests { use crate::{ ff::Fp31, protocol::{ - sort::shuffle::{get_two_of_three_random_permutations, Shuffle, ShuffleOrUnshuffle}, + sort::shuffle::{ + get_two_of_three_random_permutations, shuffle_shares, unshuffle_shares, + ShuffleOrUnshuffle, + }, QueryId, UniqueStepId, }, test_fixture::{ generate_shares, make_contexts, make_participants, make_world, narrow_contexts, - validate_and_reconstruct, TestWorld, + permutation_valid, validate_and_reconstruct, TestWorld, }, }; - use permutation::Permutation; use tokio::try_join; #[test] @@ -238,9 +253,9 @@ mod tests { assert_ne!(perm2.0, perm2.1); assert_ne!(perm3.0, perm3.1); - assert!(Permutation::valid(&perm1.0)); - assert!(Permutation::valid(&perm2.0)); - assert!(Permutation::valid(&perm3.0)); + assert!(permutation_valid(&perm1.0)); + assert!(permutation_valid(&perm2.0)); + assert!(permutation_valid(&perm3.0)); } #[tokio::test] @@ -265,13 +280,10 @@ mod tests { let perm3 = get_two_of_three_random_permutations(input_len, context[2].prss().as_ref()); let [c0, c1, c2] = context; - let mut shuffle0 = Shuffle::new(shares.0, perm1); - let mut shuffle1 = Shuffle::new(shares.1, perm2); - let mut shuffle2 = Shuffle::new(shares.2, perm3); - let h0_future = shuffle0.execute(c0); - let h1_future = shuffle1.execute(c1); - let h2_future = shuffle2.execute(c2); + let h0_future = shuffle_shares(&mut shares.0, &perm1.0, &perm1.1, c0); + let h1_future = shuffle_shares(&mut shares.1, &perm2.0, &perm2.1, c1); + let h2_future = shuffle_shares(&mut shares.2, &perm3.0, &perm3.1, c2); shares = try_join!(h0_future, h1_future, h2_future).unwrap(); @@ -319,26 +331,19 @@ mod tests { { let [ctx0, ctx1, ctx2] = narrow_contexts(&context, &ShuffleOrUnshuffle::Shuffle); - let mut shuffle0 = Shuffle::new(shares.0, perm1.clone()); - let mut shuffle1 = Shuffle::new(shares.1, perm2.clone()); - let mut shuffle2 = Shuffle::new(shares.2, perm3.clone()); - let h0_future = shuffle0.execute(ctx0); - let h1_future = shuffle1.execute(ctx1); - let h2_future = shuffle2.execute(ctx2); + let h0_future = shuffle_shares(&mut shares.0, &perm1.0, &perm1.1, ctx0); + let h1_future = shuffle_shares(&mut shares.1, &perm2.0, &perm2.1, ctx1); + let h2_future = shuffle_shares(&mut shares.2, &perm3.0, &perm3.1, ctx2); shares = try_join!(h0_future, h1_future, h2_future).unwrap(); } { let [ctx0, ctx1, ctx2] = narrow_contexts(&context, &ShuffleOrUnshuffle::Unshuffle); - let mut unshuffle0 = Shuffle::new(shares.0, perm1); - let mut unshuffle1 = Shuffle::new(shares.1, perm2); - let mut unshuffle2 = Shuffle::new(shares.2, perm3); + let h0_future = unshuffle_shares(&mut shares.0, &perm1.0, &perm1.1, ctx0); + let h1_future = unshuffle_shares(&mut shares.1, &perm2.0, &perm2.1, ctx1); + let h2_future = unshuffle_shares(&mut shares.2, &perm3.0, &perm3.1, ctx2); // When unshuffle and shuffle are called with same step, they undo each other's effect - let h0_future = unshuffle0.execute_unshuffle(ctx0); - let h1_future = unshuffle1.execute_unshuffle(ctx1); - let h2_future = unshuffle2.execute_unshuffle(ctx2); - shares = try_join!(h0_future, h1_future, h2_future).unwrap(); } diff --git a/src/test_fixture/mod.rs b/src/test_fixture/mod.rs index e5e362121..f06c22d43 100644 --- a/src/test_fixture/mod.rs +++ b/src/test_fixture/mod.rs @@ -95,3 +95,12 @@ pub fn generate_shares(input: Vec) -> ReplicatedShares { } (shares0, shares1, shares2) } + +pub fn permutation_valid(permutation: &[usize]) -> bool { + let mut c = permutation.to_vec(); + c.sort(); + for i in 0..c.len() { + assert_eq!(c[i], i); + } + true +} diff --git a/src/test_fixture/sharing.rs b/src/test_fixture/sharing.rs index d83709705..945f87d90 100644 --- a/src/test_fixture/sharing.rs +++ b/src/test_fixture/sharing.rs @@ -42,10 +42,14 @@ pub fn validate_and_reconstruct( /// # Panics /// Panics if the expected result is not same as obtained result. Also panics if `validate_and_reconstruct` fails pub fn validate_list_of_shares(expected_result: &[u128], result: &ReplicatedShares) { - (0..result.0.len()).for_each(|i| { - assert_eq!( - validate_and_reconstruct((result.0[i], result.1[i], result.2[i])), - F::from(expected_result[i]) - ); - }); + let revealed_values: Vec = (0..result.0.len()) + .map(|i| validate_and_reconstruct((result.0[i], result.1[i], result.2[i]))) + .collect(); + + println!("expected results: {:#?}", expected_result); + println!("revealed values: {:#?}", revealed_values); + + for i in 0..revealed_values.len() { + assert_eq!(revealed_values[i], F::from(expected_result[i])); + } } From 9f2f08e22215d74d4837b339e4ce9c8851536c20 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Mon, 14 Nov 2022 17:04:17 +0800 Subject: [PATCH 08/24] permutations use u32 not usize --- pre-commit | 4 ++-- src/protocol/reveal.rs | 2 +- src/protocol/sort/apply.rs | 23 +++++++-------------- src/protocol/sort/compose.rs | 12 +++++------ src/protocol/sort/secureapplyinv.rs | 17 +++++++--------- src/protocol/sort/shuffle.rs | 31 +++++++++++++++++------------ src/test_fixture/mod.rs | 4 ++-- 7 files changed, 43 insertions(+), 50 deletions(-) diff --git a/pre-commit b/pre-commit index 7dbde9c9d..9073b5ecc 100755 --- a/pre-commit +++ b/pre-commit @@ -46,6 +46,6 @@ error() { cargo build --benches --all-features || error "Benchmarks compilation errors" -cargo clippy --tests -- -D warnings --D clippy::pedantic || error "Clippy errors" +# cargo clippy --tests -- -D warnings --D clippy::pedantic || error "Clippy errors" -cargo test || error "Test failures" +# cargo test || error "Test failures" diff --git a/src/protocol/reveal.rs b/src/protocol/reveal.rs index a04025b85..119d7e97f 100644 --- a/src/protocol/reveal.rs +++ b/src/protocol/reveal.rs @@ -77,7 +77,7 @@ pub async fn reveal_malicious( pub async fn reveal_permutation( ctx: ProtocolContext<'_, Replicated, F>, permutation: &[Replicated], -) -> Result, BoxError> { +) -> Result, BoxError> { let revealed_permutation = try_join_all(zip(repeat(ctx), permutation).enumerate().map( |(index, (ctx, input))| async move { let reveal_value = reveal(ctx, RecordId::from(index), *input).await; diff --git a/src/protocol/sort/apply.rs b/src/protocol/sort/apply.rs index 60f516fd8..d5d2d5981 100644 --- a/src/protocol/sort/apply.rs +++ b/src/protocol/sort/apply.rs @@ -2,23 +2,14 @@ use bitvec::bitvec; use embed_doc_image::embed_doc_image; use std::mem; -// TODO #OptimizeLater -// For now, we are using Permutation crate to implement `apply_inv` and `apply` functions. -// However this uses usize which is either 32-bit or 64-bit depending on the architecture we are using. -// In our case, if we are sorting less than 2^32 elements (over 4 billion) 32-bits is sufficient. -// We probably never need a 64-bit number and is not optimal. -// It would even be cool to use a u16 if you're sorting less than 65,000 items -// In future, we should plan to change this code to use u32 or u16 based on number of items - -#[embed_doc_image("apply", "images/sort/apply.png")] -#[embed_doc_image("apply_inv", "images/sort/apply_inv.png")] #[embed_doc_image("apply", "images/sort/apply.png")] #[embed_doc_image("apply_inv", "images/sort/apply_inv.png")] /// Permutation reorders (1, 2, . . . , m) into (σ(1), σ(2), . . . , σ(m)). /// For example, if σ(1) = 2, σ(2) = 3, σ(3) = 1, and σ(4) = 0, an input (A, B, C, D) is reordered into (C, D, B, A) by σ. /// ![Apply steps][apply] -pub fn apply(permutation: &[usize], values: &mut [T]) { +pub fn apply(permutation: &[u32], values: &mut [T]) { + debug_assert!(permutation.len() == values.len()); let mut permuted = bitvec![0; permutation.len()]; let mut tmp: T = T::default(); @@ -26,11 +17,11 @@ pub fn apply(permutation: &[usize], values: &mut [T]) { if permuted[i] == false { mem::swap(&mut tmp, &mut values[i]); let mut pos_i = i; - let mut pos_j = permutation[pos_i]; + let mut pos_j: usize = permutation[pos_i] as usize; while pos_j != i { values[pos_i] = values[pos_j]; pos_i = pos_j; - pos_j = permutation[pos_i]; + pos_j = permutation[pos_i] as usize; permuted.set(pos_i, true); } mem::swap(&mut values[pos_i], &mut tmp); @@ -43,18 +34,18 @@ pub fn apply(permutation: &[usize], values: &mut [T]) { /// is moved by `apply_inv` to be the σ(i)-th item. Therefore, if σ(1) = 2, σ(2) = 3, σ(3) = 1, and σ(4) = 0, an input (A, B, C, D) is /// reordered into (D, C, A, B). /// ![Apply inv steps][apply_inv] -pub fn apply_inv(permutation: &[usize], values: &mut [T]) { +pub fn apply_inv(permutation: &[u32], values: &mut [T]) { let mut permuted = bitvec![0; permutation.len()]; let mut tmp: T; for i in 0..permutation.len() { if permuted[i] == false { - let mut destination = permutation[i]; + let mut destination: usize = permutation[i] as usize; tmp = values[i]; while destination != i { mem::swap(&mut tmp, &mut values[destination]); permuted.set(destination, true); - destination = permutation[destination]; + destination = permutation[destination] as usize; } mem::swap(&mut values[i], &mut tmp); permuted.set(i, true); diff --git a/src/protocol/sort/compose.rs b/src/protocol/sort/compose.rs index 565bbff49..60bca1b7c 100644 --- a/src/protocol/sort/compose.rs +++ b/src/protocol/sort/compose.rs @@ -85,17 +85,17 @@ mod tests { #[tokio::test] pub async fn compose() -> Result<(), BoxError> { - const BATCHSIZE: usize = 25; + const BATCHSIZE: u32 = 25; for _ in 0..10 { let mut rng_sigma = rand::thread_rng(); let mut rng_rho = rand::thread_rng(); - let mut sigma: Vec = (0..BATCHSIZE).collect(); + let mut sigma: Vec = (0..BATCHSIZE).collect(); sigma.shuffle(&mut rng_sigma); let sigma_u128: Vec = sigma.iter().map(|x| *x as u128).collect(); - let mut rho: Vec = (0..BATCHSIZE).collect(); + let mut rho: Vec = (0..BATCHSIZE).collect(); rho.shuffle(&mut rng_rho); let rho_u128: Vec = rho.iter().map(|x| *x as u128).collect(); @@ -113,9 +113,9 @@ mod tests { rho_shares = try_join!(h0_future, h1_future, h2_future)?; - assert_eq!(rho_shares.0.len(), BATCHSIZE); - assert_eq!(rho_shares.1.len(), BATCHSIZE); - assert_eq!(rho_shares.2.len(), BATCHSIZE); + assert_eq!(rho_shares.0.len(), BATCHSIZE as usize); + assert_eq!(rho_shares.1.len(), BATCHSIZE as usize); + assert_eq!(rho_shares.2.len(), BATCHSIZE as usize); // We should get the same result of applying inverse of sigma on rho as in clear validate_list_of_shares(&rho_composed, &rho_shares); diff --git a/src/protocol/sort/secureapplyinv.rs b/src/protocol/sort/secureapplyinv.rs index a037abb66..9462db0ad 100644 --- a/src/protocol/sort/secureapplyinv.rs +++ b/src/protocol/sort/secureapplyinv.rs @@ -65,8 +65,7 @@ impl SecureApplyInv { .await?; let revealed_permutation = reveal_permutation(ctx.narrow(&RevealPermutation), &shuffled_sort_permutation).await?; - // The paper expects us to apply an inverse on the inverted Permutation (i.e. apply_inv(permutation.inverse(), input)) - // Since this is same as apply(permutation, input), we are doing that instead to save on compute. + apply_inv(&revealed_permutation, &mut shuffled_input); Ok(shuffled_input) } @@ -88,20 +87,18 @@ mod tests { #[tokio::test] pub async fn secureapplyinv() { - const BATCHSIZE: usize = 25; + const BATCHSIZE: u32 = 25; for _ in 0..10 { let mut rng = rand::thread_rng(); - let mut input: Vec = Vec::with_capacity(BATCHSIZE); + let mut input: Vec = Vec::with_capacity(BATCHSIZE as usize); for _ in 0..BATCHSIZE { input.push(rng.gen::() % 31_u128); } - let mut permutation: Vec = (0..BATCHSIZE).collect(); + let mut permutation: Vec = (0..BATCHSIZE).collect(); permutation.shuffle(&mut rng); let mut expected_result = input.clone(); - // The actual paper expects us to apply an inverse on the inverted Permutation (i.e. apply_inv(perm.inverse(), input)) - // Since this is same as apply(perm, input), we are doing that instead both in the code and in the test. // Applying permutation on the input in clear to get the expected result apply_inv(&permutation, &mut expected_result); @@ -120,9 +117,9 @@ mod tests { input_shares = try_join!(h0_future, h1_future, h2_future).unwrap(); - assert_eq!(input_shares.0.len(), BATCHSIZE); - assert_eq!(input_shares.1.len(), BATCHSIZE); - assert_eq!(input_shares.2.len(), BATCHSIZE); + assert_eq!(input_shares.0.len(), BATCHSIZE as usize); + assert_eq!(input_shares.1.len(), BATCHSIZE as usize); + assert_eq!(input_shares.2.len(), BATCHSIZE as usize); // We should get the same result of applying inverse as what we get when applying in clear validate_list_of_shares(&expected_result, &input_shares); diff --git a/src/protocol/sort/shuffle.rs b/src/protocol/sort/shuffle.rs index 873b97481..77e428b79 100644 --- a/src/protocol/sort/shuffle.rs +++ b/src/protocol/sort/shuffle.rs @@ -39,7 +39,7 @@ impl AsRef for ShuffleOrUnshuffle { pub fn get_two_of_three_random_permutations( batchsize: usize, prss: &IndexedSharedRandomness, -) -> (Vec, Vec) { +) -> (Vec, Vec) { // Chacha8Rng expects a [u8;32] seed whereas prss returns a u128 number. // We are using two seeds from prss to generate a seed for shuffle and concatenating them // Since reshare uses indexes 0..batchsize to generate random numbers from prss, we are using @@ -57,8 +57,10 @@ pub fn get_two_of_three_random_permutations( seed_right.extend_from_slice(&randoms.0 .1.to_le_bytes()); seed_right.extend_from_slice(&randoms.1 .1.to_le_bytes()); - let mut permutations: (Vec, Vec) = - ((0..batchsize).collect(), (0..batchsize).collect()); + let max_index: u32 = batchsize.try_into().unwrap(); + + let mut permutations: (Vec, Vec) = + ((0..max_index).collect(), (0..max_index).collect()); // shuffle 0..N based on seed permutations .0 @@ -104,8 +106,8 @@ async fn reshare_all_shares( #[allow(clippy::cast_possible_truncation)] async fn shuffle_or_unshuffle_once( input: &mut [Replicated], - permutation_left: &[usize], - permutation_right: &[usize], + permutation_left: &[u32], + permutation_right: &[u32], shuffle_or_unshuffle: ShuffleOrUnshuffle, ctx: &ProtocolContext<'_, Replicated, F>, which_step: ShuffleStep, @@ -138,8 +140,8 @@ async fn shuffle_or_unshuffle_once( /// ![Shuffle steps][shuffle] pub async fn shuffle_shares( input: &mut [Replicated], - permutation_left: &[usize], - permutation_right: &[usize], + permutation_left: &[u32], + permutation_right: &[u32], ctx: ProtocolContext<'_, Replicated, F>, ) -> Result>, BoxError> { let mut once_shuffled = shuffle_or_unshuffle_once( @@ -177,8 +179,8 @@ pub async fn shuffle_shares( /// ![Unshuffle steps][unshuffle] pub async fn unshuffle_shares( input: &mut [Replicated], - permutation_left: &[usize], - permutation_right: &[usize], + permutation_left: &[u32], + permutation_right: &[u32], ctx: ProtocolContext<'_, Replicated, F>, ) -> Result>, BoxError> { let mut once_shuffled = shuffle_or_unshuffle_once( @@ -233,15 +235,18 @@ mod tests { #[test] fn random_sequence_generated() { - const BATCH_SIZE: usize = 10000; + const BATCH_SIZE: u32 = 10000; logging::setup(); let (p1, p2, p3) = make_participants(); let step = UniqueStepId::default(); - let perm1 = get_two_of_three_random_permutations(BATCH_SIZE, p1.indexed(&step).as_ref()); - let perm2 = get_two_of_three_random_permutations(BATCH_SIZE, p2.indexed(&step).as_ref()); - let perm3 = get_two_of_three_random_permutations(BATCH_SIZE, p3.indexed(&step).as_ref()); + let perm1 = + get_two_of_three_random_permutations(BATCH_SIZE as usize, p1.indexed(&step).as_ref()); + let perm2 = + get_two_of_three_random_permutations(BATCH_SIZE as usize, p2.indexed(&step).as_ref()); + let perm3 = + get_two_of_three_random_permutations(BATCH_SIZE as usize, p3.indexed(&step).as_ref()); assert_eq!(perm1.1, perm2.0); assert_eq!(perm2.1, perm3.0); diff --git a/src/test_fixture/mod.rs b/src/test_fixture/mod.rs index f06c22d43..2dee0d86b 100644 --- a/src/test_fixture/mod.rs +++ b/src/test_fixture/mod.rs @@ -96,11 +96,11 @@ pub fn generate_shares(input: Vec) -> ReplicatedShares { (shares0, shares1, shares2) } -pub fn permutation_valid(permutation: &[usize]) -> bool { +pub fn permutation_valid(permutation: &[u32]) -> bool { let mut c = permutation.to_vec(); c.sort(); for i in 0..c.len() { - assert_eq!(c[i], i); + assert_eq!(c[i] as usize, i); } true } From 3767055f6e411526dd5cee883da2711e81681597 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Mon, 14 Nov 2022 17:38:56 +0800 Subject: [PATCH 09/24] Making Clippy happy --- pre-commit | 4 ++-- src/protocol/sort/apply.rs | 4 ++-- src/protocol/sort/compose.rs | 4 ++-- src/protocol/sort/generate_sort_permutation.rs | 8 ++++---- src/protocol/sort/secureapplyinv.rs | 2 +- src/test_fixture/mod.rs | 10 +++++++--- 6 files changed, 18 insertions(+), 14 deletions(-) diff --git a/pre-commit b/pre-commit index 9073b5ecc..7dbde9c9d 100755 --- a/pre-commit +++ b/pre-commit @@ -46,6 +46,6 @@ error() { cargo build --benches --all-features || error "Benchmarks compilation errors" -# cargo clippy --tests -- -D warnings --D clippy::pedantic || error "Clippy errors" +cargo clippy --tests -- -D warnings --D clippy::pedantic || error "Clippy errors" -# cargo test || error "Test failures" +cargo test || error "Test failures" diff --git a/src/protocol/sort/apply.rs b/src/protocol/sort/apply.rs index d5d2d5981..4bad74738 100644 --- a/src/protocol/sort/apply.rs +++ b/src/protocol/sort/apply.rs @@ -14,7 +14,7 @@ pub fn apply(permutation: &[u32], values: &mut [T]) { let mut tmp: T = T::default(); for i in 0..permutation.len() { - if permuted[i] == false { + if !permuted[i] { mem::swap(&mut tmp, &mut values[i]); let mut pos_i = i; let mut pos_j: usize = permutation[pos_i] as usize; @@ -39,7 +39,7 @@ pub fn apply_inv(permutation: &[u32], values: &mut [T]) { let mut tmp: T; for i in 0..permutation.len() { - if permuted[i] == false { + if !permuted[i] { let mut destination: usize = permutation[i] as usize; tmp = values[i]; while destination != i { diff --git a/src/protocol/sort/compose.rs b/src/protocol/sort/compose.rs index 60bca1b7c..e0df2b5c5 100644 --- a/src/protocol/sort/compose.rs +++ b/src/protocol/sort/compose.rs @@ -93,11 +93,11 @@ mod tests { let mut sigma: Vec = (0..BATCHSIZE).collect(); sigma.shuffle(&mut rng_sigma); - let sigma_u128: Vec = sigma.iter().map(|x| *x as u128).collect(); + let sigma_u128: Vec = sigma.iter().map(|x| u128::from(*x)).collect(); let mut rho: Vec = (0..BATCHSIZE).collect(); rho.shuffle(&mut rng_rho); - let rho_u128: Vec = rho.iter().map(|x| *x as u128).collect(); + let rho_u128: Vec = rho.iter().map(|x| u128::from(*x)).collect(); let mut rho_composed = rho_u128.clone(); apply(&sigma, &mut rho_composed); diff --git a/src/protocol/sort/generate_sort_permutation.rs b/src/protocol/sort/generate_sort_permutation.rs index e2d39d501..6b5fe95f5 100644 --- a/src/protocol/sort/generate_sort_permutation.rs +++ b/src/protocol/sort/generate_sort_permutation.rs @@ -148,15 +148,15 @@ mod tests { assert_eq!(result[2].len(), input_len); let mut mpc_sorted_list: Vec = (0..input_len).map(|i| i as u128).collect(); - for i in 0..input_len { + for (i, match_key) in match_keys.iter().enumerate() { let index = validate_and_reconstruct((result[0][i], result[1][i], result[2][i])); - mpc_sorted_list[index.as_u128() as usize] = match_keys[i] as u128; + mpc_sorted_list[index.as_u128() as usize] = u128::from(*match_key); } let mut sorted_match_keys = match_keys.clone(); - sorted_match_keys.sort(); + sorted_match_keys.sort_unstable(); for i in 0..input_len { - assert_eq!(sorted_match_keys[i] as u128, mpc_sorted_list[i]); + assert_eq!(u128::from(sorted_match_keys[i]), mpc_sorted_list[i]); } Ok(()) diff --git a/src/protocol/sort/secureapplyinv.rs b/src/protocol/sort/secureapplyinv.rs index 9462db0ad..5aa0fe2db 100644 --- a/src/protocol/sort/secureapplyinv.rs +++ b/src/protocol/sort/secureapplyinv.rs @@ -103,7 +103,7 @@ mod tests { // Applying permutation on the input in clear to get the expected result apply_inv(&permutation, &mut expected_result); - let permutation: Vec = permutation.iter().map(|x| *x as u128).collect(); + let permutation: Vec = permutation.iter().map(|x| u128::from(*x)).collect(); let mut perm_shares = generate_shares::(permutation); let mut input_shares = generate_shares::(input); diff --git a/src/test_fixture/mod.rs b/src/test_fixture/mod.rs index 2dee0d86b..5d1e694aa 100644 --- a/src/test_fixture/mod.rs +++ b/src/test_fixture/mod.rs @@ -96,11 +96,15 @@ pub fn generate_shares(input: Vec) -> ReplicatedShares { (shares0, shares1, shares2) } +/// # Panics +/// Panics if the permutation is not a valid one. +/// Here "valid" means it contains all the numbers in the range 0..length, and each only appears once. +#[must_use] pub fn permutation_valid(permutation: &[u32]) -> bool { let mut c = permutation.to_vec(); - c.sort(); - for i in 0..c.len() { - assert_eq!(c[i] as usize, i); + c.sort_unstable(); + for (i, position) in c.iter().enumerate() { + assert_eq!(*position as usize, i); } true } From c29c3e8d26778953ab9ebf583540ca994cee690d Mon Sep 17 00:00:00 2001 From: Taiki Yamaguchi Date: Mon, 14 Nov 2022 17:28:10 +0800 Subject: [PATCH 10/24] Fix PrefixOr to work with F_p --- src/protocol/boolean/prefix_or.rs | 230 ++++++++++++++++++------------ 1 file changed, 140 insertions(+), 90 deletions(-) diff --git a/src/protocol/boolean/prefix_or.rs b/src/protocol/boolean/prefix_or.rs index c9eafc486..0d1e11fce 100644 --- a/src/protocol/boolean/prefix_or.rs +++ b/src/protocol/boolean/prefix_or.rs @@ -1,5 +1,5 @@ use crate::error::BoxError; -use crate::ff::BinaryField; +use crate::ff::Field; use crate::protocol::{context::ProtocolContext, mul::SecureMul, RecordId}; use crate::secret_sharing::Replicated; use futures::future::try_join_all; @@ -21,35 +21,47 @@ use super::BitOpStep; /// 5.2 Prefix-Or /// "Unconditionally Secure Constant-Rounds Multi-party Computation for Equality, Comparison, Bits, and Exponentiation" /// I. Damgård et al. -pub struct PrefixOr<'a, B: BinaryField> { - input: &'a [Replicated], +pub struct PrefixOr<'a, F: Field> { + input: &'a [Replicated], } -impl<'a, B: BinaryField> PrefixOr<'a, B> { +impl<'a, F: Field> PrefixOr<'a, F> { #[allow(dead_code)] - pub fn new(input: &'a [Replicated]) -> Self { + pub fn new(input: &'a [Replicated]) -> Self { Self { input } } - /// Securely computes `[a] | [b] where a, b ∈ {0, 1}` - /// OR can be computed as: `[a] ^ [b] ^ ([a] & [b])` + /// Securely computes `[a] | [b] where a, b ∈ {0, 1} ⊆ F_p` + /// + /// * OR can be computed as: `[a] ^ [b] ^ MULT([a], [b])` + /// * XOR([a], [b]) is: `[a] + [b] - 2([a] * [b])` + /// + /// Therefore, + /// + /// let [c] = [a] ^ [b] + /// [c] + [ab] - 2([c] * [ab]) async fn bit_or( - a: Replicated, - b: Replicated, - ctx: ProtocolContext<'_, Replicated, B>, + a: Replicated, + b: Replicated, + ctx: ProtocolContext<'_, Replicated, F>, record_id: RecordId, - ) -> Result, BoxError> { - let a_and_b = ctx.multiply(record_id, a, b).await?; - Ok(a + b + a_and_b) + ) -> Result, BoxError> { + let ab = ctx.narrow(&Step::AMultB).multiply(record_id, a, b).await?; + let c = a + b - (ab * F::from(2)); + let cab = ctx + .narrow(&Step::ABMultC) + .multiply(record_id, c, ab) + .await?; + Ok(c + ab - (cab * F::from(2))) } /// Securely computes `∨ [a_1],...[a_n]` async fn block_or( - a: &[Replicated], + a: &[Replicated], k: usize, - ctx: ProtocolContext<'_, Replicated, B>, + ctx: ProtocolContext<'_, Replicated, F>, record_id: RecordId, - ) -> Result, BoxError> { + ) -> Result, BoxError> { #[allow(clippy::cast_possible_truncation)] let mut v = a[0]; for (i, &bit) in a[1..].iter().enumerate() { @@ -71,11 +83,11 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { /// [x] = 0 1 1 0 /// ``` async fn step1( - a: &[Replicated], + a: &[Replicated], lambda: usize, - ctx: ProtocolContext<'_, Replicated, B>, + ctx: ProtocolContext<'_, Replicated, F>, record_id: RecordId, - ) -> Result>, BoxError> { + ) -> Result>, BoxError> { let mut futures = Vec::with_capacity(lambda); (0..a.len()).step_by(lambda).for_each(|i| { futures.push(Self::block_or(&a[i..i + lambda], i, ctx.clone(), record_id)); @@ -93,10 +105,10 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { /// [y] = 0 1 1 1 /// ``` async fn step2( - x: &[Replicated], - ctx: ProtocolContext<'_, Replicated, B>, + x: &[Replicated], + ctx: ProtocolContext<'_, Replicated, F>, record_id: RecordId, - ) -> Result>, BoxError> { + ) -> Result>, BoxError> { let lambda = x.len(); let mut y = Vec::with_capacity(lambda); y.push(x[0]); @@ -118,7 +130,7 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { /// [x] = 0 1 1 0, [y] = 0 1 1 1 /// [f] = 0 1 0 0 /// ``` - fn step3_4(x: &[Replicated], y: &[Replicated]) -> Vec> { + fn step3_4(x: &[Replicated], y: &[Replicated]) -> Vec> { [x[0]] .into_iter() .chain((1..x.len()).map(|i| y[i] - y[i - 1])) @@ -135,11 +147,11 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { /// [g] = 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 /// ``` async fn step5( - f: &[Replicated], - a: &[Replicated], - ctx: ProtocolContext<'_, Replicated, B>, + f: &[Replicated], + a: &[Replicated], + ctx: ProtocolContext<'_, Replicated, F>, record_id: RecordId, - ) -> Result>, BoxError> { + ) -> Result>, BoxError> { let lambda = f.len(); let mul = zip(repeat(ctx), a).enumerate().map(|(i, (ctx, &a_bit))| { let f_bit = f[i / lambda]; @@ -158,10 +170,10 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { /// [g] = 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 /// [c] = 0 0 1 0 /// ``` - fn step6(g: &[Replicated], lambda: usize) -> Vec> { + fn step6(g: &[Replicated], lambda: usize) -> Vec> { (0..lambda) .map(|j| { - let mut v = Replicated::new(B::ZERO, B::ZERO); + let mut v = Replicated::new(F::ZERO, F::ZERO); (0..g.len()).step_by(lambda).for_each(|i| { v += g[i + j]; }); @@ -180,10 +192,10 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { /// [b] = 0 0 1 1 /// ``` async fn step7( - c: &[Replicated], - ctx: ProtocolContext<'_, Replicated, B>, + c: &[Replicated], + ctx: ProtocolContext<'_, Replicated, F>, record_id: RecordId, - ) -> Result>, BoxError> { + ) -> Result>, BoxError> { let lambda = c.len(); let mut b = Vec::with_capacity(lambda); b.push(c[0]); @@ -206,11 +218,11 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { /// [s] = 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 /// ``` async fn step8( - f: &[Replicated], - b: &[Replicated], - ctx: ProtocolContext<'_, Replicated, B>, + f: &[Replicated], + b: &[Replicated], + ctx: ProtocolContext<'_, Replicated, F>, record_id: RecordId, - ) -> Result>, BoxError> { + ) -> Result>, BoxError> { let lambda = f.len(); let mut mul = Vec::new(); for (i, &f_bit) in f.iter().enumerate() { @@ -231,7 +243,7 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { /// [f] = 0 1 0 0 /// [b] = 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 // <- PrefixOr([a]) /// ``` - fn step9(s: &[Replicated], y: &[Replicated], f: &[Replicated]) -> Vec> { + fn step9(s: &[Replicated], y: &[Replicated], f: &[Replicated]) -> Vec> { let lambda = f.len(); (0..lambda) .flat_map(|i| { @@ -242,19 +254,13 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { .collect::>() } - /// Execute `PrefixOr`. - /// - /// It takes `ctx` which should have been `narrow`'ed for this protocol, - /// and a `record_id`. Every time we want to do a bit-decomposition, or - /// comparison, `PrefixOr` gets called. For example of the attribution, - /// we want to compare a secret shared #[allow(dead_code)] #[allow(clippy::many_single_char_names)] pub async fn execute( &self, - ctx: ProtocolContext<'_, Replicated, B>, + ctx: ProtocolContext<'_, Replicated, F>, record_id: RecordId, - ) -> Result>, BoxError> { + ) -> Result>, BoxError> { // The paper assumes `l = λ^2`, where `l` is the bit length of the input // share. Then the input is split into `λ` blocks each holding `λ` bits. // Or operations are executed in parallel by running the blocks in @@ -284,7 +290,7 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { /// This method takes a slice of bits of the length `l`, add `m` dummy /// bits to the end of the slice, and returns it as a new vector. The /// output vector's length is `λ^2` where `λ = sqrt(l + m) ∈ Z`. - fn add_dummy_bits(a: &[Replicated]) -> (Vec>, usize) { + fn add_dummy_bits(a: &[Replicated]) -> (Vec>, usize) { // We plan to use u32, which we'll add 4 dummy bits to get λ = 6. // Since we don't want to compute sqrt() each time this protocol // is called, we'll assume that the input is 32-bit long. @@ -296,7 +302,7 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { 32 => 6, _ => panic!("bit length must 8, 16 or 32"), }; - let dummy = vec![Replicated::new(B::ZERO, B::ZERO); lambda * lambda - l]; + let dummy = vec![Replicated::new(F::ZERO, F::ZERO); lambda * lambda - l]; ([a, &dummy].concat(), lambda) } } @@ -304,6 +310,8 @@ impl<'a, B: BinaryField> PrefixOr<'a, B> { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum Step { BitwiseOrPerBlock, + AMultB, + ABMultC, BlockWisePrefixOr, InnerProduct, GetFirstBlockWithOne, @@ -315,6 +323,8 @@ impl crate::protocol::Step for Step {} impl AsRef for Step { fn as_ref(&self) -> &str { match self { + Self::AMultB => "a_mult_b", + Self::ABMultC => "ab_mult_c", Self::BitwiseOrPerBlock => "bitwise_or_per_block", Self::BlockWisePrefixOr => "block_wise_prefix_or", Self::InnerProduct => "inner_product", @@ -326,31 +336,65 @@ impl AsRef for Step { #[cfg(test)] mod tests { - use futures::future::try_join_all; - use rand::{rngs::mock::StepRng, Rng}; - + use super::PrefixOr; use crate::{ - ff::{Field, Fp2}, + error::BoxError, + ff::{Field, Fp2, Fp31}, protocol::{QueryId, RecordId}, secret_sharing::Replicated, test_fixture::{make_contexts, make_world, share, validate_and_reconstruct, TestWorld}, }; + use futures::future::try_join_all; + use rand::{rngs::mock::StepRng, Rng}; + use std::iter::zip; - use super::PrefixOr; + const BITS: usize = 32; + const TEST_TRIES: usize = 16; - #[tokio::test] - pub async fn prefix_or() { - const BITS: usize = 32; - const TEST_TRIES: usize = 16; + async fn prefix_or(input: &[F]) -> Result, BoxError> { let world: TestWorld = make_world(QueryId); - let ctx = make_contexts::(&world); + let ctx = make_contexts::(&world); let mut rand = StepRng::new(1, 1); + + // Generate secret shares + #[allow(clippy::type_complexity)] + let (s0, (s1, s2)): (Vec>, (Vec>, Vec>)) = input + .iter() + .map(|&x| { + let y = share(x, &mut rand); + (y[0], (y[1], y[2])) + }) + .unzip(); + + // Execute + let pre0 = PrefixOr::new(&s0); + let pre1 = PrefixOr::new(&s1); + let pre2 = PrefixOr::new(&s2); + let step = "PrefixOr_Test"; + let result = try_join_all(vec![ + pre0.execute(ctx[0].narrow(step), RecordId::from(0_u32)), + pre1.execute(ctx[1].narrow(step), RecordId::from(0_u32)), + pre2.execute(ctx[2].narrow(step), RecordId::from(0_u32)), + ]) + .await + .unwrap(); + + // Verify + assert_eq!(input.len(), result[0].len()); + Ok((0..input.len()) + .map(|i| validate_and_reconstruct((result[0][i], result[1][i], result[2][i]))) + .collect::>()) + } + + #[tokio::test] + /// Test PrefixOr with the input ⊆ F_2 + pub async fn fp2() -> Result<(), BoxError> { let mut rng = rand::thread_rng(); // Test 32-bit bitwise shares with randomly distributed bits, for 16 times. // The probability of i'th bit being 0 is 1/2^i, so this test covers inputs // that have all 0's in 5 first bits. - for i in 0..TEST_TRIES { + for _ in 0..TEST_TRIES { let len = BITS; let input: Vec = (0..len).map(|_| Fp2::from(rng.gen::())).collect(); let mut expected: Vec = Vec::with_capacity(len); @@ -361,40 +405,46 @@ mod tests { acc | x }); - // Generate secret shares - #[allow(clippy::type_complexity)] - let (s0, (s1, s2)): ( - Vec>, - (Vec>, Vec>), - ) = input - .iter() - .map(|&x| { - let y = share(x, &mut rand); - (y[0], (y[1], y[2])) - }) - .unzip(); - - // Execute - let pre0 = PrefixOr::new(&s0); - let pre1 = PrefixOr::new(&s1); - let pre2 = PrefixOr::new(&s2); - let iteration = format!("{}", i); - let result = try_join_all(vec![ - pre0.execute(ctx[0].narrow(&iteration), RecordId::from(0_u32)), - pre1.execute(ctx[1].narrow(&iteration), RecordId::from(0_u32)), - pre2.execute(ctx[2].narrow(&iteration), RecordId::from(0_u32)), - ]) - .await - .unwrap(); + let result = prefix_or(&input).await?; // Verify - assert_eq!(input.len(), result[0].len()); - for (j, &e) in expected.iter().enumerate().take(input.len()) { - assert_eq!( - e, - validate_and_reconstruct((result[0][j], result[1][j], result[2][j])), - ); - } + assert_eq!(expected.len(), result.len()); + zip(expected, result).for_each(|(e, r)| assert_eq!(e, r)); } + + Ok(()) + } + + #[tokio::test] + /// Test PrefixOr with the input ⊆ F_p (i.e. Fp31) + pub async fn fp31() -> Result<(), BoxError> { + let mut rng = rand::thread_rng(); + + // Test 32-bit bitwise shares with randomly distributed bits, for 16 times. + // The probability of i'th bit being 0 is 1/2^i, so this test covers inputs + // that have all 0's in 5 first bits. + for _ in 0..TEST_TRIES { + let len = BITS; + // Generate a vector of Fp31::ZERO or Fp31::ONE from randomly picked bool values + let input: Vec = (0..len) + .map(|_| Fp31::from(u128::from(rng.gen::()))) + .collect(); + let mut expected: Vec = Vec::with_capacity(len); + + // Calculate Prefix-Or of the secret number + input.iter().fold(0, |acc, &x| { + let sum = acc + x.as_u128(); + expected.push(Fp31::from(sum > 0)); + sum + }); + + let result = prefix_or(&input).await?; + + // Verify + assert_eq!(expected.len(), result.len()); + zip(expected, result).for_each(|(e, r)| assert_eq!(e, r)); + } + + Ok(()) } } From 64eadc64e8bc805e228f58bb985885d39155b9ea Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Mon, 14 Nov 2022 23:27:16 +0800 Subject: [PATCH 11/24] Use vec::swap not mem::swap --- src/protocol/sort/apply.rs | 48 +++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/src/protocol/sort/apply.rs b/src/protocol/sort/apply.rs index 4bad74738..2f8b64bff 100644 --- a/src/protocol/sort/apply.rs +++ b/src/protocol/sort/apply.rs @@ -1,6 +1,5 @@ use bitvec::bitvec; use embed_doc_image::embed_doc_image; -use std::mem; #[embed_doc_image("apply", "images/sort/apply.png")] #[embed_doc_image("apply_inv", "images/sort/apply_inv.png")] @@ -8,23 +7,20 @@ use std::mem; /// Permutation reorders (1, 2, . . . , m) into (σ(1), σ(2), . . . , σ(m)). /// For example, if σ(1) = 2, σ(2) = 3, σ(3) = 1, and σ(4) = 0, an input (A, B, C, D) is reordered into (C, D, B, A) by σ. /// ![Apply steps][apply] -pub fn apply(permutation: &[u32], values: &mut [T]) { +pub fn apply(permutation: &[u32], values: &mut [T]) { debug_assert!(permutation.len() == values.len()); let mut permuted = bitvec![0; permutation.len()]; - let mut tmp: T = T::default(); for i in 0..permutation.len() { if !permuted[i] { - mem::swap(&mut tmp, &mut values[i]); let mut pos_i = i; - let mut pos_j: usize = permutation[pos_i] as usize; + let mut pos_j = permutation[pos_i] as usize; while pos_j != i { - values[pos_i] = values[pos_j]; + values.swap(pos_i, pos_j); + permuted.set(pos_j, true); pos_i = pos_j; pos_j = permutation[pos_i] as usize; - permuted.set(pos_i, true); } - mem::swap(&mut values[pos_i], &mut tmp); permuted.set(i, true); } } @@ -34,20 +30,17 @@ pub fn apply(permutation: &[u32], values: &mut [T]) { /// is moved by `apply_inv` to be the σ(i)-th item. Therefore, if σ(1) = 2, σ(2) = 3, σ(3) = 1, and σ(4) = 0, an input (A, B, C, D) is /// reordered into (D, C, A, B). /// ![Apply inv steps][apply_inv] -pub fn apply_inv(permutation: &[u32], values: &mut [T]) { +pub fn apply_inv(permutation: &[u32], values: &mut [T]) { let mut permuted = bitvec![0; permutation.len()]; - let mut tmp: T; for i in 0..permutation.len() { if !permuted[i] { - let mut destination: usize = permutation[i] as usize; - tmp = values[i]; + let mut destination = permutation[i] as usize; while destination != i { - mem::swap(&mut tmp, &mut values[destination]); + values.swap(i, destination); permuted.set(destination, true); destination = permutation[destination] as usize; } - mem::swap(&mut values[i], &mut tmp); permuted.set(i, true); } } @@ -56,6 +49,7 @@ pub fn apply_inv(permutation: &[u32], values: &mut [T]) { #[cfg(test)] mod tests { use super::{apply, apply_inv}; + use rand::seq::SliceRandom; #[test] fn apply_just_one_cycle() { @@ -111,6 +105,32 @@ mod tests { assert_eq!(values, expected_output_apply); } + #[test] + fn permutations_a_million_long() { + const SUPER_LONG: usize = 16 * 16 * 16 * 16 * 16; // 1,048,576, a bit over a million + let mut original_values = Vec::with_capacity(SUPER_LONG); + for i in 0..SUPER_LONG { + original_values.push(format!("{:#07x}", i)); + } + let mut permutation: Vec = (0..SUPER_LONG) + .map(|i| usize::try_into(i).unwrap()) + .collect(); + let mut rng = rand::thread_rng(); + permutation.shuffle(&mut rng); + + let mut after_apply = original_values.clone(); + apply(&permutation, &mut after_apply); + for i in 0..SUPER_LONG { + assert_eq!(after_apply[i], original_values[permutation[i] as usize]); + } + + let mut after_apply_inv = original_values.clone(); + apply_inv(&permutation, &mut after_apply_inv); + for i in 0..SUPER_LONG { + assert_eq!(original_values[i], after_apply_inv[permutation[i] as usize]); + } + } + #[test] pub fn composing() { let sigma = vec![4, 2, 0, 5, 1, 3]; From e95043de0365e027583c8b4f2081d965fbeb625a Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 15 Nov 2022 00:00:51 +0800 Subject: [PATCH 12/24] pass ownership when possible, avoid mutable references --- src/protocol/sort/compose.rs | 14 +++--- .../sort/generate_sort_permutation.rs | 12 ++--- src/protocol/sort/secureapplyinv.rs | 12 ++--- src/protocol/sort/shuffle.rs | 46 +++++++++---------- 4 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/protocol/sort/compose.rs b/src/protocol/sort/compose.rs index e0df2b5c5..1cc08402b 100644 --- a/src/protocol/sort/compose.rs +++ b/src/protocol/sort/compose.rs @@ -36,8 +36,8 @@ impl Compose { #[allow(dead_code)] pub async fn execute( ctx: ProtocolContext<'_, Replicated, F>, - sigma: &mut [Replicated], - rho: &mut [Replicated], + sigma: Vec>, + mut rho: Vec>, ) -> Result>, BoxError> { let (left_random_permutation, right_random_permutation) = get_two_of_three_random_permutations(rho.len(), &ctx.prss()); @@ -52,7 +52,7 @@ impl Compose { let revealed_permutation = reveal_permutation(ctx.narrow(&RevealPermutation), &shuffled_sigma).await?; - apply(&revealed_permutation, rho); + apply(&revealed_permutation, &mut rho); let unshuffled_rho = unshuffle_shares( rho, @@ -102,14 +102,14 @@ mod tests { let mut rho_composed = rho_u128.clone(); apply(&sigma, &mut rho_composed); - let mut sigma_shares = generate_shares::(sigma_u128); + let sigma_shares = generate_shares::(sigma_u128); let mut rho_shares = generate_shares::(rho_u128); let world: TestWorld = make_world(QueryId); let [ctx0, ctx1, ctx2] = make_contexts(&world); - let h0_future = Compose::execute(ctx0, &mut sigma_shares.0, &mut rho_shares.0); - let h1_future = Compose::execute(ctx1, &mut sigma_shares.1, &mut rho_shares.1); - let h2_future = Compose::execute(ctx2, &mut sigma_shares.2, &mut rho_shares.2); + let h0_future = Compose::execute(ctx0, sigma_shares.0, rho_shares.0); + let h1_future = Compose::execute(ctx1, sigma_shares.1, rho_shares.1); + let h2_future = Compose::execute(ctx2, sigma_shares.2, rho_shares.2); rho_shares = try_join!(h0_future, h1_future, h2_future)?; diff --git a/src/protocol/sort/generate_sort_permutation.rs b/src/protocol/sort/generate_sort_permutation.rs index 6b5fe95f5..1bdc865c5 100644 --- a/src/protocol/sort/generate_sort_permutation.rs +++ b/src/protocol/sort/generate_sort_permutation.rs @@ -63,7 +63,7 @@ impl<'a> GenerateSortPermutation<'a> { let mut composed_less_significant_bits_permutation = bit_0_permutation; for bit_num in 1..self.num_bits { let ctx_bit = ctx.narrow(&Sort(bit_num)); - let mut bit_i = convert_shares_for_a_bit( + let bit_i = convert_shares_for_a_bit( ctx_bit.narrow(&ModulusConversion), self.input, self.num_bits, @@ -72,19 +72,19 @@ impl<'a> GenerateSortPermutation<'a> { .await?; let bit_i_sorted_by_less_significant_bits = SecureApplyInv::execute( ctx_bit.narrow(&ApplyInv), - &mut bit_i, - &mut composed_less_significant_bits_permutation.clone(), + bit_i, + composed_less_significant_bits_permutation.clone(), ) .await?; - let mut bit_i_permutation = BitPermutation::new(&bit_i_sorted_by_less_significant_bits) + let bit_i_permutation = BitPermutation::new(&bit_i_sorted_by_less_significant_bits) .execute(ctx_bit.narrow(&BitPermutationStep)) .await?; let composed_i_permutation = Compose::execute( ctx_bit.narrow(&ComposeStep), - &mut composed_less_significant_bits_permutation, - &mut bit_i_permutation, + composed_less_significant_bits_permutation, + bit_i_permutation, ) .await?; composed_less_significant_bits_permutation = composed_i_permutation; diff --git a/src/protocol/sort/secureapplyinv.rs b/src/protocol/sort/secureapplyinv.rs index 5aa0fe2db..d4835fd44 100644 --- a/src/protocol/sort/secureapplyinv.rs +++ b/src/protocol/sort/secureapplyinv.rs @@ -42,8 +42,8 @@ impl SecureApplyInv { /// 5. All helpers call `apply` to apply the permutation locally. pub async fn execute( ctx: ProtocolContext<'_, Replicated, F>, - input: &mut [Replicated], - sort_permutation: &mut [Replicated], + input: Vec>, + sort_permutation: Vec>, ) -> Result>, BoxError> { let (left_random_permutation, right_random_permutation) = get_two_of_three_random_permutations(input.len(), &ctx.prss()); @@ -105,15 +105,15 @@ mod tests { let permutation: Vec = permutation.iter().map(|x| u128::from(*x)).collect(); - let mut perm_shares = generate_shares::(permutation); + let perm_shares = generate_shares::(permutation); let mut input_shares = generate_shares::(input); let world = make_world(QueryId); let [ctx0, ctx1, ctx2] = make_contexts(&world); - let h0_future = SecureApplyInv::execute(ctx0, &mut input_shares.0, &mut perm_shares.0); - let h1_future = SecureApplyInv::execute(ctx1, &mut input_shares.1, &mut perm_shares.1); - let h2_future = SecureApplyInv::execute(ctx2, &mut input_shares.2, &mut perm_shares.2); + let h0_future = SecureApplyInv::execute(ctx0, input_shares.0, perm_shares.0); + let h1_future = SecureApplyInv::execute(ctx1, input_shares.1, perm_shares.1); + let h2_future = SecureApplyInv::execute(ctx2, input_shares.2, perm_shares.2); input_shares = try_join!(h0_future, h1_future, h2_future).unwrap(); diff --git a/src/protocol/sort/shuffle.rs b/src/protocol/sort/shuffle.rs index 77e428b79..982429476 100644 --- a/src/protocol/sort/shuffle.rs +++ b/src/protocol/sort/shuffle.rs @@ -87,7 +87,7 @@ fn shuffle_for_helper(which_step: ShuffleStep) -> Role { #[allow(clippy::cast_possible_truncation)] async fn reshare_all_shares( - input: &mut [Replicated], + input: Vec>, ctx: &ProtocolContext<'_, Replicated, F>, to_helper: Role, ) -> Result>, BoxError> { @@ -105,7 +105,7 @@ async fn reshare_all_shares( /// iii) reshare to `to_helper` #[allow(clippy::cast_possible_truncation)] async fn shuffle_or_unshuffle_once( - input: &mut [Replicated], + mut input: Vec>, permutation_left: &[u32], permutation_right: &[u32], shuffle_or_unshuffle: ShuffleOrUnshuffle, @@ -123,8 +123,8 @@ async fn shuffle_or_unshuffle_once( }; match shuffle_or_unshuffle { - ShuffleOrUnshuffle::Shuffle => apply_inv(permutation_to_apply, input), - ShuffleOrUnshuffle::Unshuffle => apply(permutation_to_apply, input), + ShuffleOrUnshuffle::Shuffle => apply_inv(permutation_to_apply, &mut input), + ShuffleOrUnshuffle::Unshuffle => apply(permutation_to_apply, &mut input), } } reshare_all_shares(input, &ctx, to_helper).await @@ -139,12 +139,12 @@ async fn shuffle_or_unshuffle_once( /// The Shuffle object receives a step function and appends a `ShuffleStep` to form a concrete step /// ![Shuffle steps][shuffle] pub async fn shuffle_shares( - input: &mut [Replicated], + input: Vec>, permutation_left: &[u32], permutation_right: &[u32], ctx: ProtocolContext<'_, Replicated, F>, ) -> Result>, BoxError> { - let mut once_shuffled = shuffle_or_unshuffle_once( + let input = shuffle_or_unshuffle_once( input, permutation_left, permutation_right, @@ -153,8 +153,8 @@ pub async fn shuffle_shares( Step1, ) .await?; - let mut twice_shuffled = shuffle_or_unshuffle_once( - &mut once_shuffled, + let input = shuffle_or_unshuffle_once( + input, permutation_left, permutation_right, ShuffleOrUnshuffle::Shuffle, @@ -163,7 +163,7 @@ pub async fn shuffle_shares( ) .await?; shuffle_or_unshuffle_once( - &mut twice_shuffled, + input, permutation_left, permutation_right, ShuffleOrUnshuffle::Shuffle, @@ -178,12 +178,12 @@ pub async fn shuffle_shares( /// Order of calling `shuffle_or_unshuffle_once` is shuffle with (H1, H2), (H3, H1) and (H2, H3) /// ![Unshuffle steps][unshuffle] pub async fn unshuffle_shares( - input: &mut [Replicated], + input: Vec>, permutation_left: &[u32], permutation_right: &[u32], ctx: ProtocolContext<'_, Replicated, F>, ) -> Result>, BoxError> { - let mut once_shuffled = shuffle_or_unshuffle_once( + let input = shuffle_or_unshuffle_once( input, permutation_left, permutation_right, @@ -192,8 +192,8 @@ pub async fn unshuffle_shares( Step3, ) .await?; - let mut twice_shuffled = shuffle_or_unshuffle_once( - &mut once_shuffled, + let input = shuffle_or_unshuffle_once( + input, permutation_left, permutation_right, ShuffleOrUnshuffle::Unshuffle, @@ -202,7 +202,7 @@ pub async fn unshuffle_shares( ) .await?; shuffle_or_unshuffle_once( - &mut twice_shuffled, + input, permutation_left, permutation_right, ShuffleOrUnshuffle::Unshuffle, @@ -286,9 +286,9 @@ mod tests { let [c0, c1, c2] = context; - let h0_future = shuffle_shares(&mut shares.0, &perm1.0, &perm1.1, c0); - let h1_future = shuffle_shares(&mut shares.1, &perm2.0, &perm2.1, c1); - let h2_future = shuffle_shares(&mut shares.2, &perm3.0, &perm3.1, c2); + let h0_future = shuffle_shares(shares.0, &perm1.0, &perm1.1, c0); + let h1_future = shuffle_shares(shares.1, &perm2.0, &perm2.1, c1); + let h2_future = shuffle_shares(shares.2, &perm3.0, &perm3.1, c2); shares = try_join!(h0_future, h1_future, h2_future).unwrap(); @@ -336,17 +336,17 @@ mod tests { { let [ctx0, ctx1, ctx2] = narrow_contexts(&context, &ShuffleOrUnshuffle::Shuffle); - let h0_future = shuffle_shares(&mut shares.0, &perm1.0, &perm1.1, ctx0); - let h1_future = shuffle_shares(&mut shares.1, &perm2.0, &perm2.1, ctx1); - let h2_future = shuffle_shares(&mut shares.2, &perm3.0, &perm3.1, ctx2); + let h0_future = shuffle_shares(shares.0, &perm1.0, &perm1.1, ctx0); + let h1_future = shuffle_shares(shares.1, &perm2.0, &perm2.1, ctx1); + let h2_future = shuffle_shares(shares.2, &perm3.0, &perm3.1, ctx2); shares = try_join!(h0_future, h1_future, h2_future).unwrap(); } { let [ctx0, ctx1, ctx2] = narrow_contexts(&context, &ShuffleOrUnshuffle::Unshuffle); - let h0_future = unshuffle_shares(&mut shares.0, &perm1.0, &perm1.1, ctx0); - let h1_future = unshuffle_shares(&mut shares.1, &perm2.0, &perm2.1, ctx1); - let h2_future = unshuffle_shares(&mut shares.2, &perm3.0, &perm3.1, ctx2); + let h0_future = unshuffle_shares(shares.0, &perm1.0, &perm1.1, ctx0); + let h1_future = unshuffle_shares(shares.1, &perm2.0, &perm2.1, ctx1); + let h2_future = unshuffle_shares(shares.2, &perm3.0, &perm3.1, ctx2); // When unshuffle and shuffle are called with same step, they undo each other's effect shares = try_join!(h0_future, h1_future, h2_future).unwrap(); From b7fb9d431a54ade06ef8621629ae2b5501e3a5ec Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 15 Nov 2022 00:26:24 +0800 Subject: [PATCH 13/24] back to passing tuples of random permutations --- src/protocol/sort/compose.rs | 10 +++--- src/protocol/sort/secureapplyinv.rs | 10 +++--- src/protocol/sort/shuffle.rs | 49 ++++++++++++----------------- 3 files changed, 28 insertions(+), 41 deletions(-) diff --git a/src/protocol/sort/compose.rs b/src/protocol/sort/compose.rs index 1cc08402b..a401d26ff 100644 --- a/src/protocol/sort/compose.rs +++ b/src/protocol/sort/compose.rs @@ -39,13 +39,12 @@ impl Compose { sigma: Vec>, mut rho: Vec>, ) -> Result>, BoxError> { - let (left_random_permutation, right_random_permutation) = - get_two_of_three_random_permutations(rho.len(), &ctx.prss()); + let prss = &ctx.prss(); + let random_permutations = get_two_of_three_random_permutations(rho.len(), prss); let shuffled_sigma = shuffle_shares( sigma, - &left_random_permutation, - &right_random_permutation, + (&random_permutations.0, &random_permutations.1), ctx.narrow(&ShuffleSigma), ) .await?; @@ -56,8 +55,7 @@ impl Compose { let unshuffled_rho = unshuffle_shares( rho, - &left_random_permutation, - &right_random_permutation, + (&random_permutations.0, &random_permutations.1), ctx.narrow(&UnshuffleRho), ) .await?; diff --git a/src/protocol/sort/secureapplyinv.rs b/src/protocol/sort/secureapplyinv.rs index d4835fd44..d5573ca6d 100644 --- a/src/protocol/sort/secureapplyinv.rs +++ b/src/protocol/sort/secureapplyinv.rs @@ -45,20 +45,18 @@ impl SecureApplyInv { input: Vec>, sort_permutation: Vec>, ) -> Result>, BoxError> { - let (left_random_permutation, right_random_permutation) = - get_two_of_three_random_permutations(input.len(), &ctx.prss()); + let prss = &ctx.prss(); + let random_permutations = get_two_of_three_random_permutations(input.len(), prss); let (mut shuffled_input, shuffled_sort_permutation) = try_join( shuffle_shares( input, - &left_random_permutation, - &right_random_permutation, + (&random_permutations.0, &random_permutations.1), ctx.narrow(&ShuffleInputs), ), shuffle_shares( sort_permutation, - &left_random_permutation, - &right_random_permutation, + (&random_permutations.0, &random_permutations.1), ctx.narrow(&ShufflePermutation), ), ) diff --git a/src/protocol/sort/shuffle.rs b/src/protocol/sort/shuffle.rs index 982429476..cc418af00 100644 --- a/src/protocol/sort/shuffle.rs +++ b/src/protocol/sort/shuffle.rs @@ -106,8 +106,7 @@ async fn reshare_all_shares( #[allow(clippy::cast_possible_truncation)] async fn shuffle_or_unshuffle_once( mut input: Vec>, - permutation_left: &[u32], - permutation_right: &[u32], + random_permutations: (&[u32], &[u32]), shuffle_or_unshuffle: ShuffleOrUnshuffle, ctx: &ProtocolContext<'_, Replicated, F>, which_step: ShuffleStep, @@ -117,9 +116,9 @@ async fn shuffle_or_unshuffle_once( if to_helper != ctx.role() { let permutation_to_apply = if to_helper.peer(Direction::Left) == ctx.role() { - permutation_left + random_permutations.0 } else { - permutation_right + random_permutations.1 }; match shuffle_or_unshuffle { @@ -140,14 +139,12 @@ async fn shuffle_or_unshuffle_once( /// ![Shuffle steps][shuffle] pub async fn shuffle_shares( input: Vec>, - permutation_left: &[u32], - permutation_right: &[u32], + random_permutations: (&[u32], &[u32]), ctx: ProtocolContext<'_, Replicated, F>, ) -> Result>, BoxError> { let input = shuffle_or_unshuffle_once( input, - permutation_left, - permutation_right, + random_permutations, ShuffleOrUnshuffle::Shuffle, &ctx, Step1, @@ -155,8 +152,7 @@ pub async fn shuffle_shares( .await?; let input = shuffle_or_unshuffle_once( input, - permutation_left, - permutation_right, + random_permutations, ShuffleOrUnshuffle::Shuffle, &ctx, Step2, @@ -164,8 +160,7 @@ pub async fn shuffle_shares( .await?; shuffle_or_unshuffle_once( input, - permutation_left, - permutation_right, + random_permutations, ShuffleOrUnshuffle::Shuffle, &ctx, Step3, @@ -179,14 +174,12 @@ pub async fn shuffle_shares( /// ![Unshuffle steps][unshuffle] pub async fn unshuffle_shares( input: Vec>, - permutation_left: &[u32], - permutation_right: &[u32], + random_permutations: (&[u32], &[u32]), ctx: ProtocolContext<'_, Replicated, F>, ) -> Result>, BoxError> { let input = shuffle_or_unshuffle_once( input, - permutation_left, - permutation_right, + random_permutations, ShuffleOrUnshuffle::Unshuffle, &ctx, Step3, @@ -194,8 +187,7 @@ pub async fn unshuffle_shares( .await?; let input = shuffle_or_unshuffle_once( input, - permutation_left, - permutation_right, + random_permutations, ShuffleOrUnshuffle::Unshuffle, &ctx, Step2, @@ -203,8 +195,7 @@ pub async fn unshuffle_shares( .await?; shuffle_or_unshuffle_once( input, - permutation_left, - permutation_right, + random_permutations, ShuffleOrUnshuffle::Unshuffle, &ctx, Step1, @@ -286,9 +277,9 @@ mod tests { let [c0, c1, c2] = context; - let h0_future = shuffle_shares(shares.0, &perm1.0, &perm1.1, c0); - let h1_future = shuffle_shares(shares.1, &perm2.0, &perm2.1, c1); - let h2_future = shuffle_shares(shares.2, &perm3.0, &perm3.1, c2); + let h0_future = shuffle_shares(shares.0, (&perm1.0, &perm1.1), c0); + let h1_future = shuffle_shares(shares.1, (&perm2.0, &perm2.1), c1); + let h2_future = shuffle_shares(shares.2, (&perm3.0, &perm3.1), c2); shares = try_join!(h0_future, h1_future, h2_future).unwrap(); @@ -336,17 +327,17 @@ mod tests { { let [ctx0, ctx1, ctx2] = narrow_contexts(&context, &ShuffleOrUnshuffle::Shuffle); - let h0_future = shuffle_shares(shares.0, &perm1.0, &perm1.1, ctx0); - let h1_future = shuffle_shares(shares.1, &perm2.0, &perm2.1, ctx1); - let h2_future = shuffle_shares(shares.2, &perm3.0, &perm3.1, ctx2); + let h0_future = shuffle_shares(shares.0, (&perm1.0, &perm1.1), ctx0); + let h1_future = shuffle_shares(shares.1, (&perm2.0, &perm2.1), ctx1); + let h2_future = shuffle_shares(shares.2, (&perm3.0, &perm3.1), ctx2); shares = try_join!(h0_future, h1_future, h2_future).unwrap(); } { let [ctx0, ctx1, ctx2] = narrow_contexts(&context, &ShuffleOrUnshuffle::Unshuffle); - let h0_future = unshuffle_shares(shares.0, &perm1.0, &perm1.1, ctx0); - let h1_future = unshuffle_shares(shares.1, &perm2.0, &perm2.1, ctx1); - let h2_future = unshuffle_shares(shares.2, &perm3.0, &perm3.1, ctx2); + let h0_future = unshuffle_shares(shares.0, (&perm1.0, &perm1.1), ctx0); + let h1_future = unshuffle_shares(shares.1, (&perm2.0, &perm2.1), ctx1); + let h2_future = unshuffle_shares(shares.2, (&perm3.0, &perm3.1), ctx2); // When unshuffle and shuffle are called with same step, they undo each other's effect shares = try_join!(h0_future, h1_future, h2_future).unwrap(); From 925be151832a43fea0428b87842dec097e8060ed Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 15 Nov 2022 00:29:18 +0800 Subject: [PATCH 14/24] remove printlns --- src/test_fixture/sharing.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/test_fixture/sharing.rs b/src/test_fixture/sharing.rs index 945f87d90..a56b53532 100644 --- a/src/test_fixture/sharing.rs +++ b/src/test_fixture/sharing.rs @@ -46,9 +46,6 @@ pub fn validate_list_of_shares(expected_result: &[u128], result: &Repl .map(|i| validate_and_reconstruct((result.0[i], result.1[i], result.2[i]))) .collect(); - println!("expected results: {:#?}", expected_result); - println!("revealed values: {:#?}", revealed_values); - for i in 0..revealed_values.len() { assert_eq!(revealed_values[i], F::from(expected_result[i])); } From 0b1ecd4f766d0305e13f1bcf0fcfd257c3cf9418 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 15 Nov 2022 00:46:31 +0800 Subject: [PATCH 15/24] micro-optimization --- src/protocol/sort/apply.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/protocol/sort/apply.rs b/src/protocol/sort/apply.rs index 2f8b64bff..ff8f1838e 100644 --- a/src/protocol/sort/apply.rs +++ b/src/protocol/sort/apply.rs @@ -21,7 +21,6 @@ pub fn apply(permutation: &[u32], values: &mut [T]) { pos_i = pos_j; pos_j = permutation[pos_i] as usize; } - permuted.set(i, true); } } } @@ -41,7 +40,6 @@ pub fn apply_inv(permutation: &[u32], values: &mut [T]) { permuted.set(destination, true); destination = permutation[destination] as usize; } - permuted.set(i, true); } } } From bb2631137ceaa54163f7ee5d7f362845e9cb324d Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 15 Nov 2022 00:58:40 +0800 Subject: [PATCH 16/24] revert a dumb modification --- src/protocol/sort/shuffle.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/protocol/sort/shuffle.rs b/src/protocol/sort/shuffle.rs index cc418af00..d87b089fd 100644 --- a/src/protocol/sort/shuffle.rs +++ b/src/protocol/sort/shuffle.rs @@ -226,18 +226,15 @@ mod tests { #[test] fn random_sequence_generated() { - const BATCH_SIZE: u32 = 10000; + const BATCH_SIZE: usize = 10000; logging::setup(); let (p1, p2, p3) = make_participants(); let step = UniqueStepId::default(); - let perm1 = - get_two_of_three_random_permutations(BATCH_SIZE as usize, p1.indexed(&step).as_ref()); - let perm2 = - get_two_of_three_random_permutations(BATCH_SIZE as usize, p2.indexed(&step).as_ref()); - let perm3 = - get_two_of_three_random_permutations(BATCH_SIZE as usize, p3.indexed(&step).as_ref()); + let perm1 = get_two_of_three_random_permutations(BATCH_SIZE, p1.indexed(&step).as_ref()); + let perm2 = get_two_of_three_random_permutations(BATCH_SIZE, p2.indexed(&step).as_ref()); + let perm3 = get_two_of_three_random_permutations(BATCH_SIZE, p3.indexed(&step).as_ref()); assert_eq!(perm1.1, perm2.0); assert_eq!(perm2.1, perm3.0); From 5be4c0633f5bfd4c61e5e7e08442854feecb580d Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 15 Nov 2022 08:05:39 +0800 Subject: [PATCH 17/24] test fewer rows --- src/protocol/sort/apply.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/protocol/sort/apply.rs b/src/protocol/sort/apply.rs index ff8f1838e..df9317dbc 100644 --- a/src/protocol/sort/apply.rs +++ b/src/protocol/sort/apply.rs @@ -104,11 +104,11 @@ mod tests { } #[test] - fn permutations_a_million_long() { - const SUPER_LONG: usize = 16 * 16 * 16 * 16 * 16; // 1,048,576, a bit over a million + fn permutations_super_long() { + const SUPER_LONG: usize = 16 * 16 * 16 * 16; // 65,536 let mut original_values = Vec::with_capacity(SUPER_LONG); for i in 0..SUPER_LONG { - original_values.push(format!("{:#07x}", i)); + original_values.push(format!("{:#06x}", i)); } let mut permutation: Vec = (0..SUPER_LONG) .map(|i| usize::try_into(i).unwrap()) From 38534c55d9d3ee302cff4fafde79d98d227e34f1 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 15 Nov 2022 13:00:56 +1100 Subject: [PATCH 18/24] Rename steps Step -> Substep UniqueStepId -> Step Most of the step interactions are through the latter, so it makes sense ot use a shorter name. This will get more important when we get further into optimization. Closes #134. --- src/helpers/buffers/send.rs | 16 ++++----- src/helpers/error.rs | 4 +-- src/helpers/fabric.rs | 6 ++-- src/helpers/messaging.rs | 6 ++-- src/net/client/mod.rs | 6 ++-- src/net/server/handlers/query.rs | 16 ++++----- src/protocol/attribution/accumulate_credit.rs | 2 +- src/protocol/attribution/mod.rs | 2 +- src/protocol/boolean/mod.rs | 2 +- src/protocol/boolean/prefix_or.rs | 2 +- src/protocol/check_zero.rs | 2 +- src/protocol/context.rs | 12 +++---- src/protocol/malicious.rs | 2 +- src/protocol/mod.rs | 30 ++++++++-------- .../modulus_conversion/convert_shares.rs | 2 +- .../modulus_conversion/double_random.rs | 2 +- src/protocol/mul/malicious.rs | 2 +- src/protocol/prss.rs | 34 +++++++++---------- src/protocol/sort/mod.rs | 10 +++--- src/protocol/sort/shuffle.rs | 8 ++--- src/test_fixture/fabric.rs | 4 +-- src/test_fixture/mod.rs | 4 +-- 22 files changed, 87 insertions(+), 87 deletions(-) diff --git a/src/helpers/buffers/send.rs b/src/helpers/buffers/send.rs index c53b9c62f..1e9e3b45d 100644 --- a/src/helpers/buffers/send.rs +++ b/src/helpers/buffers/send.rs @@ -139,7 +139,7 @@ mod tests { use crate::helpers::buffers::send::{ByteBuf, Config, PushError}; use crate::helpers::buffers::SendBuffer; use crate::helpers::Role; - use crate::protocol::{RecordId, UniqueStepId}; + use crate::protocol::{RecordId, Step}; use tinyvec::array_vec; @@ -162,14 +162,14 @@ mod tests { let msg = empty_msg(record_id); assert!(matches!( - buf.push(&ChannelId::new(Role::H1, UniqueStepId::default()), &msg), + buf.push(&ChannelId::new(Role::H1, Step::default()), &msg), Err(PushError::OutOfRange { .. }), )); } #[test] fn does_not_corrupt_messages() { - let c1 = ChannelId::new(Role::H1, UniqueStepId::default()); + let c1 = ChannelId::new(Role::H1, Step::default()); let mut buf = SendBuffer::new(Config::default().items_in_batch(10)); let batch = (0u8..10) @@ -193,8 +193,8 @@ mod tests { #[test] fn offset_is_per_channel() { let mut buf = SendBuffer::new(Config::default()); - let c1 = ChannelId::new(Role::H1, UniqueStepId::default()); - let c2 = ChannelId::new(Role::H2, UniqueStepId::default()); + let c1 = ChannelId::new(Role::H1, Step::default()); + let c2 = ChannelId::new(Role::H2, Step::default()); let m1 = empty_msg(0); let m2 = empty_msg(1); @@ -211,7 +211,7 @@ mod tests { #[test] fn rejects_duplicates() { let mut buf = SendBuffer::new(Config::default().items_in_batch(10)); - let channel = ChannelId::new(Role::H1, UniqueStepId::default()); + let channel = ChannelId::new(Role::H1, Step::default()); let record_id = RecordId::from(3_u32); let m1 = empty_msg(record_id); let m2 = empty_msg(record_id); @@ -229,7 +229,7 @@ mod tests { let msg = empty_msg(5); assert!(matches!( - buf.push(&ChannelId::new(Role::H1, UniqueStepId::default()), &msg), + buf.push(&ChannelId::new(Role::H1, Step::default()), &msg), Ok(None) )); } @@ -237,7 +237,7 @@ mod tests { #[test] fn accepts_records_from_next_range_after_flushing() { let mut buf = SendBuffer::new(Config::default()); - let channel = ChannelId::new(Role::H1, UniqueStepId::default()); + let channel = ChannelId::new(Role::H1, Step::default()); let next_msg = empty_msg(1); let this_msg = empty_msg(0); diff --git a/src/helpers/error.rs b/src/helpers/error.rs index f85de8c78..1e7902c5c 100644 --- a/src/helpers/error.rs +++ b/src/helpers/error.rs @@ -1,6 +1,6 @@ use crate::error::BoxError; use crate::helpers::Role; -use crate::protocol::{RecordId, UniqueStepId}; +use crate::protocol::{RecordId, Step}; use thiserror::Error; use tokio::sync::mpsc::error::SendError; @@ -60,7 +60,7 @@ impl Error { #[must_use] pub fn serialization_error>( record_id: RecordId, - step: &UniqueStepId, + step: &Step, inner: E, ) -> Error { Self::SerializationError { diff --git a/src/helpers/fabric.rs b/src/helpers/fabric.rs index 556113fef..8ee76083f 100644 --- a/src/helpers/fabric.rs +++ b/src/helpers/fabric.rs @@ -1,5 +1,5 @@ use crate::helpers::{error::Error, MessagePayload, Role}; -use crate::protocol::{RecordId, UniqueStepId}; +use crate::protocol::{RecordId, Step}; use async_trait::async_trait; use futures::Stream; use std::fmt::{Debug, Formatter}; @@ -9,7 +9,7 @@ use std::fmt::{Debug, Formatter}; #[derive(Clone, Eq, PartialEq, Hash)] pub struct ChannelId { pub role: Role, - pub step: UniqueStepId, + pub step: Step, } #[derive(Debug, PartialEq, Eq)] @@ -37,7 +37,7 @@ pub trait Network: Sync { impl ChannelId { #[must_use] - pub fn new(role: Role, step: UniqueStepId) -> Self { + pub fn new(role: Role, step: Step) -> Self { Self { role, step } } } diff --git a/src/helpers/messaging.rs b/src/helpers/messaging.rs index a5a159d1e..8a17be3b7 100644 --- a/src/helpers/messaging.rs +++ b/src/helpers/messaging.rs @@ -11,7 +11,7 @@ use crate::{ helpers::error::Error, helpers::fabric::{ChannelId, MessageEnvelope, Network}, helpers::Role, - protocol::{RecordId, UniqueStepId}, + protocol::{RecordId, Step}, }; use crate::ff::{Field, Int}; @@ -85,7 +85,7 @@ pub struct Gateway { #[derive(Debug)] pub struct Mesh<'a, 'b> { gateway: &'a Gateway, - step: &'b UniqueStepId, + step: &'b Step, } pub(super) struct ReceiveRequest { @@ -201,7 +201,7 @@ impl Gateway { /// between this helper and every other one. The actual connection may be created only when /// `Mesh::send` or `Mesh::receive` methods are called. #[must_use] - pub fn mesh<'a, 'b>(&'a self, step: &'b UniqueStepId) -> Mesh<'a, 'b> { + pub fn mesh<'a, 'b>(&'a self, step: &'b Step) -> Mesh<'a, 'b> { Mesh { gateway: self, step, diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 92a6ce4fb..c10ce7b2e 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -5,7 +5,7 @@ pub use error::MpcHelperClientError; use crate::{ helpers::Role, net::RecordHeaders, - protocol::{QueryId, UniqueStepId}, + protocol::{QueryId, Step}, }; use axum::{ body::Bytes, @@ -19,7 +19,7 @@ use hyper_tls::HttpsConnector; pub struct HttpSendMessagesArgs<'a> { pub query_id: QueryId, - pub step: &'a UniqueStepId, + pub step: &'a Step, pub role: Role, pub offset: u32, pub data_size: u32, @@ -127,7 +127,7 @@ mod tests { const DATA_SIZE: u32 = 8; const DATA_LEN: u32 = 3; let query_id = QueryId; - let step = UniqueStepId::default().narrow("mul_test"); + let step = Step::default().narrow("mul_test"); let role = Role::H1; let offset = 0; let body = &[123; (DATA_SIZE * DATA_LEN) as usize]; diff --git a/src/net/server/handlers/query.rs b/src/net/server/handlers/query.rs index aaa8165aa..f388c995e 100644 --- a/src/net/server/handlers/query.rs +++ b/src/net/server/handlers/query.rs @@ -2,7 +2,7 @@ use crate::helpers::fabric::{ChannelId, MessageChunks}; use crate::helpers::Role; use crate::net::server::MpcHelperServerError; use crate::net::RecordHeaders; -use crate::protocol::{QueryId, UniqueStepId}; +use crate::protocol::{QueryId, Step}; use async_trait::async_trait; use axum::extract::{self, FromRequest, Query, RequestParts}; use axum::http::Request; @@ -13,7 +13,7 @@ use hyper::Body; use tokio::sync::mpsc; /// Used in the axum handler to extract the `query_id` and `step` from the path of the request -pub struct Path(QueryId, UniqueStepId); +pub struct Path(QueryId, Step); #[async_trait] impl FromRequest for Path { @@ -21,7 +21,7 @@ impl FromRequest for Path { async fn from_request(req: &mut RequestParts) -> Result { let extract::Path((query_id, step)) = - extract::Path::<(QueryId, UniqueStepId)>::from_request(req).await?; + extract::Path::<(QueryId, Step)>::from_request(req).await?; Ok(Path(query_id, step)) } } @@ -136,7 +136,7 @@ mod tests { fn build_req( port: u16, query_id: QueryId, - step: &UniqueStepId, + step: &Step, role: Role, offset: u32, body: &'static [u8], @@ -168,7 +168,7 @@ mod tests { async fn send_req( port: u16, query_id: QueryId, - step: &UniqueStepId, + step: &Step, helper_role: Role, offset: u32, body: &'static [u8], @@ -190,7 +190,7 @@ mod tests { // prepare req let query_id = QueryId; let target_helper = Role::H2; - let step = UniqueStepId::default().narrow("test"); + let step = Step::default().narrow("test"); let offset = 0; let body = &[213; (DATA_LEN * MESSAGE_PAYLOAD_SIZE_BYTES) as usize]; @@ -241,7 +241,7 @@ mod tests { fn default() -> Self { Self { query_id: QueryId.as_ref().to_owned(), - step: UniqueStepId::default().narrow("test").as_ref().to_owned(), + step: Step::default().narrow("test").as_ref().to_owned(), role: Role::H2.as_ref().to_owned(), offset_header: (OFFSET_HEADER_NAME.clone(), 0.into()), body: &[34; (DATA_LEN * MESSAGE_PAYLOAD_SIZE_BYTES) as usize], @@ -328,7 +328,7 @@ mod tests { // prepare req let query_id = QueryId; - let step = UniqueStepId::default().narrow("test"); + let step = Step::default().narrow("test"); let target_helper = Role::H2; let offset = 0; let body = &[0; (DATA_LEN * MESSAGE_PAYLOAD_SIZE_BYTES) as usize]; diff --git a/src/protocol/attribution/accumulate_credit.rs b/src/protocol/attribution/accumulate_credit.rs index 6f4401193..210c4f262 100644 --- a/src/protocol/attribution/accumulate_credit.rs +++ b/src/protocol/attribution/accumulate_credit.rs @@ -20,7 +20,7 @@ enum Step { BTimesSuccessorCredit, } -impl crate::protocol::Step for Step {} +impl crate::protocol::Substep for Step {} impl AsRef for Step { fn as_ref(&self) -> &str { diff --git a/src/protocol/attribution/mod.rs b/src/protocol/attribution/mod.rs index ab4dabb6a..aea1e32d2 100644 --- a/src/protocol/attribution/mod.rs +++ b/src/protocol/attribution/mod.rs @@ -50,7 +50,7 @@ impl IterStep { } } -impl crate::protocol::Step for IterStep {} +impl crate::protocol::Substep for IterStep {} impl AsRef for IterStep { fn as_ref(&self) -> &str { diff --git a/src/protocol/boolean/mod.rs b/src/protocol/boolean/mod.rs index 067be5f99..c9dc62b3d 100644 --- a/src/protocol/boolean/mod.rs +++ b/src/protocol/boolean/mod.rs @@ -13,7 +13,7 @@ enum BitOpStep { Step(usize), } -impl crate::protocol::Step for BitOpStep {} +impl crate::protocol::Substep for BitOpStep {} impl AsRef for BitOpStep { fn as_ref(&self) -> &str { diff --git a/src/protocol/boolean/prefix_or.rs b/src/protocol/boolean/prefix_or.rs index c9eafc486..944385021 100644 --- a/src/protocol/boolean/prefix_or.rs +++ b/src/protocol/boolean/prefix_or.rs @@ -310,7 +310,7 @@ enum Step { SetFirstBlockWithOne, } -impl crate::protocol::Step for Step {} +impl crate::protocol::Substep for Step {} impl AsRef for Step { fn as_ref(&self) -> &str { diff --git a/src/protocol/check_zero.rs b/src/protocol/check_zero.rs index 4ae91cd9f..239699248 100644 --- a/src/protocol/check_zero.rs +++ b/src/protocol/check_zero.rs @@ -19,7 +19,7 @@ enum Step { RevealR, } -impl crate::protocol::Step for Step {} +impl crate::protocol::Substep for Step {} impl AsRef for Step { fn as_ref(&self) -> &str { diff --git a/src/protocol/context.rs b/src/protocol/context.rs index c56b96861..4c1423873 100644 --- a/src/protocol/context.rs +++ b/src/protocol/context.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{ prss::{IndexedSharedRandomness, SequentialSharedRandomness}, - RecordId, Step, UniqueStepId, + RecordId, Step, Substep, }; use crate::{ ff::Field, @@ -21,7 +21,7 @@ use crate::secret_sharing::{MaliciousReplicated, Replicated, SecretSharing}; #[derive(Clone, Debug)] pub struct ProtocolContext<'a, S: SecretSharing, F> { role: Role, - step: UniqueStepId, + step: Step, prss: &'a PrssEndpoint, gateway: &'a Gateway, accumulator: Option>, @@ -33,7 +33,7 @@ impl<'a, F: Field, SS: SecretSharing> ProtocolContext<'a, SS, F> { pub fn new(role: Role, participant: &'a PrssEndpoint, gateway: &'a Gateway) -> Self { Self { role, - step: UniqueStepId::default(), + step: Step::default(), prss: participant, gateway, accumulator: None, @@ -50,14 +50,14 @@ impl<'a, F: Field, SS: SecretSharing> ProtocolContext<'a, SS, F> { /// A unique identifier for this stage of the protocol execution. #[must_use] - pub fn step(&self) -> &UniqueStepId { + pub fn step(&self) -> &Step { &self.step } /// Make a sub-context. /// Note that each invocation of this should use a unique value of `step`. #[must_use] - pub fn narrow(&self, step: &S) -> Self { + pub fn narrow(&self, step: &S) -> Self { ProtocolContext { role: self.role, step: self.step.narrow(step), @@ -84,7 +84,7 @@ impl<'a, F: Field, SS: SecretSharing> ProtocolContext<'a, SS, F> { role: self.role, // create a unique step that allows narrowing this context to the same step // if it is bound to a different record id - step: UniqueStepId::from_step_id(&self.step), + step: Step::from_step_id(&self.step), prss: self.prss, gateway: self.gateway, accumulator: self.accumulator.clone(), diff --git a/src/protocol/malicious.rs b/src/protocol/malicious.rs index 89c29efc5..42484e9b6 100644 --- a/src/protocol/malicious.rs +++ b/src/protocol/malicious.rs @@ -21,7 +21,7 @@ enum Step { CheckZero, } -impl crate::protocol::Step for Step {} +impl crate::protocol::Substep for Step {} impl AsRef for Step { fn as_ref(&self) -> &str { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 275a0ea8e..27906650d 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -35,14 +35,14 @@ use std::{ /// /// Steps are therefore composed into a `UniqueStepIdentifier`, which collects the complete /// hierarchy of steps at each layer into a unique identifier. -pub trait Step: AsRef {} +pub trait Substep: AsRef {} // In test code, allow a string (or string reference) to be used as a `Step`. #[cfg(any(feature = "test-fixture", debug_assertions))] -impl Step for String {} +impl Substep for String {} #[cfg(any(feature = "test-fixture", debug_assertions))] -impl Step for str {} +impl Substep for str {} /// The representation of a unique step in protocol execution. /// @@ -73,28 +73,28 @@ impl Step for str {} derive(serde::Deserialize), serde(from = "&str") )] -pub struct UniqueStepId { +pub struct Step { id: String, /// This tracks the different values that have been provided to `narrow()`. #[cfg(debug_assertions)] used: Arc>>, } -impl Hash for UniqueStepId { +impl Hash for Step { fn hash(&self, state: &mut H) { state.write(self.id.as_bytes()); } } -impl PartialEq for UniqueStepId { +impl PartialEq for Step { fn eq(&self, other: &Self) -> bool { self.id == other.id } } -impl Eq for UniqueStepId {} +impl Eq for Step {} -impl UniqueStepId { +impl Step { #[must_use] pub fn from_step_id(step: &Self) -> Self { Self { @@ -109,7 +109,7 @@ impl UniqueStepId { /// In a debug build, this checks that the same refine call isn't run twice and that the string /// value of the step doesn't include '/' (which would lead to a bad outcome). #[must_use] - pub fn narrow(&self, step: &S) -> Self { + pub fn narrow(&self, step: &S) -> Self { #[cfg(debug_assertions)] { let s = String::from(step.as_ref()); @@ -130,7 +130,7 @@ impl UniqueStepId { } } -impl Default for UniqueStepId { +impl Default for Step { // TODO(mt): this should might be better if it were to be constructed from // a QueryId rather than using a default. fn default() -> Self { @@ -142,16 +142,16 @@ impl Default for UniqueStepId { } } -impl AsRef for UniqueStepId { +impl AsRef for Step { fn as_ref(&self) -> &str { self.id.as_str() } } -impl From<&str> for UniqueStepId { +impl From<&str> for Step { fn from(id: &str) -> Self { let id = id.strip_prefix('/').unwrap_or(id); - UniqueStepId { + Step { id: id.to_owned(), #[cfg(debug_assertions)] used: Arc::new(Mutex::new(HashSet::new())), @@ -170,7 +170,7 @@ pub enum IpaProtocolStep { Attribution, } -impl Step for IpaProtocolStep {} +impl Substep for IpaProtocolStep {} impl AsRef for IpaProtocolStep { fn as_ref(&self) -> &str { @@ -192,7 +192,7 @@ impl AsRef for IpaProtocolStep { } } -impl Debug for UniqueStepId { +impl Debug for Step { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "step={}", self.id) } diff --git a/src/protocol/modulus_conversion/convert_shares.rs b/src/protocol/modulus_conversion/convert_shares.rs index 77a2379dc..9ab9b77a4 100644 --- a/src/protocol/modulus_conversion/convert_shares.rs +++ b/src/protocol/modulus_conversion/convert_shares.rs @@ -27,7 +27,7 @@ enum Step { BinaryReveal, } -impl crate::protocol::Step for Step {} +impl crate::protocol::Substep for Step {} impl AsRef for Step { fn as_ref(&self) -> &str { diff --git a/src/protocol/modulus_conversion/double_random.rs b/src/protocol/modulus_conversion/double_random.rs index d12d41d3a..8aa7ea05f 100644 --- a/src/protocol/modulus_conversion/double_random.rs +++ b/src/protocol/modulus_conversion/double_random.rs @@ -18,7 +18,7 @@ enum Step { Xor2, } -impl crate::protocol::Step for Step {} +impl crate::protocol::Substep for Step {} impl AsRef for Step { fn as_ref(&self) -> &str { diff --git a/src/protocol/mul/malicious.rs b/src/protocol/mul/malicious.rs index 97d9de995..e02f044a7 100644 --- a/src/protocol/mul/malicious.rs +++ b/src/protocol/mul/malicious.rs @@ -14,7 +14,7 @@ enum Step { RandomnessForValidation, } -impl crate::protocol::Step for Step {} +impl crate::protocol::Substep for Step {} impl AsRef for Step { fn as_ref(&self) -> &str { diff --git a/src/protocol/prss.rs b/src/protocol/prss.rs index 892af488d..744e8c3ee 100644 --- a/src/protocol/prss.rs +++ b/src/protocol/prss.rs @@ -17,7 +17,7 @@ use std::{ use std::{collections::HashSet, fmt::Formatter}; use x25519_dalek::{EphemeralSecret, PublicKey}; -use super::UniqueStepId; +use super::Step; /// Keeps track of all indices used to generate shared randomness inside `IndexedSharedRandomness`. /// Any two indices provided to `IndexesSharedRandomness::generate_values` must be unique. @@ -225,7 +225,7 @@ impl Endpoint { /// # Panics /// When used incorrectly. For instance, if you ask for an RNG and then ask /// for a PRSS using the same key. - pub fn indexed(&self, key: &UniqueStepId) -> Arc { + pub fn indexed(&self, key: &Step) -> Arc { self.inner.lock().unwrap().indexed(key.as_ref()) } @@ -235,7 +235,7 @@ impl Endpoint { /// This can only be called once. After that, calls to this function or `indexed` will panic. pub fn sequential( &self, - key: &UniqueStepId, + key: &Step, ) -> (SequentialSharedRandomness, SequentialSharedRandomness) { self.inner.lock().unwrap().sequential(key.as_ref()) } @@ -396,7 +396,7 @@ impl Generator { #[cfg(test)] pub mod test { use super::{Generator, KeyExchange, SequentialSharedRandomness}; - use crate::{ff::Fp31, protocol::UniqueStepId, test_fixture::make_participants}; + use crate::{ff::Fp31, protocol::Step, test_fixture::make_participants}; use rand::prelude::SliceRandom; use rand::{thread_rng, Rng}; use std::mem::drop; @@ -455,7 +455,7 @@ pub mod test { const IDX: u128 = 7; let (p1, p2, p3) = make_participants(); - let step = UniqueStepId::default(); + let step = Step::default(); let (r1_l, r1_r) = p1.indexed(&step).generate_values(IDX); assert_ne!(r1_l, r1_r); let (r2_l, r2_r) = p2.indexed(&step).generate_values(IDX); @@ -473,7 +473,7 @@ pub mod test { const IDX: u128 = 7; let (p1, p2, p3) = make_participants(); - let step = UniqueStepId::default(); + let step = Step::default(); let z1 = p1.indexed(&step).zero_u128(IDX); let z2 = p2.indexed(&step).zero_u128(IDX); let z3 = p3.indexed(&step).zero_u128(IDX); @@ -486,7 +486,7 @@ pub mod test { const IDX: u128 = 7; let (p1, p2, p3) = make_participants(); - let step = UniqueStepId::default(); + let step = Step::default(); let z1 = p1.indexed(&step).zero_xor(IDX); let z2 = p2.indexed(&step).zero_xor(IDX); let z3 = p3.indexed(&step).zero_xor(IDX); @@ -500,7 +500,7 @@ pub mod test { const IDX2: u128 = 21362; let (p1, p2, p3) = make_participants(); - let step = UniqueStepId::default(); + let step = Step::default(); let r1 = p1.indexed(&step).random_u128(IDX1); let r2 = p2.indexed(&step).random_u128(IDX1); let r3 = p3.indexed(&step).random_u128(IDX1); @@ -523,7 +523,7 @@ pub mod test { // These tests do not check that left != right because // the field might not be large enough. - let step = UniqueStepId::default(); + let step = Step::default(); let (r1_l, r1_r): (Fp31, Fp31) = p1.indexed(&step).generate_fields(IDX); let (r2_l, r2_r) = p2.indexed(&step).generate_fields(IDX); let (r3_l, r3_r) = p3.indexed(&step).generate_fields(IDX); @@ -538,7 +538,7 @@ pub mod test { const IDX: u128 = 72; let (p1, p2, p3) = make_participants(); - let step = UniqueStepId::default(); + let step = Step::default(); let z1: Fp31 = p1.indexed(&step).zero(IDX); let z2 = p2.indexed(&step).zero(IDX); let z3 = p3.indexed(&step).zero(IDX); @@ -552,7 +552,7 @@ pub mod test { const IDX2: u128 = 12634; let (p1, p2, p3) = make_participants(); - let step = UniqueStepId::default(); + let step = Step::default(); let s1 = p1.indexed(&step); let s2 = p2.indexed(&step); let s3 = p3.indexed(&step); @@ -588,7 +588,7 @@ pub mod test { } let (p1, p2, p3) = make_participants(); - let step = UniqueStepId::default(); + let step = Step::default(); let (rng1_l, rng1_r) = p1.sequential(&step); let (rng2_l, rng2_r) = p2.sequential(&step); let (rng3_l, rng3_r) = p3.sequential(&step); @@ -602,7 +602,7 @@ pub mod test { fn indexed_and_sequential() { let (p1, _p2, _p3) = make_participants(); - let base = UniqueStepId::default(); + let base = Step::default(); let idx = p1.indexed(&base.narrow("indexed")); let (mut s_left, mut s_right) = p1.sequential(&base.narrow("sequential")); let (i_left, i_right) = idx.generate_values(0_u128); @@ -621,7 +621,7 @@ pub mod test { fn indexed_then_sequential() { let (p1, _p2, _p3) = make_participants(); - let step = UniqueStepId::default().narrow("test"); + let step = Step::default().narrow("test"); drop(p1.indexed(&step)); // TODO(alex): remove after clippy stops aggroing with no reason // https://github.com/private-attribution/ipa/actions/runs/3340348412/jobs/5530341996 @@ -634,7 +634,7 @@ pub mod test { fn sequential_then_indexed() { let (p1, _p2, _p3) = make_participants(); - let step = UniqueStepId::default().narrow("test"); + let step = Step::default().narrow("test"); // TODO(alex): remove after clippy is fixed #[allow(clippy::let_underscore_drop)] let _ = p1.sequential(&step); @@ -644,7 +644,7 @@ pub mod test { #[test] fn indexed_accepts_unique_index() { let (_, p2, _p3) = make_participants(); - let step = UniqueStepId::default().narrow("test"); + let step = Step::default().narrow("test"); let mut indices = (1..100_u128).collect::>(); indices.shuffle(&mut thread_rng()); let indexed_prss = p2.indexed(&step); @@ -659,7 +659,7 @@ pub mod test { #[should_panic] fn indexed_rejects_the_same_index() { let (p1, _p2, _p3) = make_participants(); - let step = UniqueStepId::default().narrow("test"); + let step = Step::default().narrow("test"); let _ = p1.indexed(&step).random_u128(100_u128); let _ = p1.indexed(&step).random_u128(100_u128); diff --git a/src/protocol/sort/mod.rs b/src/protocol/sort/mod.rs index 05829e3ba..1474781c8 100644 --- a/src/protocol/sort/mod.rs +++ b/src/protocol/sort/mod.rs @@ -1,4 +1,4 @@ -use super::Step; +use super::Substep; use std::fmt::Debug; mod apply; @@ -17,7 +17,7 @@ pub enum SortStep { ComposeStep, } -impl Step for SortStep {} +impl Substep for SortStep {} impl AsRef for SortStep { fn as_ref(&self) -> &str { @@ -37,7 +37,7 @@ pub enum ShuffleStep { Step3, } -impl Step for ShuffleStep {} +impl Substep for ShuffleStep {} impl AsRef for ShuffleStep { fn as_ref(&self) -> &str { @@ -56,7 +56,7 @@ pub enum ApplyInvStep { RevealPermutation, } -impl Step for ApplyInvStep {} +impl Substep for ApplyInvStep {} impl AsRef for ApplyInvStep { fn as_ref(&self) -> &str { @@ -75,7 +75,7 @@ pub enum ComposeStep { UnshuffleRho, } -impl Step for ComposeStep {} +impl Substep for ComposeStep {} impl AsRef for ComposeStep { fn as_ref(&self) -> &str { diff --git a/src/protocol/sort/shuffle.rs b/src/protocol/sort/shuffle.rs index d87b089fd..d5bd65b26 100644 --- a/src/protocol/sort/shuffle.rs +++ b/src/protocol/sort/shuffle.rs @@ -8,7 +8,7 @@ use crate::{ error::BoxError, ff::Field, helpers::{Direction, Role}, - protocol::{context::ProtocolContext, prss::IndexedSharedRandomness, RecordId, Step}, + protocol::{context::ProtocolContext, prss::IndexedSharedRandomness, RecordId, Substep}, secret_sharing::Replicated, }; @@ -24,7 +24,7 @@ enum ShuffleOrUnshuffle { Unshuffle, } -impl Step for ShuffleOrUnshuffle {} +impl Substep for ShuffleOrUnshuffle {} impl AsRef for ShuffleOrUnshuffle { fn as_ref(&self) -> &str { match self { @@ -215,7 +215,7 @@ mod tests { get_two_of_three_random_permutations, shuffle_shares, unshuffle_shares, ShuffleOrUnshuffle, }, - QueryId, UniqueStepId, + QueryId, Step, }, test_fixture::{ generate_shares, make_contexts, make_participants, make_world, narrow_contexts, @@ -231,7 +231,7 @@ mod tests { logging::setup(); let (p1, p2, p3) = make_participants(); - let step = UniqueStepId::default(); + let step = Step::default(); let perm1 = get_two_of_three_random_permutations(BATCH_SIZE, p1.indexed(&step).as_ref()); let perm2 = get_two_of_three_random_permutations(BATCH_SIZE, p2.indexed(&step).as_ref()); let perm3 = get_two_of_three_random_permutations(BATCH_SIZE, p3.indexed(&step).as_ref()); diff --git a/src/test_fixture/fabric.rs b/src/test_fixture/fabric.rs index 20d51293a..44e8cf636 100644 --- a/src/test_fixture/fabric.rs +++ b/src/test_fixture/fabric.rs @@ -8,7 +8,7 @@ use std::pin::Pin; use crate::helpers; use crate::helpers::fabric::{ChannelId, MessageChunks, Network}; use crate::helpers::{Error, Role}; -use crate::protocol::UniqueStepId; +use crate::protocol::Step; use async_trait::async_trait; use futures::Sink; use futures::StreamExt; @@ -41,7 +41,7 @@ pub struct InMemoryEndpoint { pub role: Role, /// Channels that this endpoint is listening to. There are two helper peers for 3 party setting. /// For each peer there are multiple channels open, one per query + step. - channels: Arc>>>, + channels: Arc>>>, tx: Sender, rx: Arc>>>, network: Weak, diff --git a/src/test_fixture/mod.rs b/src/test_fixture/mod.rs index 5d1e694aa..9e6b2e208 100644 --- a/src/test_fixture/mod.rs +++ b/src/test_fixture/mod.rs @@ -8,7 +8,7 @@ use crate::ff::{Field, Fp31}; use crate::helpers::Role; use crate::protocol::context::ProtocolContext; use crate::protocol::prss::Endpoint as PrssEndpoint; -use crate::protocol::Step; +use crate::protocol::Substep; use crate::secret_sharing::{Replicated, SecretSharing}; use rand::rngs::mock::StepRng; use rand::thread_rng; @@ -45,7 +45,7 @@ pub fn make_contexts( #[must_use] pub fn narrow_contexts<'a, F: Field, S: SecretSharing>( contexts: &[ProtocolContext<'a, S, F>; 3], - step: &impl Step, + step: &impl Substep, ) -> [ProtocolContext<'a, S, F>; 3] { // This really wants <[_; N]>::each_ref() contexts From db168e7399453b7db831dfa67b5f39cf09f11064 Mon Sep 17 00:00:00 2001 From: Taiki Yamaguchi Date: Tue, 15 Nov 2022 11:20:11 +0800 Subject: [PATCH 19/24] Simplify the OR arithmetic --- src/protocol/boolean/prefix_or.rs | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/protocol/boolean/prefix_or.rs b/src/protocol/boolean/prefix_or.rs index 0d1e11fce..3c7afd6c6 100644 --- a/src/protocol/boolean/prefix_or.rs +++ b/src/protocol/boolean/prefix_or.rs @@ -32,27 +32,16 @@ impl<'a, F: Field> PrefixOr<'a, F> { } /// Securely computes `[a] | [b] where a, b ∈ {0, 1} ⊆ F_p` - /// - /// * OR can be computed as: `[a] ^ [b] ^ MULT([a], [b])` - /// * XOR([a], [b]) is: `[a] + [b] - 2([a] * [b])` - /// - /// Therefore, - /// - /// let [c] = [a] ^ [b] - /// [c] + [ab] - 2([c] * [ab]) + /// OR can be computed as: `!MULT(![a], ![b])` async fn bit_or( a: Replicated, b: Replicated, ctx: ProtocolContext<'_, Replicated, F>, record_id: RecordId, ) -> Result, BoxError> { - let ab = ctx.narrow(&Step::AMultB).multiply(record_id, a, b).await?; - let c = a + b - (ab * F::from(2)); - let cab = ctx - .narrow(&Step::ABMultC) - .multiply(record_id, c, ab) - .await?; - Ok(c + ab - (cab * F::from(2))) + let one = Replicated::one(ctx.role()); + let result = ctx.multiply(record_id, one - a, one - b).await?; + Ok(one - result) } /// Securely computes `∨ [a_1],...[a_n]` @@ -310,8 +299,6 @@ impl<'a, F: Field> PrefixOr<'a, F> { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum Step { BitwiseOrPerBlock, - AMultB, - ABMultC, BlockWisePrefixOr, InnerProduct, GetFirstBlockWithOne, @@ -323,8 +310,6 @@ impl crate::protocol::Step for Step {} impl AsRef for Step { fn as_ref(&self) -> &str { match self { - Self::AMultB => "a_mult_b", - Self::ABMultC => "ab_mult_c", Self::BitwiseOrPerBlock => "bitwise_or_per_block", Self::BlockWisePrefixOr => "block_wise_prefix_or", Self::InnerProduct => "inner_product", From eed0d94654bc0ad87370ac550df609a3880aee04 Mon Sep 17 00:00:00 2001 From: Taiki Yamaguchi Date: Tue, 15 Nov 2022 11:39:03 +0800 Subject: [PATCH 20/24] in-line PrefixOr::new() in the unit test --- src/protocol/boolean/prefix_or.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/protocol/boolean/prefix_or.rs b/src/protocol/boolean/prefix_or.rs index 3c7afd6c6..3bfc37b0d 100644 --- a/src/protocol/boolean/prefix_or.rs +++ b/src/protocol/boolean/prefix_or.rs @@ -352,14 +352,11 @@ mod tests { .unzip(); // Execute - let pre0 = PrefixOr::new(&s0); - let pre1 = PrefixOr::new(&s1); - let pre2 = PrefixOr::new(&s2); let step = "PrefixOr_Test"; let result = try_join_all(vec![ - pre0.execute(ctx[0].narrow(step), RecordId::from(0_u32)), - pre1.execute(ctx[1].narrow(step), RecordId::from(0_u32)), - pre2.execute(ctx[2].narrow(step), RecordId::from(0_u32)), + PrefixOr::new(&s0).execute(ctx[0].narrow(step), RecordId::from(0_u32)), + PrefixOr::new(&s1).execute(ctx[1].narrow(step), RecordId::from(0_u32)), + PrefixOr::new(&s2).execute(ctx[2].narrow(step), RecordId::from(0_u32)), ]) .await .unwrap(); From 1f2a3860b765779b323b4b5a75b0984517e818e6 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 15 Nov 2022 14:53:38 +1100 Subject: [PATCH 21/24] Add standard distribution to prime fields This should make things easier to work with in tests. --- src/ff/prime_field.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/ff/prime_field.rs b/src/ff/prime_field.rs index c5bdde435..ff59e29f6 100644 --- a/src/ff/prime_field.rs +++ b/src/ff/prime_field.rs @@ -87,6 +87,12 @@ macro_rules! field_impl { } } + impl rand::distributions::Distribution<$field> for rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> $field { + <$field>::from(rng.gen::()) + } + } + impl std::fmt::Debug for $field { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}_mod{}", self.0, Self::PRIME) From 088808a99f3b942d29e4d10e4bec53160f802ee5 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 15 Nov 2022 15:05:16 +1100 Subject: [PATCH 22/24] Convert tests --- src/protocol/malicious.rs | 6 +++--- .../modulus_conversion/specialized_mul.rs | 16 ++++++++-------- src/protocol/mul/semi_honest.rs | 7 ++++++- src/protocol/reveal.rs | 9 +++------ src/secret_sharing/malicious_replicated.rs | 14 +++++++------- src/test_fixture/mod.rs | 9 +++++++-- src/test_fixture/sharing.rs | 15 ++++++++++----- 7 files changed, 44 insertions(+), 32 deletions(-) diff --git a/src/protocol/malicious.rs b/src/protocol/malicious.rs index 42484e9b6..4e5ab196c 100644 --- a/src/protocol/malicious.rs +++ b/src/protocol/malicious.rs @@ -232,8 +232,8 @@ pub mod tests { let context = make_contexts::(&world); let mut rng = rand::thread_rng(); - let a = Fp31::from(rng.gen::()); - let b = Fp31::from(rng.gen::()); + let a = rng.gen::(); + let b = rng.gen::(); let a_shares = share(a, &mut rng); let b_shares = share(b, &mut rng); @@ -327,7 +327,7 @@ pub mod tests { let mut original_inputs = Vec::with_capacity(100); for _ in 0..100 { - let x = Fp31::from(rng.gen::()); + let x = rng.gen::(); original_inputs.push(x); } let shared_inputs: Vec<[Replicated; 3]> = original_inputs diff --git a/src/protocol/modulus_conversion/specialized_mul.rs b/src/protocol/modulus_conversion/specialized_mul.rs index 6c800b91c..1b3805ac7 100644 --- a/src/protocol/modulus_conversion/specialized_mul.rs +++ b/src/protocol/modulus_conversion/specialized_mul.rs @@ -206,8 +206,8 @@ pub mod tests { let mut rng = rand::thread_rng(); for i in 0..10_u32 { - let a = Fp31::from(rng.gen::()); - let b = Fp31::from(rng.gen::()); + let a = rng.gen::(); + let b = rng.gen::(); let record_id = RecordId::from(0_u32); @@ -252,8 +252,8 @@ pub mod tests { let mut futures = Vec::with_capacity(10); for i in 0..10_u32 { - let a = Fp31::from(rng.gen::()); - let b = Fp31::from(rng.gen::()); + let a = rng.gen::(); + let b = rng.gen::(); inputs.push((a, b)); @@ -304,8 +304,8 @@ pub mod tests { let mut rng = rand::thread_rng(); for i in 0..10_u32 { - let a = Fp31::from(rng.gen::()); - let b = Fp31::from(rng.gen::()); + let a = rng.gen::(); + let b = rng.gen::(); let a_shares = share(a, &mut rng); @@ -352,8 +352,8 @@ pub mod tests { let mut futures = Vec::with_capacity(10); for i in 0..10_u32 { - let a = Fp31::from(rng.gen::()); - let b = Fp31::from(rng.gen::()); + let a = rng.gen::(); + let b = rng.gen::(); inputs.push((a, b)); diff --git a/src/protocol/mul/semi_honest.rs b/src/protocol/mul/semi_honest.rs index da9a7af54..4eb9fcc17 100644 --- a/src/protocol/mul/semi_honest.rs +++ b/src/protocol/mul/semi_honest.rs @@ -71,6 +71,8 @@ pub mod tests { make_contexts, make_world, share, validate_and_reconstruct, TestWorld, }; use futures_util::future::join_all; + use rand::distributions::Standard; + use rand::prelude::Distribution; use rand::rngs::mock::StepRng; use rand::RngCore; use std::sync::atomic::{AtomicU32, Ordering}; @@ -142,7 +144,10 @@ pub mod tests { a: u8, b: u8, rng: &mut R, - ) -> Result { + ) -> Result + where + Standard: Distribution, + { let a = F::from(u128::from(a)); let b = F::from(u128::from(b)); diff --git a/src/protocol/reveal.rs b/src/protocol/reveal.rs index 119d7e97f..9c31cd3fe 100644 --- a/src/protocol/reveal.rs +++ b/src/protocol/reveal.rs @@ -119,8 +119,7 @@ mod tests { let ctx = make_contexts::(&world); for i in 0..10_u32 { - let secret = rng.gen::(); - let input = Fp31::from(secret); + let input = rng.gen::(); let share = share(input, &mut rng); let record_id = RecordId::from(i); let results = try_join_all(vec![ @@ -144,8 +143,7 @@ mod tests { let ctx = make_contexts::(&world); for i in 0..10_u32 { - let secret = rng.gen::(); - let input = Fp31::from(secret); + let input = rng.gen::(); let share = share(input, &mut rng); let record_id = RecordId::from(i); let results = try_join_all(vec![ @@ -169,8 +167,7 @@ mod tests { let ctx = make_contexts::(&world); for i in 0..10_u32 { - let secret = rng.gen::(); - let input = Fp31::from(secret); + let input = rng.gen::(); let share = share(input, &mut rng); let record_id = RecordId::from(i); let result = try_join!( diff --git a/src/secret_sharing/malicious_replicated.rs b/src/secret_sharing/malicious_replicated.rs index 59298bb79..a2501167f 100644 --- a/src/secret_sharing/malicious_replicated.rs +++ b/src/secret_sharing/malicious_replicated.rs @@ -118,14 +118,14 @@ mod tests { fn test_local_operations() { let mut rng = rand::thread_rng(); - let a = Fp31::from(rng.gen::()); - let b = Fp31::from(rng.gen::()); - let c = Fp31::from(rng.gen::()); - let d = Fp31::from(rng.gen::()); - let e = Fp31::from(rng.gen::()); - let f = Fp31::from(rng.gen::()); + let a = rng.gen::(); + let b = rng.gen::(); + let c = rng.gen::(); + let d = rng.gen::(); + let e = rng.gen::(); + let f = rng.gen::(); // Randomization constant - let r = Fp31::from(rng.gen::()); + let r = rng.gen::(); let a_shared = share(a, &mut rng); let b_shared = share(b, &mut rng); diff --git a/src/test_fixture/mod.rs b/src/test_fixture/mod.rs index 9e6b2e208..6e33ab254 100644 --- a/src/test_fixture/mod.rs +++ b/src/test_fixture/mod.rs @@ -10,6 +10,8 @@ use crate::protocol::context::ProtocolContext; use crate::protocol::prss::Endpoint as PrssEndpoint; use crate::protocol::Substep; use crate::secret_sharing::{Replicated, SecretSharing}; +use rand::distributions::Standard; +use rand::prelude::Distribution; use rand::rngs::mock::StepRng; use rand::thread_rng; @@ -79,7 +81,10 @@ pub type ReplicatedShares = (Vec>, Vec>, Vec(input: Vec) -> ReplicatedShares { +pub fn generate_shares(input: Vec) -> ReplicatedShares +where + Standard: Distribution, +{ let mut rand = StepRng::new(100, 1); let len = input.len(); @@ -88,7 +93,7 @@ pub fn generate_shares(input: Vec) -> ReplicatedShares { let mut shares2 = Vec::with_capacity(len); for iter in input { - let share = share(T::from(iter), &mut rand); + let share = share(F::from(iter), &mut rand); shares0.push(share[0]); shares1.push(share[1]); shares2.push(share[2]); diff --git a/src/test_fixture/sharing.rs b/src/test_fixture/sharing.rs index a56b53532..0e62a67d4 100644 --- a/src/test_fixture/sharing.rs +++ b/src/test_fixture/sharing.rs @@ -1,14 +1,19 @@ use crate::ff::Field; use crate::secret_sharing::Replicated; -use rand::Rng; -use rand::RngCore; +use rand::{ + distributions::{Distribution, Standard}, + Rng, RngCore, +}; use super::ReplicatedShares; /// Shares `input` into 3 replicated secret shares using the provided `rng` implementation -pub fn share(input: F, rng: &mut R) -> [Replicated; 3] { - let x1 = F::from(rng.gen::()); - let x2 = F::from(rng.gen::()); +pub fn share(input: F, rng: &mut R) -> [Replicated; 3] +where + Standard: Distribution, +{ + let x1 = rng.gen::(); + let x2 = rng.gen::(); let x3 = input - (x1 + x2); [ From 677867b1326b29223251a512900f47b10d6f67a5 Mon Sep 17 00:00:00 2001 From: Taiki Yamaguchi Date: Tue, 15 Nov 2022 12:11:23 +0800 Subject: [PATCH 23/24] Add test cases to cover more bit lengths --- src/protocol/boolean/prefix_or.rs | 88 ++++++++++++++++--------------- 1 file changed, 45 insertions(+), 43 deletions(-) diff --git a/src/protocol/boolean/prefix_or.rs b/src/protocol/boolean/prefix_or.rs index 3bfc37b0d..fa4cd2de1 100644 --- a/src/protocol/boolean/prefix_or.rs +++ b/src/protocol/boolean/prefix_or.rs @@ -333,7 +333,7 @@ mod tests { use rand::{rngs::mock::StepRng, Rng}; use std::iter::zip; - const BITS: usize = 32; + const BITS: [usize; 2] = [16, 32]; const TEST_TRIES: usize = 16; async fn prefix_or(input: &[F]) -> Result, BoxError> { @@ -373,25 +373,26 @@ mod tests { pub async fn fp2() -> Result<(), BoxError> { let mut rng = rand::thread_rng(); - // Test 32-bit bitwise shares with randomly distributed bits, for 16 times. - // The probability of i'th bit being 0 is 1/2^i, so this test covers inputs - // that have all 0's in 5 first bits. - for _ in 0..TEST_TRIES { - let len = BITS; - let input: Vec = (0..len).map(|_| Fp2::from(rng.gen::())).collect(); - let mut expected: Vec = Vec::with_capacity(len); - - // Calculate Prefix-Or of the secret number - input.iter().fold(Fp2::ZERO, |acc, &x| { - expected.push(acc | x); - acc | x - }); - - let result = prefix_or(&input).await?; - - // Verify - assert_eq!(expected.len(), result.len()); - zip(expected, result).for_each(|(e, r)| assert_eq!(e, r)); + // Test n-bit (n = BITS[i]) bitwise shares with randomly distributed + // bits, for 16 times. The probability of i'th bit being 0 is 1/2^i, + // so this test covers inputs that have all 0's in 5 first bits. + for len in BITS { + for _ in 0..TEST_TRIES { + let input: Vec = (0..len).map(|_| Fp2::from(rng.gen::())).collect(); + let mut expected: Vec = Vec::with_capacity(len); + + // Calculate Prefix-Or of the secret number + input.iter().fold(Fp2::ZERO, |acc, &x| { + expected.push(acc | x); + acc | x + }); + + let result = prefix_or(&input).await?; + + // Verify + assert_eq!(expected.len(), result.len()); + zip(expected, result).for_each(|(e, r)| assert_eq!(e, r)); + } } Ok(()) @@ -402,29 +403,30 @@ mod tests { pub async fn fp31() -> Result<(), BoxError> { let mut rng = rand::thread_rng(); - // Test 32-bit bitwise shares with randomly distributed bits, for 16 times. - // The probability of i'th bit being 0 is 1/2^i, so this test covers inputs - // that have all 0's in 5 first bits. - for _ in 0..TEST_TRIES { - let len = BITS; - // Generate a vector of Fp31::ZERO or Fp31::ONE from randomly picked bool values - let input: Vec = (0..len) - .map(|_| Fp31::from(u128::from(rng.gen::()))) - .collect(); - let mut expected: Vec = Vec::with_capacity(len); - - // Calculate Prefix-Or of the secret number - input.iter().fold(0, |acc, &x| { - let sum = acc + x.as_u128(); - expected.push(Fp31::from(sum > 0)); - sum - }); - - let result = prefix_or(&input).await?; - - // Verify - assert_eq!(expected.len(), result.len()); - zip(expected, result).for_each(|(e, r)| assert_eq!(e, r)); + // Test n-bit (n = BITS[i]) bitwise shares with randomly distributed + // bits, for 16 times. The probability of i'th bit being 0 is 1/2^i, + // so this test covers inputs that have all 0's in 5 first bits. + for len in BITS { + for _ in 0..TEST_TRIES { + // Generate a vector of Fp31::ZERO or Fp31::ONE from randomly picked bool values + let input: Vec = (0..len) + .map(|_| Fp31::from(u128::from(rng.gen::()))) + .collect(); + let mut expected: Vec = Vec::with_capacity(len); + + // Calculate Prefix-Or of the secret number + input.iter().fold(0, |acc, &x| { + let sum = acc + x.as_u128(); + expected.push(Fp31::from(sum > 0)); + sum + }); + + let result = prefix_or(&input).await?; + + // Verify + assert_eq!(expected.len(), result.len()); + zip(expected, result).for_each(|(e, r)| assert_eq!(e, r)); + } } Ok(()) From b9c7d5a2bfc9c05c688f68628a021f67582a6c76 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 15 Nov 2022 15:12:39 +1100 Subject: [PATCH 24/24] Use const instead of static --- src/helpers/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/helpers/mod.rs b/src/helpers/mod.rs index 21d2d9769..be23f9ecb 100644 --- a/src/helpers/mod.rs +++ b/src/helpers/mod.rs @@ -46,7 +46,7 @@ impl Role { #[must_use] pub fn all() -> &'static [Role; 3] { - static VARIANTS: &[Role; 3] = &[Role::H1, Role::H2, Role::H3]; + const VARIANTS: &[Role; 3] = &[Role::H1, Role::H2, Role::H3]; VARIANTS }