Skip to content

Commit

Permalink
Merge from main
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Nov 15, 2022
2 parents 24db628 + f8556eb commit 6dd9b4b
Show file tree
Hide file tree
Showing 37 changed files with 749 additions and 681 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ hyper-tls = { version = "0.5.0", optional = true }
metrics = "0.20.1"
metrics-tracing-context = { version = "0.12.0", optional = true }
metrics-util = { version = "0.14.0", optional = true }
permutation = "0.4.1"
pin-project = "1.0.12"
rand = "0.8"
rand_chacha = "0.3.1"
Expand All @@ -44,6 +43,7 @@ serde = { version = "1.0", optional = true, features = ["derive"] }
serde_json = { version = "1.0", optional = true }
sha2 = "0.10.6"
thiserror = "1.0"
tinyvec = { version = "1.6.0" }
tokio = { version = "1.21.2", optional = true, features = ["rt", "rt-multi-thread", "macros"] }
tokio-stream = { version = "0.1.10", optional = true }
tokio-util = { version = "0.7.4", optional = true }
Expand Down
6 changes: 6 additions & 0 deletions src/ff/prime_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ macro_rules! field_impl {
}
}

impl rand::distributions::Distribution<$field> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> $field {
<$field>::from(rng.gen::<u128>())
}
}

impl std::fmt::Debug for $field {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}_mod{}", self.0, Self::PRIME)
Expand Down
39 changes: 26 additions & 13 deletions src/helpers/buffers/receive.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::helpers::fabric::{ChannelId, MessageEnvelope};
use crate::helpers::fabric::ChannelId;
use crate::helpers::{MessagePayload, MESSAGE_PAYLOAD_SIZE_BYTES};
use crate::protocol::RecordId;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
Expand All @@ -13,14 +14,15 @@ use tokio::sync::oneshot;
#[allow(clippy::module_name_repetitions)]
pub struct ReceiveBuffer {
inner: HashMap<ChannelId, HashMap<RecordId, ReceiveBufItem>>,
record_ids: HashMap<ChannelId, RecordId>,
}

#[derive(Debug)]
enum ReceiveBufItem {
/// There is an outstanding request to receive the message but this helper hasn't seen it yet
Requested(oneshot::Sender<Box<[u8]>>),
Requested(oneshot::Sender<MessagePayload>),
/// Message has been received but nobody requested it yet
Received(Box<[u8]>),
Received(MessagePayload),
}

