Skip to content

Commit

Permalink
Allow storing the ratchet tree separately (#214)
Browse files Browse the repository at this point in the history
* Allow storing the ratchet tree separately

* Fixup

* Fixup

* Fixup

---------

Co-authored-by: Marta Mularczyk <[email protected]>
  • Loading branch information
mulmarta and Marta Mularczyk authored Nov 12, 2024
1 parent b6d257e commit eb221ce
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 5 deletions.
2 changes: 1 addition & 1 deletion mls-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mls-rs"
version = "0.42.2"
version = "0.42.3"
edition = "2021"
description = "An implementation of Messaging Layer Security (RFC 9420)"
homepage = "https://github.com/awslabs/mls-rs"
Expand Down
25 changes: 25 additions & 0 deletions mls-rs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,31 @@ where
Group::from_snapshot(self.config.clone(), snapshot).await
}

/// Load an existing group state into this client using the
/// [GroupStateStorage](crate::GroupStateStorage) that
/// this client was configured to use. The tree is taken from
/// `tree_data` instead of the stored state.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[inline(never)]
pub async fn load_group_with_ratchet_tree(
&self,
group_id: &[u8],
tree_data: ExportedTree<'_>,
) -> Result<Group<C>, MlsError> {
let snapshot = self
.config
.group_state_storage()
.state(group_id)
.await
.map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
.ok_or(MlsError::GroupNotFound)?;

let mut snapshot = Snapshot::mls_decode(&mut &*snapshot)?;
snapshot.state.public_tree.nodes = tree_data.0.into_owned();

Group::from_snapshot(self.config.clone(), snapshot).await
}

/// Request to join an existing [group](crate::group::Group).
///
/// An existing group member will need to perform a
Expand Down
15 changes: 15 additions & 0 deletions mls-rs/src/external_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ where
ExternalGroup::from_snapshot(self.config.clone(), snapshot).await
}

/// Load an existing observed group by loading a snapshot that was
/// generated by
/// [ExternalGroup::snapshot](self::ExternalGroup::snapshot). The tree
/// is taken from `tree_data` instead of the stored state.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn load_group_with_ratchet_tree(
&self,
mut snapshot: ExternalSnapshot,
tree_data: ExportedTree<'_>,
) -> Result<ExternalGroup<C>, MlsError> {
snapshot.state.public_tree.nodes = tree_data.0.into_owned();

ExternalGroup::from_snapshot(self.config.clone(), snapshot).await
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn validate_key_package(
&self,
Expand Down
56 changes: 53 additions & 3 deletions mls-rs/src/external_client/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ where
#[derive(Debug, MlsEncode, MlsSize, MlsDecode, PartialEq, Clone)]
pub struct ExternalSnapshot {
version: u16,
state: RawGroupState,
pub(crate) state: RawGroupState,
signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
}

Expand Down Expand Up @@ -697,6 +697,23 @@ where
}
}

