diff --git a/mls-rs-uniffi/src/lib.rs b/mls-rs-uniffi/src/lib.rs index b84b8e49..c7221750 100644 --- a/mls-rs-uniffi/src/lib.rs +++ b/mls-rs-uniffi/src/lib.rs @@ -312,9 +312,8 @@ impl Client { /// Join an existing group. /// /// See [`mls_rs::Client::join_group`] for details. - pub async fn join_group(&self, welcome_message: Arc) -> Result { - let welcome_message = arc_unwrap_or_clone(welcome_message); - let (group, new_member_info) = self.inner.join_group(None, welcome_message.inner).await?; + pub async fn join_group(&self, welcome_message: &Message) -> Result { + let (group, new_member_info) = self.inner.join_group(None, &welcome_message.inner).await?; let group = Arc::new(Group { inner: Arc::new(Mutex::new(group)), diff --git a/mls-rs/examples/basic_server_usage.rs b/mls-rs/examples/basic_server_usage.rs index b4540210..fba71da5 100644 --- a/mls-rs/examples/basic_server_usage.rs +++ b/mls-rs/examples/basic_server_usage.rs @@ -9,7 +9,7 @@ use mls_rs::{ builder::MlsConfig as ExternalMlsConfig, ExternalClient, ExternalReceivedMessage, ExternalSnapshot, }, - group::{CachedProposal, ExportedTree, ReceivedMessage}, + group::{CachedProposal, ReceivedMessage}, identity::{ basic::{BasicCredential, BasicIdentityProvider}, SigningIdentity, @@ -39,11 +39,11 @@ struct BasicServer { impl BasicServer { // Client uploads group data after creating the group - fn create_group(group_info: &[u8], tree: ExportedTree) -> Result { + fn create_group(group_info: &[u8]) -> Result { let server = make_server(); let group_info = MlsMessage::from_bytes(group_info)?; - let group = server.observe_group(group_info, Some(tree))?; + let group = server.observe_group(group_info, None)?; Ok(Self { group_state: group.snapshot().to_bytes()?, @@ -143,22 +143,18 @@ fn main() -> Result<(), MlsError> { let mut alice_group = alice.create_group(ExtensionList::default())?; let bob_key_package = bob.generate_key_package_message()?; - let welcome = alice_group + let welcome = &alice_group .commit_builder() .add_member(bob_key_package)? .build()? - .welcome_messages - .pop() - .expect("key package shouldn't be rejected"); + .welcome_messages[0]; let (mut bob_group, _) = bob.join_group(None, welcome)?; alice_group.apply_pending_commit()?; // Server starts observing Alice's group let group_info = alice_group.group_info_message(true)?.to_bytes()?; - let tree = alice_group.export_tree(); - - let mut server = BasicServer::create_group(&group_info, tree)?; + let mut server = BasicServer::create_group(&group_info)?; // Bob uploads a proposal let proposal = bob_group diff --git a/mls-rs/examples/basic_usage.rs b/mls-rs/examples/basic_usage.rs index 9bef35b3..c49af8f1 100644 --- a/mls-rs/examples/basic_usage.rs +++ b/mls-rs/examples/basic_usage.rs @@ -50,7 +50,7 @@ fn main() -> Result<(), MlsError> { let bob_key_package = bob.generate_key_package_message()?; // Alice issues a commit that adds Bob to the group. - let mut alice_commit = alice_group + let alice_commit = alice_group .commit_builder() .add_member(bob_key_package)? .build()?; @@ -61,7 +61,7 @@ fn main() -> Result<(), MlsError> { alice_group.apply_pending_commit()?; // Bob joins the group with the welcome message created as part of Alice's commit. - let (mut bob_group, _) = bob.join_group(None, alice_commit.welcome_messages.pop().unwrap())?; + let (mut bob_group, _) = bob.join_group(None, &alice_commit.welcome_messages[0])?; // Alice encrypts an application message to Bob. let msg = alice_group.encrypt_application_message(b"hello world", Default::default())?; diff --git a/mls-rs/examples/custom.rs b/mls-rs/examples/custom.rs index d20fac6a..f5d93273 100644 --- a/mls-rs/examples/custom.rs +++ b/mls-rs/examples/custom.rs @@ -383,7 +383,7 @@ fn main() -> Result<(), CustomError> { .remove(0); alice_tablet_group.apply_pending_commit()?; - let (mut alice_pc_group, _) = alice_pc_client.join_group(None, welcome)?; + let (mut alice_pc_group, _) = alice_pc_client.join_group(None, &welcome)?; // Alice cannot add bob's devices yet let bob_tablet_client = make_client(bob_tablet)?; @@ -401,13 +401,13 @@ fn main() -> Result<(), CustomError> { new_user: bob.credential, }; - let mut commit = alice_tablet_group + let commit = alice_tablet_group .commit_builder() .custom_proposal(add_bob.to_custom_proposal()?) .add_member(key_package)? .build()?; - bob_tablet_client.join_group(None, commit.welcome_messages.remove(0))?; + bob_tablet_client.join_group(None, &commit.welcome_messages[0])?; alice_tablet_group.apply_pending_commit()?; alice_pc_group.process_incoming_message(commit.commit_message)?; diff --git a/mls-rs/examples/large_group.rs b/mls-rs/examples/large_group.rs index 7f317b3e..c4377437 100644 --- a/mls-rs/examples/large_group.rs +++ b/mls-rs/examples/large_group.rs @@ -69,7 +69,7 @@ fn make_groups_best_case( let bob_kpkg = bob_client.generate_key_package_message()?; // Last group sends a commit adding the new client to the group. - let mut commit = groups + let commit = groups .last_mut() .unwrap() .commit_builder() @@ -85,8 +85,7 @@ fn make_groups_best_case( groups.last_mut().unwrap().apply_pending_commit()?; // The new member joins. - let welcome_message = commit.welcome_messages.pop().unwrap(); - let (bob_group, _info) = bob_client.join_group(None, welcome_message)?; + let (bob_group, _info) = bob_client.join_group(None, &commit.welcome_messages[0])?; groups.push(bob_group); } @@ -115,7 +114,7 @@ fn make_groups_worst_case( commit_builder = commit_builder.add_member(bob_kpkg)?; } - let welcome_message: mls_rs::MlsMessage = commit_builder.build()?.welcome_messages.remove(0); + let welcome_message = &commit_builder.build()?.welcome_messages[0]; alice_group.apply_pending_commit()?; @@ -123,7 +122,7 @@ fn make_groups_worst_case( let mut groups = vec![alice_group]; for bob_client in &bob_clients { - let (bob_group, _info) = bob_client.join_group(None, welcome_message.clone())?; + let (bob_group, _info) = bob_client.join_group(None, welcome_message)?; groups.push(bob_group); } diff --git a/mls-rs/src/client.rs b/mls-rs/src/client.rs index 7281ad23..0799ada7 100644 --- a/mls-rs/src/client.rs +++ b/mls-rs/src/client.rs @@ -11,7 +11,6 @@ use crate::group::framing::MlsMessage; use crate::group::{ framing::{Content, MlsMessagePayload, PublicMessage, Sender, WireFormat}, message_signature::AuthenticatedContent, - process_group_info, proposal::{AddProposal, Proposal}, }; use crate::group::{ExportedTree, Group, NewMemberInfo}; @@ -536,7 +535,7 @@ where pub async fn join_group( &self, tree_data: Option>, - welcome_message: MlsMessage, + welcome_message: &MlsMessage, ) -> Result<(Group, NewMemberInfo), MlsError> { Group::join( welcome_message, @@ -627,7 +626,7 @@ where #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn external_add_proposal( &self, - group_info: MlsMessage, + group_info: &MlsMessage, tree_data: Option>, authenticated_data: Vec, ) -> Result { @@ -638,7 +637,7 @@ where } let group_info = group_info - .into_group_info() + .as_group_info() .ok_or(MlsError::UnexpectedMessageType)?; let cipher_suite = group_info.group_context.cipher_suite; @@ -649,15 +648,14 @@ where .cipher_suite_provider(cipher_suite) .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))?; - let group_context = process_group_info( + crate::group::validate_group_info_joiner( protocol_version, group_info, tree_data, &self.config.identity_provider(), &cipher_suite_provider, ) - .await? - .group_context; + .await?; let key_package = self.generate_key_package().await?.key_package; @@ -667,7 +665,7 @@ where let message = AuthenticatedContent::new_signed( &cipher_suite_provider, - &group_context, + &group_info.group_context, Sender::NewMemberProposal, Content::Proposal(Box::new(Proposal::Add(Box::new(AddProposal { key_package, @@ -830,12 +828,8 @@ mod tests { let proposal = bob .external_add_proposal( - alice_group - .group - .group_info_message_allowing_ext_commit(true) - .await - .unwrap(), - Some(alice_group.group.export_tree()), + &alice_group.group.group_info_message(true).await.unwrap(), + None, vec![], ) .await @@ -890,7 +884,7 @@ mod tests { let group_info_msg = alice_group .group - .group_info_message_allowing_ext_commit(false) + .group_info_message_allowing_ext_commit(true) .await .unwrap(); @@ -904,10 +898,7 @@ mod tests { .signing_identity(new_client_identity.clone(), secret_key, TEST_CIPHER_SUITE) .build(); - let mut builder = new_client - .external_commit_builder() - .unwrap() - .with_tree_data(alice_group.group.export_tree().into_owned()); + let mut builder = new_client.external_commit_builder().unwrap(); if do_remove { builder = builder.with_removal(1); @@ -1013,7 +1004,6 @@ mod tests { let (_, external_commit) = carol .external_commit_builder() .unwrap() - .with_tree_data(bob_group.group.export_tree().into_owned()) .build(group_info_msg) .await .unwrap(); diff --git a/mls-rs/src/external_client/group.rs b/mls-rs/src/external_client/group.rs index ae9ee205..89399480 100644 --- a/mls-rs/src/external_client/group.rs +++ b/mls-rs/src/external_client/group.rs @@ -24,7 +24,8 @@ use crate::{ snapshot::RawGroupState, state::GroupState, transcript_hash::InterimTranscriptHash, - validate_group_info, ContentType, ExportedTree, GroupContext, GroupInfo, Roster, Welcome, + validate_group_info_joiner, ContentType, ExportedTree, GroupContext, GroupInfo, Roster, + Welcome, }, identity::SigningIdentity, protocol_version::ProtocolVersion, @@ -126,9 +127,9 @@ impl ExternalGroup { group_info.group_context.cipher_suite, )?; - let join_context = validate_group_info( + let public_tree = validate_group_info_joiner( protocol_version, - group_info, + &group_info, tree_data, &config.identity_provider(), &cipher_suite_provider, @@ -137,8 +138,8 @@ impl ExternalGroup { let interim_transcript_hash = InterimTranscriptHash::create( &cipher_suite_provider, - &join_context.group_context.confirmed_transcript_hash, - &join_context.confirmation_tag, + &group_info.group_context.confirmed_transcript_hash, + &group_info.confirmation_tag, ) .await?; @@ -146,10 +147,10 @@ impl ExternalGroup { config, signing_data, state: GroupState::new( - join_context.group_context, - join_context.public_tree, + group_info.group_context, + public_tree, interim_transcript_hash, - join_context.confirmation_tag, + group_info.confirmation_tag, ), cipher_suite_provider, }) @@ -597,7 +598,7 @@ where #[cfg(feature = "private_message")] async fn process_ciphertext( &mut self, - cipher_text: PrivateMessage, + cipher_text: &PrivateMessage, ) -> Result, MlsError> { Ok(EventOrContent::Event(ExternalReceivedMessage::Ciphertext( cipher_text.content_type, diff --git a/mls-rs/src/group/ciphertext_processor.rs b/mls-rs/src/group/ciphertext_processor.rs index 111de4ea..bf70f5dd 100644 --- a/mls-rs/src/group/ciphertext_processor.rs +++ b/mls-rs/src/group/ciphertext_processor.rs @@ -194,7 +194,7 @@ where #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn open( &mut self, - ciphertext: PrivateMessage, + ciphertext: &PrivateMessage, ) -> Result { // Decrypt the sender data with the derived sender_key and sender_nonce from the message // epoch's key schedule @@ -236,7 +236,7 @@ where .decrypt( &self.cipher_suite_provider, &ciphertext.ciphertext, - &PrivateContentAAD::from(&ciphertext).mls_encode_to_vec()?, + &PrivateContentAAD::from(ciphertext).mls_encode_to_vec()?, &sender_data.reuse_guard, ) .await @@ -252,7 +252,7 @@ where group_id: ciphertext.group_id.clone(), epoch: ciphertext.epoch, sender, - authenticated_data: ciphertext.authenticated_data, + authenticated_data: ciphertext.authenticated_data.clone(), content: ciphertext_content.content, }, auth: ciphertext_content.auth, @@ -335,7 +335,7 @@ mod test { let mut receiver_processor = test_processor(&mut receiver_group, cipher_suite); - let decrypted = receiver_processor.open(ciphertext).await.unwrap(); + let decrypted = receiver_processor.open(&ciphertext).await.unwrap(); assert_eq!(decrypted, test_data.content); } @@ -384,7 +384,7 @@ mod test { .await .unwrap(); - let res = ciphertext_processor.open(ciphertext).await; + let res = ciphertext_processor.open(&ciphertext).await; assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf)) } @@ -403,7 +403,7 @@ mod test { ciphertext.ciphertext = random_bytes(ciphertext.ciphertext.len()); receiver_group.group.private_tree.self_index = LeafIndex::new(1); - let res = ciphertext_processor.open(ciphertext).await; + let res = ciphertext_processor.open(&ciphertext).await; assert!(res.is_err()); } diff --git a/mls-rs/src/group/commit.rs b/mls-rs/src/group/commit.rs index 0bec5962..3f3dbbf1 100644 --- a/mls-rs/src/group/commit.rs +++ b/mls-rs/src/group/commit.rs @@ -1006,7 +1006,7 @@ mod tests { .welcome_messages .remove(0); - let (_, context) = bob_client.join_group(None, welcome_message).await.unwrap(); + let (_, context) = bob_client.join_group(None, &welcome_message).await.unwrap(); assert_eq!( context @@ -1243,7 +1243,7 @@ mod tests { .find(|w| w.welcome_key_package_references().contains(&&kp_ref)) .unwrap(); - client.join_group(None, welcome.clone()).await.unwrap(); + client.join_group(None, welcome).await.unwrap(); assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1); } diff --git a/mls-rs/src/group/external_commit.rs b/mls-rs/src/group/external_commit.rs index 32bd09b3..34b10427 100644 --- a/mls-rs/src/group/external_commit.rs +++ b/mls-rs/src/group/external_commit.rs @@ -11,8 +11,7 @@ use crate::{ epoch::SenderDataSecret, key_schedule::{InitSecret, KeySchedule}, proposal::{ExternalInit, Proposal, RemoveProposal}, - validate_group_info, EpochSecrets, ExternalPubExt, LeafIndex, LeafNode, MlsError, - TreeKemPrivate, + EpochSecrets, ExternalPubExt, LeafIndex, LeafNode, MlsError, TreeKemPrivate, }, Group, MlsMessage, }; @@ -40,7 +39,7 @@ use crate::group::{ PreSharedKeyProposal, {JustPreSharedKeyID, PreSharedKeyID}, }; -use super::ExportedTree; +use super::{validate_group_info_joiner, ExportedTree}; /// A builder that aids with the construction of an external commit. #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))] @@ -164,9 +163,9 @@ impl ExternalCommitBuilder { .get_as::()? .ok_or(MlsError::MissingExternalPubExtension)?; - let join_context = validate_group_info( + let public_tree = validate_group_info_joiner( protocol_version, - group_info, + &group_info, self.tree_data, &self.config.identity_provider(), &cipher_suite, @@ -195,8 +194,8 @@ impl ExternalCommitBuilder { let (mut group, _) = Group::join_with( self.config, - cipher_suite.clone(), - join_context, + group_info, + public_tree, KeySchedule::new(init_secret), epoch_secrets, TreeKemPrivate::new_for_external(), diff --git a/mls-rs/src/group/framing.rs b/mls-rs/src/group/framing.rs index cc5bbf25..8663b968 100644 --- a/mls-rs/src/group/framing.rs +++ b/mls-rs/src/group/framing.rs @@ -394,6 +394,14 @@ impl MlsMessage { } } + #[inline(always)] + pub fn as_group_info(&self) -> Option<&GroupInfo> { + match &self.payload { + MlsMessagePayload::GroupInfo(info) => Some(info), + _ => None, + } + } + #[inline(always)] pub fn into_key_package(self) -> Option { match self.payload { diff --git a/mls-rs/src/group/interop_test_vectors/passive_client.rs b/mls-rs/src/group/interop_test_vectors/passive_client.rs index e3e75766..ea1780ba 100644 --- a/mls-rs/src/group/interop_test_vectors/passive_client.rs +++ b/mls-rs/src/group/interop_test_vectors/passive_client.rs @@ -188,7 +188,7 @@ async fn interop_passive_client() { .ratchet_tree .map(|t| ExportedTree::from_bytes(&t.0).unwrap()); - let (mut group, _info) = client.join_group(tree, welcome).await.unwrap(); + let (mut group, _info) = client.join_group(tree, &welcome).await.unwrap(); assert_eq!( group.epoch_authenticator().unwrap().to_vec(), @@ -631,7 +631,7 @@ pub async fn add_random_members( let commit = commit_output.welcome_messages[0].clone(); let group = client - .join_group(Some(tree_data.clone()), commit) + .join_group(Some(tree_data.clone()), &commit) .await .unwrap() .0; diff --git a/mls-rs/src/group/message_processor.rs b/mls-rs/src/group/message_processor.rs index cf591979..8084a583 100644 --- a/mls-rs/src/group/message_processor.rs +++ b/mls-rs/src/group/message_processor.rs @@ -10,15 +10,13 @@ use super::{ }, message_signature::AuthenticatedContent, mls_rules::{CommitDirection, MlsRules}, - process_group_info, proposal_filter::ProposalBundle, state::GroupState, transcript_hash::InterimTranscriptHash, - transcript_hashes, ExportedTree, GroupContext, GroupInfo, Welcome, + transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, Welcome, }; use crate::{ client::MlsError, - extension::RatchetTreeExt, key_package::validate_key_package_properties, time::MlsTime, tree_kem::{ @@ -486,10 +484,15 @@ pub(crate) trait MessageProcessor: Send + Sync { self.verify_plaintext_authentication(plaintext).await } #[cfg(feature = "private_message")] - MlsMessagePayload::Cipher(cipher_text) => self.process_ciphertext(cipher_text).await, + MlsMessagePayload::Cipher(cipher_text) => self.process_ciphertext(&cipher_text).await, MlsMessagePayload::GroupInfo(group_info) => { - self.validate_group_info(group_info.clone(), message.version) - .await?; + validate_group_info_member( + self.group_state(), + message.version, + &group_info, + self.cipher_suite_provider(), + ) + .await?; Ok(EventOrContent::Event(group_info.into())) } @@ -946,38 +949,6 @@ pub(crate) trait MessageProcessor: Send + Sync { Ok(()) } - async fn validate_group_info( - &self, - group_info: GroupInfo, - version: ProtocolVersion, - ) -> Result<(), MlsError> { - let state = self.group_state(); - - let self_tree = ExportedTree::new_borrowed(&state.public_tree.nodes); - - if let Some(tree) = group_info.extensions.get_as::()? { - (tree.tree_data == self_tree) - .then_some(()) - .ok_or(MlsError::InvalidGroupInfo)?; - } - - (group_info.group_context == state.context - && group_info.confirmation_tag == state.confirmation_tag) - .then_some(()) - .ok_or(MlsError::InvalidGroupInfo)?; - - process_group_info( - version, - group_info, - Some(self_tree), - &self.identity_provider(), - self.cipher_suite_provider(), - ) - .await?; - - Ok(()) - } - fn validate_welcome( &self, welcome: &Welcome, @@ -1005,7 +976,7 @@ pub(crate) trait MessageProcessor: Send + Sync { #[cfg(feature = "private_message")] async fn process_ciphertext( &mut self, - cipher_text: PrivateMessage, + cipher_text: &PrivateMessage, ) -> Result, MlsError>; async fn verify_plaintext_authentication( diff --git a/mls-rs/src/group/message_verifier.rs b/mls-rs/src/group/message_verifier.rs index 8674bd86..7a2bc59b 100644 --- a/mls-rs/src/group/message_verifier.rs +++ b/mls-rs/src/group/message_verifier.rs @@ -296,7 +296,7 @@ mod tests { let (bob_client, bob_key_pkg) = test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await; - let mut commit_output = alice + let commit_output = alice .group .commit_builder() .add_member(bob_key_pkg) @@ -308,7 +308,7 @@ mod tests { alice.group.apply_pending_commit().await.unwrap(); let (bob, _) = Group::join( - commit_output.welcome_messages.remove(0), + &commit_output.welcome_messages[0], None, bob_client.config, bob_client.signer.unwrap(), diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 77a2e19d..97c124ee 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -390,7 +390,7 @@ where #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(crate) async fn join( - welcome: MlsMessage, + welcome: &MlsMessage, tree_data: Option>, config: C, signer: SignatureSecretKey, @@ -408,7 +408,7 @@ where #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn from_welcome_message( - welcome: MlsMessage, + welcome: &MlsMessage, tree_data: Option>, config: C, signer: SignatureSecretKey, @@ -420,9 +420,9 @@ where return Err(MlsError::UnsupportedProtocolVersion(protocol_version)); } - let welcome = welcome - .into_welcome() - .ok_or(MlsError::UnexpectedMessageType)?; + let MlsMessagePayload::Welcome(welcome) = &welcome.payload else { + return Err(MlsError::UnexpectedMessageType); + }; let cipher_suite_provider = cipher_suite_provider(config.crypto_provider(), welcome.cipher_suite)?; @@ -501,9 +501,9 @@ where let group_info = GroupInfo::mls_decode(&mut &**decrypted_group_info)?; - let join_context = validate_group_info( + let public_tree = validate_group_info_joiner( protocol_version, - group_info, + &group_info, tree_data, &config.identity_provider(), &cipher_suite_provider, @@ -514,8 +514,7 @@ where // to the leaf_node field of the KeyPackage. If no such field exists, return an error. Let // index represent the index of this node among the leaves in the tree, namely the index of // the node in the tree array divided by two. - let self_index = join_context - .public_tree + let self_index = public_tree .find_leaf_node(&key_package_generation.key_package.leaf_node) .ok_or(MlsError::WelcomeKeyPackageNotFound)?; @@ -529,9 +528,9 @@ where private_tree .update_secrets( &cipher_suite_provider, - join_context.signer_index, + group_info.signer, path_secret, - &join_context.public_tree, + &public_tree, ) .await?; } @@ -541,20 +540,20 @@ where let key_schedule_result = KeySchedule::from_joiner( &cipher_suite_provider, &group_secrets.joiner_secret, - &join_context.group_context, + &group_info.group_context, #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] - join_context.public_tree.total_leaf_count(), + public_tree.total_leaf_count(), &psk_secret, ) .await?; // Verify the confirmation tag in the GroupInfo using the derived confirmation key and the // confirmed_transcript_hash from the GroupInfo. - if !join_context + if !group_info .confirmation_tag .matches( &key_schedule_result.confirmation_key, - &join_context.group_context.confirmed_transcript_hash, + &group_info.group_context.confirmed_transcript_hash, &cipher_suite_provider, ) .await? @@ -564,8 +563,8 @@ where Self::join_with( config, - cipher_suite_provider, - join_context, + group_info, + public_tree, key_schedule_result.key_schedule, key_schedule_result.epoch_secrets, private_tree, @@ -579,41 +578,46 @@ where #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn join_with( config: C, - cipher_suite_provider: ::CipherSuiteProvider, - join_context: JoinContext, + group_info: GroupInfo, + public_tree: TreeKemPublic, key_schedule: KeySchedule, epoch_secrets: EpochSecrets, private_tree: TreeKemPrivate, used_key_package_ref: Option, signer: SignatureSecretKey, ) -> Result<(Self, NewMemberInfo), MlsError> { + let cs = group_info.group_context.cipher_suite; + + let cs = config + .crypto_provider() + .cipher_suite_provider(cs) + .ok_or(MlsError::UnsupportedCipherSuite(cs))?; + // Use the confirmed transcript hash and confirmation tag to compute the interim transcript // hash in the new state. let interim_transcript_hash = InterimTranscriptHash::create( - &cipher_suite_provider, - &join_context.group_context.confirmed_transcript_hash, - &join_context.confirmation_tag, + &cs, + &group_info.group_context.confirmed_transcript_hash, + &group_info.confirmation_tag, ) .await?; let state_repo = GroupStateRepository::new( #[cfg(feature = "prior_epoch")] - join_context.group_context.group_id.clone(), + group_info.group_context.group_id.clone(), config.group_state_storage(), config.key_package_repo(), used_key_package_ref, ) .await?; - let group_info_extensions = join_context.group_info_extensions.clone(); - let group = Group { config, state: GroupState::new( - join_context.group_context, - join_context.public_tree, + group_info.group_context, + public_tree, interim_transcript_hash, - join_context.confirmation_tag, + group_info.confirmation_tag, ), private_tree, key_schedule, @@ -624,13 +628,13 @@ where commit_modifiers: Default::default(), epoch_secrets, state_repo, - cipher_suite_provider, + cipher_suite_provider: cs, #[cfg(feature = "psk")] previous_psk: None, signer, }; - Ok((group, NewMemberInfo::new(group_info_extensions))) + Ok((group, NewMemberInfo::new(group_info.extensions))) } #[inline(always)] @@ -1193,7 +1197,7 @@ where #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn decrypt_incoming_ciphertext( &mut self, - message: PrivateMessage, + message: &PrivateMessage, ) -> Result { let epoch_id = message.epoch; @@ -1604,7 +1608,7 @@ where #[cfg(feature = "private_message")] async fn process_ciphertext( &mut self, - cipher_text: PrivateMessage, + cipher_text: &PrivateMessage, ) -> Result, MlsError> { self.decrypt_incoming_ciphertext(cipher_text) .await @@ -2161,7 +2165,7 @@ mod tests { test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await; // Add bob to the group - let mut commit_output = test_group + let commit_output = test_group .group .commit_builder() .add_member(bob_key_package) @@ -2172,7 +2176,7 @@ mod tests { // Group from Bob's perspective let bob_group = Group::join( - commit_output.welcome_messages.remove(0), + &commit_output.welcome_messages[0], None, bob_client.config, bob_client.signer.unwrap(), @@ -2699,7 +2703,6 @@ mod tests { let (bob_group, commit) = bob .external_commit_builder() .unwrap() - .with_tree_data(alice_group.group.export_tree().into_owned()) .build( alice_group .group @@ -2893,27 +2896,18 @@ mod tests { .await .unwrap(); - let (mut alice_sub_group, mut welcome) = alice + let (mut alice_sub_group, welcome) = alice .group .branch(b"subgroup".to_vec(), vec![new_key_pkg]) .await .unwrap(); - let welcome = welcome.remove(0); + let welcome = &welcome[0]; - let (mut bob_sub_group, _) = bob - .group - .join_subgroup(welcome.clone(), None) - .await - .unwrap(); + let (mut bob_sub_group, _) = bob.group.join_subgroup(welcome, None).await.unwrap(); // Carol can't join - let res = carol - .group - .join_subgroup(welcome, Some(alice_sub_group.export_tree())) - .await - .map(|_| ()); - + let res = carol.group.join_subgroup(welcome, None).await.map(|_| ()); assert_matches!(res, Err(_)); // Alice and Bob can still talk @@ -3834,7 +3828,7 @@ mod tests { bob.config.secret_store().insert(psk_id.clone(), psk); - let mut commit = alice + let commit = alice .commit_builder() .add_member(key_pkg) .unwrap() @@ -3844,7 +3838,7 @@ mod tests { .await .unwrap(); - bob.join_group(None, commit.welcome_messages.remove(0)) + bob.join_group(None, &commit.welcome_messages[0]) .await .unwrap(); } @@ -3911,7 +3905,7 @@ mod tests { alice.apply_pending_commit().await.unwrap(); let mut bob = bob_client - .join_group(None, commit.welcome_messages[0].clone()) + .join_group(None, &commit.welcome_messages[0]) .await .unwrap() .0; @@ -3930,13 +3924,13 @@ mod tests { .unwrap(); let mut carol = carol_client - .join_group(None, commit.welcome_messages[0].clone()) + .join_group(None, &commit.welcome_messages[0]) .await .unwrap() .0; let mut dave = dave_client - .join_group(None, commit.welcome_messages[0].clone()) + .join_group(None, &commit.welcome_messages[0]) .await .unwrap() .0; diff --git a/mls-rs/src/group/resumption.rs b/mls-rs/src/group/resumption.rs index e77d5159..3478ef3c 100644 --- a/mls-rs/src/group/resumption.rs +++ b/mls-rs/src/group/resumption.rs @@ -74,7 +74,7 @@ where #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn join_subgroup( &self, - welcome: MlsMessage, + welcome: &MlsMessage, tree_data: Option>, ) -> Result<(Group, NewMemberInfo), MlsError> { let expected_new_group_prams = ResumptionGroupParameters { @@ -204,7 +204,7 @@ impl ReinitClient { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn join( self, - welcome: MlsMessage, + welcome: &MlsMessage, tree_data: Option>, ) -> Result<(Group, NewMemberInfo), MlsError> { let reinit = self.reinit; @@ -274,7 +274,7 @@ async fn resumption_create_group( async fn resumption_join_group( config: C, signer: SignatureSecretKey, - welcome: MlsMessage, + welcome: &MlsMessage, tree_data: Option>, expected_new_group_params: ResumptionGroupParameters<'_>, verify_group_id: bool, diff --git a/mls-rs/src/group/test_utils.rs b/mls-rs/src/group/test_utils.rs index 51320cda..764d5e6f 100644 --- a/mls-rs/src/group/test_utils.rs +++ b/mls-rs/src/group/test_utils.rs @@ -80,7 +80,7 @@ impl TestGroup { // Add new member to the group let CommitOutput { - mut welcome_messages, + welcome_messages, ratchet_tree, commit_message, .. @@ -100,7 +100,7 @@ impl TestGroup { // Group from new member's perspective let (new_group, _) = Group::join( - welcome_messages.pop().unwrap(), + &welcome_messages[0], ratchet_tree, new_client.config.clone(), new_client.signer.clone().unwrap(), @@ -375,7 +375,7 @@ pub(crate) async fn get_test_groups_with_features( groups.push( client - .join_group(None, commit_output.welcome_messages[0].clone()) + .join_group(None, &commit_output.welcome_messages[0]) .await .unwrap() .0, @@ -488,7 +488,7 @@ impl MessageProcessor for GroupWithoutKeySchedule { #[cfg_attr(coverage_nightly, coverage(off))] async fn process_ciphertext( &mut self, - cipher_text: PrivateMessage, + cipher_text: &PrivateMessage, ) -> Result, MlsError> { self.inner.process_ciphertext(cipher_text).await } diff --git a/mls-rs/src/group/util.rs b/mls-rs/src/group/util.rs index 104d4f86..dadfafac 100644 --- a/mls-rs/src/group/util.rs +++ b/mls-rs/src/group/util.rs @@ -14,140 +14,108 @@ use crate::{ protocol_version::ProtocolVersion, signer::Signable, tree_kem::{node::LeafIndex, tree_validator::TreeValidator, TreeKemPublic}, - CipherSuiteProvider, CryptoProvider, ExtensionList, + CipherSuiteProvider, CryptoProvider, }; #[cfg(feature = "by_ref_proposal")] use crate::extension::ExternalSendersExt; use super::{ - confirmation_tag::ConfirmationTag, framing::Sender, message_signature::AuthenticatedContent, + framing::Sender, message_signature::AuthenticatedContent, transcript_hash::InterimTranscriptHash, ConfirmedTranscriptHash, EncryptedGroupSecrets, - ExportedTree, GroupContext, GroupInfo, + ExportedTree, GroupInfo, GroupState, }; use super::message_processor::ProvisionalState; -#[derive(Clone, Debug)] -#[non_exhaustive] -pub(crate) struct JoinContext { - pub group_info_extensions: ExtensionList, - pub group_context: GroupContext, - pub confirmation_tag: ConfirmationTag, - pub public_tree: TreeKemPublic, - pub signer_index: LeafIndex, -} - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] -pub(crate) async fn process_group_info( - msg_protocol_version: ProtocolVersion, - group_info: GroupInfo, - tree_data: Option>, - id_provider: &I, +pub(crate) async fn validate_group_info_common( + msg_version: ProtocolVersion, + group_info: &GroupInfo, + tree: &TreeKemPublic, cs: &C, -) -> Result -where - C: CipherSuiteProvider, - I: IdentityProvider, -{ - let tree_data = match group_info.extensions.get_as::()? { - Some(ext) => ext.tree_data, - None => tree_data.ok_or(MlsError::RatchetTreeNotFound)?, - }; - - let context_ext = &group_info.group_context.extensions; - - let public_tree = - TreeKemPublic::import_node_data(tree_data.into(), id_provider, context_ext).await?; - - let group_protocol_version = group_info.group_context.protocol_version; - - if msg_protocol_version != group_protocol_version { +) -> Result<(), MlsError> { + if msg_version != group_info.group_context.protocol_version { return Err(MlsError::ProtocolVersionMismatch); } - let cipher_suite = cs.cipher_suite(); - - if group_info.group_context.cipher_suite != cipher_suite { + if group_info.group_context.cipher_suite != cs.cipher_suite() { return Err(MlsError::CipherSuiteMismatch); } - let sender_key_package = public_tree.get_leaf_node(group_info.signer)?; + let sender_leaf = &tree.get_leaf_node(group_info.signer)?; group_info - .verify(cs, &sender_key_package.signing_identity.signature_key, &()) + .verify(cs, &sender_leaf.signing_identity.signature_key, &()) .await?; - let confirmation_tag = group_info.confirmation_tag; - let signer_index = group_info.signer; - - let group_context = GroupContext { - protocol_version: msg_protocol_version, - cipher_suite, - group_id: group_info.group_context.group_id, - epoch: group_info.group_context.epoch, - tree_hash: group_info.group_context.tree_hash, - confirmed_transcript_hash: group_info.group_context.confirmed_transcript_hash, - extensions: group_info.group_context.extensions, - }; + Ok(()) +} - Ok(JoinContext { - group_info_extensions: group_info.extensions, - group_context, - confirmation_tag, - public_tree, - signer_index, - }) +#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] +pub(crate) async fn validate_group_info_member( + self_state: &GroupState, + msg_version: ProtocolVersion, + group_info: &GroupInfo, + cs: &C, +) -> Result<(), MlsError> { + validate_group_info_common(msg_version, group_info, &self_state.public_tree, cs).await?; + + let self_tree = ExportedTree::new_borrowed(&self_state.public_tree.nodes); + + if let Some(tree) = group_info.extensions.get_as::()? { + (tree.tree_data == self_tree) + .then_some(()) + .ok_or(MlsError::InvalidGroupInfo)?; + } + + (group_info.group_context == self_state.context + && group_info.confirmation_tag == self_state.confirmation_tag) + .then_some(()) + .ok_or(MlsError::InvalidGroupInfo)?; + + Ok(()) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] -pub(crate) async fn validate_group_info( - msg_protocol_version: ProtocolVersion, - group_info: GroupInfo, - tree_data: Option>, - identity_provider: &I, - cipher_suite_provider: &C, -) -> Result { - let mut join_context = process_group_info( - msg_protocol_version, - group_info, - tree_data, - identity_provider, - cipher_suite_provider, - ) - .await?; +pub(crate) async fn validate_group_info_joiner( + msg_version: ProtocolVersion, + group_info: &GroupInfo, + tree: Option>, + id_provider: &I, + cs: &C, +) -> Result +where + C: CipherSuiteProvider, + I: IdentityProvider, +{ + let tree = match group_info.extensions.get_as::()? { + Some(ext) => ext.tree_data, + None => tree.ok_or(MlsError::RatchetTreeNotFound)?, + }; + + let context = &group_info.group_context; + + let mut tree = + TreeKemPublic::import_node_data(tree.into(), id_provider, &context.extensions).await?; // Verify the integrity of the ratchet tree - let tree_validator = TreeValidator::new( - cipher_suite_provider, - &join_context.group_context.group_id, - &join_context.group_context.tree_hash, - &join_context.group_context.extensions, - identity_provider, - ); - - tree_validator - .validate(&mut join_context.public_tree) + TreeValidator::new(cs, context, id_provider) + .validate(&mut tree) .await?; #[cfg(feature = "by_ref_proposal")] - if let Some(ext_senders) = join_context - .group_context - .extensions - .get_as::()? - { + if let Some(ext_senders) = context.extensions.get_as::()? { // TODO do joiners verify group against current time?? ext_senders - .verify_all( - identity_provider, - None, - &join_context.group_context.extensions, - ) + .verify_all(id_provider, None, &context.extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; } - Ok(join_context) + validate_group_info_common(msg_version, group_info, &tree, cs).await?; + + Ok(tree) } pub(crate) fn commit_sender( diff --git a/mls-rs/src/test_utils/mod.rs b/mls-rs/src/test_utils/mod.rs index 42410ad2..3346bafe 100644 --- a/mls-rs/src/test_utils/mod.rs +++ b/mls-rs/src/test_utils/mod.rs @@ -142,7 +142,7 @@ pub async fn get_test_groups( for client in &receiver_clients { let (test_client, _info) = client - .join_group(Some(tree_data.clone()), welcome[0].clone()) + .join_group(Some(tree_data.clone()), &welcome[0]) .await .unwrap(); diff --git a/mls-rs/src/tree_kem/interop_test_vectors.rs b/mls-rs/src/tree_kem/interop_test_vectors.rs index 6d57b143..50e0077a 100644 --- a/mls-rs/src/tree_kem/interop_test_vectors.rs +++ b/mls-rs/src/tree_kem/interop_test_vectors.rs @@ -5,10 +5,7 @@ use alloc::vec; use alloc::vec::Vec; use mls_rs_codec::{MlsDecode, MlsEncode}; -use mls_rs_core::{ - crypto::{CipherSuite, CipherSuiteProvider}, - extension::ExtensionList, -}; +use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider}; use itertools::Itertools; @@ -78,6 +75,8 @@ impl ValidationTestCase { #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[cfg_attr(coverage_nightly, coverage(off))] async fn validation() { + use crate::group::test_utils::get_test_group_context; + #[cfg(mls_build_async)] let test_cases: Vec = load_test_case_json!( interop_tree_validation, @@ -117,16 +116,14 @@ async fn validation() { assert_eq!(&tree.nodes.get_resolution_index(i as u32).unwrap(), res) }); - TreeValidator::new( - &cs, - &test_case.group_id, - &tree_hash, - &ExtensionList::new(), - &BasicIdentityProvider, - ) - .validate(&mut tree) - .await - .unwrap(); + let mut context = get_test_group_context(1, test_case.cipher_suite.into()).await; + context.tree_hash = tree_hash; + context.group_id = test_case.group_id; + + TreeValidator::new(&cs, &context, &BasicIdentityProvider) + .validate(&mut tree) + .await + .unwrap(); } } diff --git a/mls-rs/src/tree_kem/tree_validator.rs b/mls-rs/src/tree_kem/tree_validator.rs index dcfee6e8..26d4baf1 100644 --- a/mls-rs/src/tree_kem/tree_validator.rs +++ b/mls-rs/src/tree_kem/tree_validator.rs @@ -12,10 +12,10 @@ use tree_math::TreeIndex; use super::node::{Node, NodeIndex}; use crate::client::MlsError; use crate::crypto::CipherSuiteProvider; +use crate::group::GroupContext; use crate::iter::wrap_impl_iter; use crate::tree_kem::math as tree_math; use crate::tree_kem::{leaf_node_validator::LeafNodeValidator, TreeKemPublic}; -use mls_rs_core::extension::ExtensionList; use mls_rs_core::identity::IdentityProvider; #[cfg(all(not(mls_build_async), feature = "rayon"))] @@ -38,19 +38,17 @@ where impl<'a, C: IdentityProvider, CSP: CipherSuiteProvider> TreeValidator<'a, C, CSP> { pub fn new( cipher_suite_provider: &'a CSP, - group_id: &'a [u8], - tree_hash: &'a [u8], - group_context_extensions: &'a ExtensionList, + context: &'a GroupContext, identity_provider: &'a C, ) -> Self { TreeValidator { - expected_tree_hash: tree_hash, + expected_tree_hash: &context.tree_hash, leaf_node_validator: LeafNodeValidator::new( cipher_suite_provider, identity_provider, - Some(group_context_extensions), + Some(&context.extensions), ), - group_id, + group_id: &context.group_id, cipher_suite_provider, } } @@ -168,7 +166,7 @@ mod tests { client::test_utils::TEST_CIPHER_SUITE, crypto::test_utils::test_cipher_suite_provider, crypto::test_utils::TestCryptoProvider, - group::test_utils::{get_test_group_context, random_bytes, TEST_GROUP}, + group::test_utils::{get_test_group_context, random_bytes}, identity::basic::BasicIdentityProvider, tree_kem::{ kem::TreeKem, @@ -238,17 +236,12 @@ mod tests { let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); let mut test_tree = get_valid_tree(cipher_suite).await; - let expected_tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap(); - let extensions = ExtensionList::new(); + let mut context = get_test_group_context(1, cipher_suite).await; + context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap(); - let validator = TreeValidator::new( - &cipher_suite_provider, - TEST_GROUP, - &expected_tree_hash, - &extensions, - &BasicIdentityProvider, - ); + let validator = + TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider); validator.validate(&mut test_tree).await.unwrap(); } @@ -258,18 +251,12 @@ mod tests { async fn test_tree_hash_mismatch() { for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { let mut test_tree = get_valid_tree(cipher_suite).await; - let expected_tree_hash = random_bytes(32); let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); - let extensions = ExtensionList::new(); + let context = get_test_group_context(1, cipher_suite).await; - let validator = TreeValidator::new( - &cipher_suite_provider, - b"", - &expected_tree_hash, - &extensions, - &BasicIdentityProvider, - ); + let validator = + TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider); let res = validator.validate(&mut test_tree).await; @@ -286,17 +273,11 @@ mod tests { parent_node.parent_hash = ParentHash::from(random_bytes(32)); let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); - let expected_tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap(); + let mut context = get_test_group_context(1, cipher_suite).await; + context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap(); - let extensions = ExtensionList::new(); - - let validator = TreeValidator::new( - &cipher_suite_provider, - b"", - &expected_tree_hash, - &extensions, - &BasicIdentityProvider, - ); + let validator = + TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider); let res = validator.validate(&mut test_tree).await; @@ -316,16 +297,11 @@ mod tests { .signature = random_bytes(32); let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); - let expected_tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap(); - let extensions = ExtensionList::new(); + let mut context = get_test_group_context(1, cipher_suite).await; + context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap(); - let validator = TreeValidator::new( - &cipher_suite_provider, - b"", - &expected_tree_hash, - &extensions, - &BasicIdentityProvider, - ); + let validator = + TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider); let res = validator.validate(&mut test_tree).await; diff --git a/mls-rs/test_harness_integration/src/branch_reinit.rs b/mls-rs/test_harness_integration/src/branch_reinit.rs index 030f2239..c7f3fe01 100644 --- a/mls-rs/test_harness_integration/src/branch_reinit.rs +++ b/mls-rs/test_harness_integration/src/branch_reinit.rs @@ -75,7 +75,7 @@ pub(crate) mod inner { .map_err(abort)?; let (group, _info) = reinit_client - .join(welcome, get_tree(&request.ratchet_tree)?) + .join(&welcome, get_tree(&request.ratchet_tree)?) .map_err(abort)?; let resp = JoinGroupResponse { @@ -120,7 +120,7 @@ pub(crate) mod inner { let welcome = MlsMessage::from_bytes(&request.welcome).map_err(abort)?; - let (new_group, _info) = group.join_subgroup(welcome, tree).map_err(abort)?; + let (new_group, _info) = group.join_subgroup(&welcome, tree).map_err(abort)?; let resp = HandleBranchResponse { state_id: request.state_id, diff --git a/mls-rs/test_harness_integration/src/by_ref_proposal.rs b/mls-rs/test_harness_integration/src/by_ref_proposal.rs index 2d54f13d..ae51d504 100644 --- a/mls-rs/test_harness_integration/src/by_ref_proposal.rs +++ b/mls-rs/test_harness_integration/src/by_ref_proposal.rs @@ -371,7 +371,7 @@ pub(crate) mod external_proposal { let proposal = client .client - .external_add_proposal(group_info, None, vec![]) + .external_add_proposal(&group_info, None, vec![]) .map_err(abort)? .to_bytes() .map_err(abort)?; diff --git a/mls-rs/test_harness_integration/src/main.rs b/mls-rs/test_harness_integration/src/main.rs index 5bbfe7d6..5b1a472f 100644 --- a/mls-rs/test_harness_integration/src/main.rs +++ b/mls-rs/test_harness_integration/src/main.rs @@ -271,7 +271,7 @@ impl MlsClient for MlsClientImpl { let (group, _) = client .client - .join_group(get_tree(&request.ratchet_tree)?, welcome_msg) + .join_group(get_tree(&request.ratchet_tree)?, &welcome_msg) .map_err(abort)?; let epoch_authenticator = group.epoch_authenticator().map_err(abort)?.to_vec(); diff --git a/mls-rs/tests/client_tests.rs b/mls-rs/tests/client_tests.rs index 3a32aa66..1f4b28a1 100644 --- a/mls-rs/tests/client_tests.rs +++ b/mls-rs/tests/client_tests.rs @@ -170,24 +170,20 @@ async fn test_create( .await .unwrap(); - let welcome = alice_group + let welcome = &alice_group .commit_builder() .add_member(bob_key_pkg) .unwrap() .build() .await .unwrap() - .welcome_messages - .remove(0); + .welcome_messages[0]; // Upon server confirmation, alice applies the commit to her own state alice_group.apply_pending_commit().await.unwrap(); // Bob receives the welcome message and joins the group - let (bob_group, _) = bob - .join_group(Some(alice_group.export_tree()), welcome) - .await - .unwrap(); + let (bob_group, _) = bob.join_group(None, welcome).await.unwrap(); assert!(Group::equal_group_state(&alice_group, &bob_group)); } @@ -517,7 +513,6 @@ async fn external_commits_work( let (new_group, commit) = client .external_commit_builder() .unwrap() - .with_tree_data(existing_group.export_tree().into_owned()) .build(group_info) .await .unwrap(); @@ -605,22 +600,18 @@ async fn reinit_works() { let mut alice_group = alice1.create_group(ExtensionList::new()).await.unwrap(); let kp = bob1.generate_key_package_message().await.unwrap(); - let welcome = alice_group + let welcome = &alice_group .commit_builder() .add_member(kp) .unwrap() .build() .await .unwrap() - .welcome_messages - .remove(0); + .welcome_messages[0]; alice_group.apply_pending_commit().await.unwrap(); - let (mut bob_group, _) = bob1 - .join_group(Some(alice_group.export_tree()), welcome) - .await - .unwrap(); + let (mut bob_group, _) = bob1.join_group(None, welcome).await.unwrap(); // Alice proposes reinit let reinit_proposal_message = alice_group @@ -702,12 +693,8 @@ async fn reinit_works() { // Bob produces key package, alice commits, bob joins let kp = bob2.generate_key_package().await.unwrap(); - let (mut alice_group, mut welcome) = alice2.commit(vec![kp]).await.unwrap(); - - let (mut bob_group, _) = bob2 - .join(welcome.remove(0), Some(alice_group.export_tree())) - .await - .unwrap(); + let (mut alice_group, welcome) = alice2.commit(vec![kp]).await.unwrap(); + let (mut bob_group, _) = bob2.join(&welcome[0], None).await.unwrap(); assert!(bob_group.cipher_suite() == suite2); @@ -716,7 +703,7 @@ async fn reinit_works() { let kp = carol.generate_key_package_message().await.unwrap(); - let mut commit_output = alice_group + let commit_output = alice_group .commit_builder() .add_member(kp) .unwrap() @@ -732,10 +719,7 @@ async fn reinit_works() { .unwrap(); carol - .join_group( - Some(alice_group.export_tree()), - commit_output.welcome_messages.remove(0), - ) + .join_group(None, &commit_output.welcome_messages[0]) .await .unwrap(); }