diff --git a/chitchat/src/digest.rs b/chitchat/src/digest.rs index 9bb6f7a..5abc424 100644 --- a/chitchat/src/digest.rs +++ b/chitchat/src/digest.rs @@ -10,17 +10,30 @@ pub(crate) struct NodeDigest { pub(crate) max_version: Version, } -impl NodeDigest { - pub(crate) fn new( - heartbeat: Heartbeat, - last_gc_version: Version, - max_version: Version, - ) -> Self { - Self { +impl Serializable for NodeDigest { + fn serialize(&self, buf: &mut Vec) { + self.heartbeat.serialize(buf); + self.last_gc_version.serialize(buf); + self.max_version.serialize(buf); + } + + fn serialized_len(&self) -> usize { + self.heartbeat.serialized_len() + + self.last_gc_version.serialized_len() + + self.max_version.serialized_len() + } +} + +impl Deserializable for NodeDigest { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let heartbeat = Heartbeat::deserialize(buf)?; + let last_gc_version = Version::deserialize(buf)?; + let max_version = Version::deserialize(buf)?; + Ok(NodeDigest { heartbeat, last_gc_version, max_version, - } + }) } } @@ -43,7 +56,11 @@ impl Digest { last_gc_version: Version, max_version: Version, ) { - let node_digest = NodeDigest::new(heartbeat, last_gc_version, max_version); + let node_digest = NodeDigest { + heartbeat, + last_gc_version, + max_version, + }; self.node_digests.insert(node, node_digest); } } @@ -53,18 +70,14 @@ impl Serializable for Digest { (self.node_digests.len() as u16).serialize(buf); for (chitchat_id, node_digest) in &self.node_digests { chitchat_id.serialize(buf); - node_digest.heartbeat.serialize(buf); - node_digest.last_gc_version.serialize(buf); - node_digest.max_version.serialize(buf); + node_digest.serialize(buf); } } fn serialized_len(&self) -> usize { let mut len = (self.node_digests.len() as u16).serialized_len(); for (chitchat_id, node_digest) in &self.node_digests { len += chitchat_id.serialized_len(); - len += node_digest.heartbeat.serialized_len(); - len += node_digest.last_gc_version.serialized_len(); - len += node_digest.max_version.serialized_len(); + len += node_digest.serialized_len(); } len } @@ -77,12 +90,39 @@ impl Deserializable for Digest { for _ in 0..num_nodes { let chitchat_id = ChitchatId::deserialize(buf)?; - let heartbeat = Heartbeat::deserialize(buf)?; - let max_version = u64::deserialize(buf)?; - let last_gc_version = u64::deserialize(buf)?; - let node_digest = NodeDigest::new(heartbeat, last_gc_version, max_version); + let node_digest = NodeDigest::deserialize(buf)?; node_digests.insert(chitchat_id, node_digest); } Ok(Digest { node_digests }) } } + + +#[cfg(test)] +mod tests { + use crate::digest::{Digest, NodeDigest}; + use crate::serialize::test_serdeser_aux; + use crate::{ChitchatId, Heartbeat}; + + #[test] + fn test_digests_serialization() { + let node_digest = NodeDigest { + heartbeat: crate::Heartbeat(100u64), + last_gc_version: 2, + max_version: 3, + }; + test_serdeser_aux(&node_digest, 24); + } + + #[test] + fn test_digest() { + let mut digest = Digest::default(); + let node1 = ChitchatId::for_local_test(10_001); + let node2 = ChitchatId::for_local_test(10_002); + let node3 = ChitchatId::for_local_test(10_002); + digest.add_node(node1, Heartbeat(101), 1, 11); + digest.add_node(node2, Heartbeat(102), 20, 12); + digest.add_node(node3, Heartbeat(103), 0, 13); + test_serdeser_aux(&digest, 104); + } +} diff --git a/chitchat/src/transport/channel.rs b/chitchat/src/transport/channel.rs index 137076c..e511938 100644 --- a/chitchat/src/transport/channel.rs +++ b/chitchat/src/transport/channel.rs @@ -7,7 +7,7 @@ use async_trait::async_trait; use tokio::sync::mpsc::{Receiver, Sender}; use tracing::info; -use crate::serialize::Serializable; +use crate::serialize::{Deserializable, Serializable}; use crate::transport::{Socket, Transport}; use crate::ChitchatMessage; @@ -56,6 +56,15 @@ impl Transport for ChannelTransport { } } +fn serialize_deserialize_chitchat_message(message: ChitchatMessage) -> ChitchatMessage { + let buf = message.serialize_to_vec(); + assert_eq!(buf.len(), message.serialized_len()); + let mut read_cursor: &[u8] = &buf[..]; + let message = ChitchatMessage::deserialize(&mut read_cursor).unwrap(); + assert!(read_cursor.is_empty()); + message +} + impl ChannelTransport { pub fn with_mtu(mtu: usize) -> Self { Self { @@ -92,6 +101,8 @@ impl ChannelTransport { to_addr: SocketAddr, message: ChitchatMessage, ) -> anyhow::Result<()> { + // We serialize/deserialize message to get closer to the real world. + let message = serialize_deserialize_chitchat_message(message); let num_bytes = message.serialized_len(); if let Some(mtu) = self.mtu_opt { if num_bytes > mtu {