/// Create a snapshot of this group's current internal state.
/// The tree is not included in the state and can be stored
/// separately by calling [`Group::export_tree`].
pub fn snapshot_without_ratchet_tree(&mut self) -> ExternalSnapshot {
let tree = std::mem::take(&mut self.state.public_tree.nodes);

let snapshot = ExternalSnapshot {
state: RawGroupState::export(&self.state),
version: 1,
signing_data: self.signing_data.clone(),
};

self.state.public_tree.nodes = tree;

snapshot
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn from_snapshot(
config: C,
Expand Down Expand Up @@ -816,15 +833,15 @@ mod tests {
external_client::{
group::test_utils::make_external_group_with_config,
tests_utils::{TestExternalClientBuilder, TestExternalClientConfig},
ExternalGroup, ExternalReceivedMessage, ExternalSnapshot,
ExternalClient, ExternalGroup, ExternalReceivedMessage, ExternalSnapshot,
},
group::{
framing::{Content, MlsMessagePayload},
message_processor::CommitEffect,
proposal::{AddProposal, Proposal, ProposalOrRef},
proposal_ref::ProposalRef,
test_utils::{test_group, TestGroup},
CommitMessageDescription, ProposalMessageDescription,
CommitMessageDescription, ExportedTree, ProposalMessageDescription,
},
identity::{test_utils::get_test_signing_identity, SigningIdentity},
key_package::test_utils::{test_key_package, test_key_package_message},
Expand Down Expand Up @@ -1366,4 +1383,37 @@ mod tests {

assert_matches!(update, ExternalReceivedMessage::Welcome);
}

#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn external_group_can_be_stored_without_tree() {
let mut server =
make_external_group(&test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await).await;

let snapshot_with_tree = server.snapshot().mls_encode_to_vec().unwrap();

let snapshot_without_tree = server
.snapshot_without_ratchet_tree()
.mls_encode_to_vec()
.unwrap();

let tree = server.state.public_tree.nodes.mls_encode_to_vec().unwrap();
let empty_tree = Vec::<u8>::new().mls_encode_to_vec().unwrap();

assert_eq!(
snapshot_with_tree.len() - snapshot_without_tree.len(),
tree.len() - empty_tree.len()
);

let exported_tree = server.export_tree().unwrap();

let restored = ExternalClient::new(server.config.clone(), None)
.load_group_with_ratchet_tree(
ExternalSnapshot::from_bytes(&snapshot_without_tree).unwrap(),
ExportedTree::from_bytes(&exported_tree).unwrap(),
)
.await
.unwrap();

assert_eq!(restored.group_state(), server.group_state());
}
}
36 changes: 35 additions & 1 deletion mls-rs/src/group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2010,8 +2010,11 @@ mod tests {
#[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
use super::test_utils::test_group_custom_config;

#[cfg(any(feature = "psk", feature = "std"))]
use crate::client::Client;

#[cfg(feature = "psk")]
use crate::{client::Client, psk::PreSharedKey};
use crate::psk::PreSharedKey;

#[cfg(any(feature = "by_ref_proposal", feature = "private_message"))]
use crate::group::test_utils::random_bytes;
Expand Down Expand Up @@ -4370,4 +4373,35 @@ mod tests {

assert!(!group.commit_required());
}

// Testing with std is sufficient. Non-std creates incompatible storage and a lot of special cases.
#[cfg(feature = "std")]
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn can_be_stored_without_tree() {
let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let storage = group.config.group_state_storage().inner;

group.write_to_storage().await.unwrap();
let snapshot_with_tree = storage.lock().unwrap().drain().next().unwrap().1;

group.write_to_storage_without_ratchet_tree().await.unwrap();
let snapshot_without_tree = storage.lock().unwrap().iter().next().unwrap().1.clone();

let tree = group.state.public_tree.nodes.mls_encode_to_vec().unwrap();
let empty_tree = Vec::<u8>::new().mls_encode_to_vec().unwrap();

assert_eq!(
snapshot_with_tree.state_data.len() - snapshot_without_tree.state_data.len(),
tree.len() - empty_tree.len()
);

let exported_tree = group.export_tree();

let restored = Client::new(group.config.clone(), None, None, TEST_PROTOCOL_VERSION)
.load_group_with_ratchet_tree(group.group_id(), exported_tree)
.await
.unwrap();

assert_eq!(restored.group_state(), group.group_state());
}
}
13 changes: 13 additions & 0 deletions mls-rs/src/group/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ where
self.state_repo.write_to_storage(self.snapshot()).await
}

/// Write the current state of the group to the
/// [`GroupStorageProvider`](crate::GroupStateStorage)
/// that is currently in use by the group.
/// The tree is not included in the state and can be stored
/// separately by calling [`Group::export_tree`].
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn write_to_storage_without_ratchet_tree(&mut self) -> Result<(), MlsError> {
let mut snapshot = self.snapshot();
snapshot.state.public_tree.nodes = Default::default();

self.state_repo.write_to_storage(snapshot).await
}

pub(crate) fn snapshot(&self) -> Snapshot {
Snapshot {
state: RawGroupState::export(&self.state),
Expand Down

0 comments on commit eb221ce

Please sign in to comment.