impl ReceiveBuffer {
Expand All @@ -29,7 +31,7 @@ impl ReceiveBuffer {
&mut self,
channel_id: ChannelId,
record_id: RecordId,
sender: oneshot::Sender<Box<[u8]>>,
sender: oneshot::Sender<MessagePayload>,
) {
match self.inner.entry(channel_id).or_default().entry(record_id) {
Entry::Occupied(entry) => match entry.remove() {
Expand All @@ -48,29 +50,40 @@ impl ReceiveBuffer {
}
}

/// Process message that has been received
pub fn receive_messages(&mut self, channel_id: &ChannelId, messages: Vec<MessageEnvelope>) {
for msg in messages {
/// Process messages that has been received. It assumes messages arriving in order, so first
/// chunk will belong to range of records [0..chunk.len()), second chunk [chunk.len()..2*chunk.len())
/// etc. It does not require all chunks to be of the same size, this assumption is baked in
/// send buffers.
pub fn receive_messages(&mut self, channel_id: &ChannelId, messages: &[u8]) {
let offset = self
.record_ids
.entry(channel_id.clone())
.or_insert_with(|| RecordId::from(0_u32));

for msg in messages.chunks(MESSAGE_PAYLOAD_SIZE_BYTES) {
let payload = msg.try_into().unwrap();
match self
.inner
.entry(channel_id.clone())
.or_default()
.entry(msg.record_id)
.entry(*offset)
{
Entry::Occupied(entry) => match entry.remove() {
ReceiveBufItem::Requested(s) => {
s.send(msg.payload).unwrap_or_else(|_| {
tracing::warn!("No listener for message {:?}", msg.record_id);
s.send(payload).unwrap_or_else(|_| {
tracing::warn!("No listener for message {:?}", offset);
});
}
ReceiveBufItem::Received(_) => {
panic!("Duplicate message for the same record {:?}", msg.record_id);
panic!("Duplicate message for the same record {:?}", offset)
}
},
Entry::Vacant(entry) => {
entry.insert(ReceiveBufItem::Received(msg.payload));
entry.insert(ReceiveBufItem::Received(payload));
}
}
};

*offset += 1;
}
}
}
118 changes: 33 additions & 85 deletions src/helpers/buffers/send.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::helpers::buffers::fsv::FixedSizeByteVec;
use crate::helpers::fabric::{ChannelId, MessageEnvelope};
use crate::helpers::{MessagePayload, MESSAGE_PAYLOAD_SIZE_BYTES};
use crate::protocol::RecordId;
use std::collections::HashMap;
use std::ops::Range;

/// Use the buffer that allocates 8 bytes per element. It could probably go down to 4 if the
/// only thing IPA sends is a single field value. To support arbitrarily sized values, it needs
/// to be at least 16 bytes to be able to store a fat pointer in it.
type ByteBuf = FixedSizeByteVec<8>;
type ByteBuf = FixedSizeByteVec<{ MESSAGE_PAYLOAD_SIZE_BYTES }>;

/// Buffer that keeps messages that must be sent to other helpers
#[derive(Debug)]
Expand All @@ -32,7 +33,7 @@ pub enum PushError {
Duplicate {
channel_id: ChannelId,
record_id: RecordId,
previous_value: Box<[u8]>,
previous_value: MessagePayload,
},
}

Expand Down Expand Up @@ -67,14 +68,11 @@ impl SendBuffer {
}
}

/// TODO: change the output to Vec<u8> - we no longer need a wrapper. The raw byte vector
/// will be communicated down to the network layer.
#[allow(clippy::needless_pass_by_value)] // will be fixed when tiny/smallvec is used
pub fn push(
&mut self,
channel_id: &ChannelId,
msg: MessageEnvelope,
) -> Result<Option<Vec<MessageEnvelope>>, PushError> {
msg: &MessageEnvelope,
) -> Result<Option<Vec<u8>>, PushError> {
assert!(
msg.payload.len() <= ByteBuf::ELEMENT_SIZE_BYTES,
"Message payload exceeds the maximum allowed size"
Expand Down Expand Up @@ -110,30 +108,11 @@ impl SendBuffer {
return Err(PushError::Duplicate {
record_id: msg.record_id,
channel_id: channel_id.clone(),
previous_value: Box::new(v),
previous_value: v.try_into().unwrap(),
});
}

Ok(if let Some(data) = buf.take() {
// The next chunk is ready to be drained as byte vec has accumulated enough elements
// in its first region. Drain it and move the elements to the caller.
// TODO: get rid of `Vec<MessageEnvelope>` and move `Vec<u8>` instead.
let start_record_id = buf.elements_drained() - data.len() / ByteBuf::ELEMENT_SIZE_BYTES;

let envs = data
.chunks(ByteBuf::ELEMENT_SIZE_BYTES)
.enumerate()
.map(|(i, chunk)| {
let record_id = RecordId::from(start_record_id + i);
let payload = chunk.to_vec().into_boxed_slice();
MessageEnvelope { record_id, payload }
})
.collect::<Vec<_>>();

Some(envs)
} else {
None
})
Ok(buf.take())
}
}

Expand All @@ -157,21 +136,21 @@ impl Config {

#[cfg(test)]
mod tests {
use crate::helpers::buffers::send::{Config, PushError};
use crate::helpers::buffers::send::{ByteBuf, Config, PushError};
use crate::helpers::buffers::SendBuffer;
use crate::helpers::Role;
use crate::protocol::{RecordId, UniqueStepId};
use rand::seq::SliceRandom;
use rand::thread_rng;
use std::cmp::Ordering;
use crate::protocol::{RecordId, Step};

use tinyvec::array_vec;

use crate::helpers::fabric::{ChannelId, MessageEnvelope};

impl Clone for MessageEnvelope {
fn clone(&self) -> Self {
MessageEnvelope {
record_id: self.record_id,
payload: self.payload.clone(),
// tinyvec implements copy for small arrays
payload: self.payload,
}
}
}
Expand All @@ -183,66 +162,63 @@ mod tests {
let msg = empty_msg(record_id);

assert!(matches!(
buf.push(&ChannelId::new(Role::H1, UniqueStepId::default()), msg),
buf.push(&ChannelId::new(Role::H1, Step::default()), &msg),
Err(PushError::OutOfRange { .. }),
));
}

#[test]
fn does_not_corrupt_messages() {
let c1 = ChannelId::new(Role::H1, UniqueStepId::default());
let c1 = ChannelId::new(Role::H1, Step::default());
let mut buf = SendBuffer::new(Config::default().items_in_batch(10));

let batch = (0u8..10)
.find_map(|i| {
let msg = MessageEnvelope {
record_id: RecordId::from(u32::from(i)),
payload: i.to_le_bytes().to_vec().into_boxed_slice(),
payload: array_vec!([u8; ByteBuf::ELEMENT_SIZE_BYTES] => i),
};
buf.push(&c1, msg).ok().flatten()
buf.push(&c1, &msg).ok().flatten()
})
.unwrap();

for v in batch {
let payload = u64::from_le_bytes(v.payload.as_ref().try_into().unwrap());
for (i, v) in batch.chunks(ByteBuf::ELEMENT_SIZE_BYTES).enumerate() {
let payload = u64::from_le_bytes(v.try_into().unwrap());
assert!(payload < u64::from(u8::MAX));

assert_eq!(
u32::from(u8::try_from(payload).unwrap()),
u32::from(v.record_id),
);
assert_eq!(usize::from(u8::try_from(payload).unwrap()), i);
}
}

#[test]
fn offset_is_per_channel() {
let mut buf = SendBuffer::new(Config::default());
let c1 = ChannelId::new(Role::H1, UniqueStepId::default());
let c2 = ChannelId::new(Role::H2, UniqueStepId::default());
let c1 = ChannelId::new(Role::H1, Step::default());
let c2 = ChannelId::new(Role::H2, Step::default());

let m1 = empty_msg(0);
let m2 = empty_msg(1);

buf.push(&c1, m1).unwrap();
buf.push(&c1, m2.clone()).unwrap();
buf.push(&c1, &m1).unwrap();
buf.push(&c1, &m2).unwrap();

assert!(matches!(
buf.push(&c2, m2),
buf.push(&c2, &m2),
Err(PushError::OutOfRange { .. }),
));
}

#[test]
fn rejects_duplicates() {
let mut buf = SendBuffer::new(Config::default().items_in_batch(10));
let channel = ChannelId::new(Role::H1, UniqueStepId::default());
let channel = ChannelId::new(Role::H1, Step::default());
let record_id = RecordId::from(3_u32);
let m1 = empty_msg(record_id);
let m2 = empty_msg(record_id);

assert!(matches!(buf.push(&channel, m1), Ok(None)));
assert!(matches!(buf.push(&channel, &m1), Ok(None)));
assert!(matches!(
buf.push(&channel, m2),
buf.push(&channel, &m2),
Err(PushError::Duplicate { .. })
));
}
Expand All @@ -253,50 +229,22 @@ mod tests {
let msg = empty_msg(5);

assert!(matches!(
buf.push(&ChannelId::new(Role::H1, UniqueStepId::default()), msg),
buf.push(&ChannelId::new(Role::H1, Step::default()), &msg),
Ok(None)
));
}

#[test]
fn accepts_records_from_next_range_after_flushing() {
let mut buf = SendBuffer::new(Config::default());
let channel = ChannelId::new(Role::H1, UniqueStepId::default());
let channel = ChannelId::new(Role::H1, Step::default());
let next_msg = empty_msg(1);
let this_msg = empty_msg(0);

// this_msg belongs to current range, should be accepted
assert!(matches!(buf.push(&channel, this_msg), Ok(Some(_))));
assert!(matches!(buf.push(&channel, &this_msg), Ok(Some(_))));
// this_msg belongs to next valid range that must be set as current by now
assert!(matches!(buf.push(&channel, next_msg), Ok(Some(_))));
}

#[test]
fn returns_sorted_batch() {
let channel = ChannelId::new(Role::H1, UniqueStepId::default());
let mut buf = SendBuffer::new(Config::default().items_in_batch(10));

let mut record_ids = (0..10).collect::<Vec<_>>();
record_ids.shuffle(&mut thread_rng());

let mut batch_processed = false;
for record in record_ids {
let msg = empty_msg(record);

if let Some(batch) = buf.push(&channel, msg).ok().flatten() {
// todo: use https://doc.rust-lang.org/std/vec/struct.Vec.html#method.is_sorted_by
// or https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.is_sorted when stable
let is_sorted = batch
.as_slice()
.windows(2)
.all(|w| w[0].record_id.cmp(&w[1].record_id) != Ordering::Greater);

assert!(is_sorted, "batch {batch:?} is not sorted by record_id");
batch_processed = true;
}
}

assert!(batch_processed);
assert!(matches!(buf.push(&channel, &next_msg), Ok(Some(_))));
}

fn empty_msg<I: TryInto<u32>>(record_id: I) -> MessageEnvelope
Expand All @@ -305,7 +253,7 @@ mod tests {
{
MessageEnvelope {
record_id: RecordId::from(record_id.try_into().unwrap()),
payload: Box::new([]),
payload: array_vec!(),
}
}
}
4 changes: 2 additions & 2 deletions src/helpers/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::error::BoxError;
use crate::helpers::Role;
use crate::protocol::{RecordId, UniqueStepId};
use crate::protocol::{RecordId, Step};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;

Expand Down Expand Up @@ -60,7 +60,7 @@ impl Error {
#[must_use]
pub fn serialization_error<E: Into<BoxError>>(
record_id: RecordId,
step: &UniqueStepId,
step: &Step,
inner: E,
) -> Error {
Self::SerializationError {
Expand Down
Loading

0 comments on commit 6dd9b4b

Please sign in to comment.