From cc0ff86c018ad22825f0d4cfd04e735b8bc25342 Mon Sep 17 00:00:00 2001 From: mulmarta <103590845+mulmarta@users.noreply.github.com> Date: Wed, 6 Dec 2023 00:47:33 +0100 Subject: [PATCH] Add context to identity functions (#52) --- mls-rs-core/src/identity/provider.rs | 7 +- mls-rs-identity-x509/src/provider.rs | 2 + mls-rs/src/client.rs | 1 - mls-rs/src/external_client/group.rs | 8 +- mls-rs/src/group/commit.rs | 6 +- .../group/interop_test_vectors/tree_kem.rs | 15 +-- .../tree_modifications.rs | 15 +-- mls-rs/src/group/message_processor.rs | 1 + mls-rs/src/group/mod.rs | 16 ++- mls-rs/src/group/proposal_cache.rs | 62 ++++++--- mls-rs/src/group/proposal_filter/filtering.rs | 7 +- .../group/proposal_filter/filtering_common.rs | 24 +++- .../group/proposal_filter/filtering_lite.rs | 1 + mls-rs/src/group/snapshot.rs | 2 +- mls-rs/src/group/util.rs | 80 ++--------- mls-rs/src/identity.rs | 47 ++++--- mls-rs/src/identity/basic.rs | 7 +- mls-rs/src/tree_kem/interop_test_vectors.rs | 2 +- mls-rs/src/tree_kem/kem.rs | 13 +- mls-rs/src/tree_kem/leaf_node_validator.rs | 7 +- mls-rs/src/tree_kem/mod.rs | 126 ++++++++++++------ mls-rs/src/tree_kem/private.rs | 12 +- mls-rs/src/tree_kem/tree_hash.rs | 6 +- mls-rs/src/tree_kem/tree_index.rs | 8 +- mls-rs/src/tree_kem/update_path.rs | 1 + 25 files changed, 271 insertions(+), 205 deletions(-) diff --git a/mls-rs-core/src/identity/provider.rs b/mls-rs-core/src/identity/provider.rs index 353b80cc..7b5fcc44 100644 --- a/mls-rs-core/src/identity/provider.rs +++ b/mls-rs-core/src/identity/provider.rs @@ -71,7 +71,11 @@ pub trait IdentityProvider: Send + Sync { /// /// The MLS protocol requires that each member of a group has a unique /// set of identifiers according to the application. - async fn identity(&self, signing_identity: &SigningIdentity) -> Result, Self::Error>; + async fn identity( + &self, + signing_identity: &SigningIdentity, + extensions: &ExtensionList, + ) -> Result, Self::Error>; /// Determines if `successor` can remove `predecessor` as part of an external commit. /// @@ -83,6 +87,7 @@ pub trait IdentityProvider: Send + Sync { &self, predecessor: &SigningIdentity, successor: &SigningIdentity, + extensions: &ExtensionList, ) -> Result; /// Credential types that are supported by this provider. diff --git a/mls-rs-identity-x509/src/provider.rs b/mls-rs-identity-x509/src/provider.rs index 8508c85b..2c6e03d2 100644 --- a/mls-rs-identity-x509/src/provider.rs +++ b/mls-rs-identity-x509/src/provider.rs @@ -214,6 +214,7 @@ where async fn identity( &self, signing_id: &mls_rs_core::identity::SigningIdentity, + _extensions: &ExtensionList, ) -> Result, Self::Error> { self.identity(signing_id) } @@ -222,6 +223,7 @@ where &self, predecessor: &mls_rs_core::identity::SigningIdentity, successor: &mls_rs_core::identity::SigningIdentity, + _extensions: &ExtensionList, ) -> Result { self.valid_successor(predecessor, successor) } diff --git a/mls-rs/src/client.rs b/mls-rs/src/client.rs index b1fae1a5..db7a4ac1 100644 --- a/mls-rs/src/client.rs +++ b/mls-rs/src/client.rs @@ -649,7 +649,6 @@ where protocol_version, group_info, tree_data, - #[cfg(feature = "tree_index")] &self.config.identity_provider(), &cipher_suite_provider, ) diff --git a/mls-rs/src/external_client/group.rs b/mls-rs/src/external_client/group.rs index 22e815d7..715033ce 100644 --- a/mls-rs/src/external_client/group.rs +++ b/mls-rs/src/external_client/group.rs @@ -530,7 +530,7 @@ impl ExternalGroup { ) -> Result { let identity = self .identity_provider() - .identity(identity_id) + .identity(identity_id, self.context_extensions()) .await .map_err(|error| MlsError::IdentityProviderError(error.into_any_error()))?; @@ -541,7 +541,11 @@ impl ExternalGroup { #[cfg(not(feature = "tree_index"))] let index = tree - .get_leaf_node_with_identity(&identity, &self.identity_provider()) + .get_leaf_node_with_identity( + &identity, + &self.identity_provider(), + self.context_extensions(), + ) .await?; let index = index.ok_or(MlsError::MemberNotFound)?; diff --git a/mls-rs/src/group/commit.rs b/mls-rs/src/group/commit.rs index 90351b99..71adc068 100644 --- a/mls-rs/src/group/commit.rs +++ b/mls-rs/src/group/commit.rs @@ -1356,7 +1356,7 @@ mod tests { ) -> bool { if let Some(extensions) = extensions { if let Some(ext) = extensions.get_as::().unwrap() { - self.identity(identity).await.unwrap()[0] == ext.foo + self.identity(identity, extensions).await.unwrap()[0] == ext.foo } else { true } @@ -1397,9 +1397,10 @@ mod tests { async fn identity( &self, signing_identity: &SigningIdentity, + extensions: &ExtensionList, ) -> Result, Self::Error> { self.0 - .identity(signing_identity) + .identity(signing_identity, extensions) .await .map_err(|_| IdentityProviderWithExtensionError {}) } @@ -1408,6 +1409,7 @@ mod tests { &self, _predecessor: &SigningIdentity, _successor: &SigningIdentity, + _extensions: &ExtensionList, ) -> Result { Ok(true) } diff --git a/mls-rs/src/group/interop_test_vectors/tree_kem.rs b/mls-rs/src/group/interop_test_vectors/tree_kem.rs index 3e589bba..9ae0aee8 100644 --- a/mls-rs/src/group/interop_test_vectors/tree_kem.rs +++ b/mls-rs/src/group/interop_test_vectors/tree_kem.rs @@ -10,6 +10,7 @@ use crate::{ message_signature::AuthenticatedContent, test_utils::GroupWithoutKeySchedule, Commit, GroupContext, PathSecret, Sender, }, + identity::basic::BasicIdentityProvider, tree_kem::{ node::{LeafIndex, NodeVec}, TreeKemPrivate, TreeKemPublic, UpdatePath, @@ -21,9 +22,6 @@ use alloc::vec::Vec; use mls_rs_codec::MlsDecode; use mls_rs_core::{crypto::CipherSuiteProvider, extension::ExtensionList}; -#[cfg(feature = "tree_index")] -use crate::identity::basic::BasicIdentityProvider; - #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)] struct TreeKemTestCase { pub cipher_suite: u16, @@ -83,13 +81,10 @@ async fn tree_kem() { // Import the public ratchet tree let nodes = NodeVec::mls_decode(&mut &*test_case.ratchet_tree).unwrap(); - let mut tree = TreeKemPublic::import_node_data( - nodes, - #[cfg(feature = "tree_index")] - &BasicIdentityProvider, - ) - .await - .unwrap(); + let mut tree = + TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default()) + .await + .unwrap(); // Construct GroupContext let group_context = GroupContext { diff --git a/mls-rs/src/group/interop_test_vectors/tree_modifications.rs b/mls-rs/src/group/interop_test_vectors/tree_modifications.rs index 79a5b6ab..6f82d34e 100644 --- a/mls-rs/src/group/interop_test_vectors/tree_modifications.rs +++ b/mls-rs/src/group/interop_test_vectors/tree_modifications.rs @@ -18,15 +18,13 @@ use crate::{ test_utils::TEST_GROUP, LeafIndex, Sender, TreeKemPublic, }, + identity::basic::BasicIdentityProvider, key_package::test_utils::test_key_package, tree_kem::{ leaf_node::test_utils::default_properties, node::NodeVec, test_utils::TreeWithSigners, }, }; -#[cfg(feature = "tree_index")] -use crate::identity::basic::BasicIdentityProvider; - #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)] struct TreeModsTestCase { #[serde(with = "hex::serde")] @@ -112,13 +110,10 @@ async fn tree_modifications_interop() { for test_case in test_cases.into_iter() { let nodes = NodeVec::mls_decode(&mut &*test_case.tree_before).unwrap(); - let tree_before = TreeKemPublic::import_node_data( - nodes, - #[cfg(feature = "tree_index")] - &BasicIdentityProvider, - ) - .await - .unwrap(); + let tree_before = + TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default()) + .await + .unwrap(); let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap(); diff --git a/mls-rs/src/group/message_processor.rs b/mls-rs/src/group/message_processor.rs index 0013e3ba..4557b62d 100644 --- a/mls-rs/src/group/message_processor.rs +++ b/mls-rs/src/group/message_processor.rs @@ -885,6 +885,7 @@ pub(crate) trait MessageProcessor: Send + Sync { .apply_update_path( sender, update_path, + &provisional_state.group_context.extensions, self.identity_provider(), self.cipher_suite_provider(), ) diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 8624f3fe..1f84fb61 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -305,8 +305,13 @@ where .check_if_valid(&leaf_node, ValidationContext::Add(None)) .await?; - let (mut public_tree, private_tree) = - TreeKemPublic::derive(leaf_node, leaf_node_secret, &config.identity_provider()).await?; + let (mut public_tree, private_tree) = TreeKemPublic::derive( + leaf_node, + leaf_node_secret, + &config.identity_provider(), + &group_context_extensions, + ) + .await?; let tree_hash = public_tree.tree_hash(&cipher_suite_provider).await?; @@ -1320,7 +1325,11 @@ where #[cfg(not(feature = "tree_index"))] let index = tree - .get_leaf_node_with_identity(identity, &self.identity_provider()) + .get_leaf_node_with_identity( + identity, + &self.identity_provider(), + &self.state.context.extensions, + ) .await?; let index = index.ok_or(MlsError::MemberNotFound)?; @@ -1615,6 +1624,7 @@ where .apply_update_path( sender, update_path, + &provisional_state.group_context.extensions, self.identity_provider(), self.cipher_suite_provider(), ) diff --git a/mls-rs/src/group/proposal_cache.rs b/mls-rs/src/group/proposal_cache.rs index 1cca87d9..4ec498f2 100644 --- a/mls-rs/src/group/proposal_cache.rs +++ b/mls-rs/src/group/proposal_cache.rs @@ -655,9 +655,10 @@ mod tests { ) .await; - let (pub_tree, priv_tree) = TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider) - .await - .unwrap(); + let (pub_tree, priv_tree) = + TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider, &Default::default()) + .await + .unwrap(); (priv_tree.self_index, pub_tree) } @@ -717,10 +718,14 @@ mod tests { let sender = LeafIndex(0); - let (mut tree, _) = - TreeKemPublic::derive(sender_leaf, sender_leaf_secret, &BasicIdentityProvider) - .await - .unwrap(); + let (mut tree, _) = TreeKemPublic::derive( + sender_leaf, + sender_leaf_secret, + &BasicIdentityProvider, + &Default::default(), + ) + .await + .unwrap(); let add_package = test_key_package(protocol_version, cipher_suite, "dave").await; @@ -774,6 +779,7 @@ mod tests { expected_tree .batch_edit( &mut bundle, + &Default::default(), &BasicIdentityProvider, &cipher_suite_provider, true, @@ -1699,9 +1705,14 @@ mod tests { get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await; let alice = 0; - let (mut tree, _) = TreeKemPublic::derive(alice_leaf, alice_secret, &BasicIdentityProvider) - .await - .unwrap(); + let (mut tree, _) = TreeKemPublic::derive( + alice_leaf, + alice_secret, + &BasicIdentityProvider, + &Default::default(), + ) + .await + .unwrap(); let bob_node = get_basic_test_node(TEST_CIPHER_SUITE, "bob").await; @@ -2983,9 +2994,10 @@ mod tests { .await .unwrap(); - let (pub_tree, priv_tree) = TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider) - .await - .unwrap(); + let (pub_tree, priv_tree) = + TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider, &Default::default()) + .await + .unwrap(); (priv_tree.self_index, pub_tree) }; @@ -3387,10 +3399,14 @@ mod tests { let (alice_leaf, alice_secret, alice_signer) = get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await; - let (mut tree, priv_tree) = - TreeKemPublic::derive(alice_leaf.clone(), alice_secret, &BasicIdentityProvider) - .await - .unwrap(); + let (mut tree, priv_tree) = TreeKemPublic::derive( + alice_leaf.clone(), + alice_secret, + &BasicIdentityProvider, + &Default::default(), + ) + .await + .unwrap(); let alice = priv_tree.self_index; @@ -3460,10 +3476,14 @@ mod tests { let (alice_leaf, alice_secret, alice_signer) = get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await; - let (mut tree, priv_tree) = - TreeKemPublic::derive(alice_leaf.clone(), alice_secret, &BasicIdentityProvider) - .await - .unwrap(); + let (mut tree, priv_tree) = TreeKemPublic::derive( + alice_leaf.clone(), + alice_secret, + &BasicIdentityProvider, + &Default::default(), + ) + .await + .unwrap(); let alice = priv_tree.self_index; diff --git a/mls-rs/src/group/proposal_filter/filtering.rs b/mls-rs/src/group/proposal_filter/filtering.rs index e318ee0b..8e67ff58 100644 --- a/mls-rs/src/group/proposal_filter/filtering.rs +++ b/mls-rs/src/group/proposal_filter/filtering.rs @@ -141,6 +141,7 @@ where let added = new_tree .batch_edit( &mut applied_proposals, + group_extensions_in_use, self.identity_provider, self.cipher_suite_provider, strategy.is_ignore(), @@ -195,7 +196,11 @@ where let valid_successor = self .identity_provider - .valid_successor(&old_leaf.signing_identity, &leaf.signing_identity) + .valid_successor( + &old_leaf.signing_identity, + &leaf.signing_identity, + group_extensions_in_use, + ) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error())) .and_then(|valid| valid.then_some(()).ok_or(MlsError::InvalidSuccessor)); diff --git a/mls-rs/src/group/proposal_filter/filtering_common.rs b/mls-rs/src/group/proposal_filter/filtering_common.rs index c14bb723..37ae0898 100644 --- a/mls-rs/src/group/proposal_filter/filtering_common.rs +++ b/mls-rs/src/group/proposal_filter/filtering_common.rs @@ -166,6 +166,7 @@ where external_leaf, self.original_tree, self.identity_provider, + self.original_group_extensions, ) .await?; @@ -201,6 +202,7 @@ where &mut output.new_tree, external_leaf.clone(), self.identity_provider, + self.original_group_extensions, ) .await?, ); @@ -504,6 +506,7 @@ async fn ensure_at_most_one_removal_for_self( external_leaf: &LeafNode, tree: &TreeKemPublic, identity_provider: &C, + extensions: &ExtensionList, ) -> Result<(), MlsError> where C: IdentityProvider, @@ -512,8 +515,14 @@ where match (removals.next(), removals.next()) { (Some(removal), None) => { - ensure_removal_is_for_self(&removal.proposal, external_leaf, tree, identity_provider) - .await + ensure_removal_is_for_self( + &removal.proposal, + external_leaf, + tree, + identity_provider, + extensions, + ) + .await } (Some(_), Some(_)) => Err(MlsError::ExternalCommitWithMoreThanOneRemove), (None, _) => Ok(()), @@ -526,6 +535,7 @@ async fn ensure_removal_is_for_self( external_leaf: &LeafNode, tree: &TreeKemPublic, identity_provider: &C, + extensions: &ExtensionList, ) -> Result<(), MlsError> where C: IdentityProvider, @@ -533,7 +543,11 @@ where let existing_signing_id = &tree.get_leaf_node(removal.to_remove)?.signing_identity; identity_provider - .valid_successor(existing_signing_id, &external_leaf.signing_identity) + .valid_successor( + existing_signing_id, + &external_leaf.signing_identity, + extensions, + ) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))? .then_some(()) @@ -553,6 +567,8 @@ async fn insert_external_leaf( tree: &mut TreeKemPublic, leaf_node: LeafNode, identity_provider: &I, + extensions: &ExtensionList, ) -> Result { - tree.add_leaf(leaf_node, identity_provider, None).await + tree.add_leaf(leaf_node, identity_provider, extensions, None) + .await } diff --git a/mls-rs/src/group/proposal_filter/filtering_lite.rs b/mls-rs/src/group/proposal_filter/filtering_lite.rs index acd07c1f..09ca3899 100644 --- a/mls-rs/src/group/proposal_filter/filtering_lite.rs +++ b/mls-rs/src/group/proposal_filter/filtering_lite.rs @@ -98,6 +98,7 @@ where let added = new_tree .batch_edit_lite( proposals, + group_extensions_in_use, self.identity_provider, self.cipher_suite_provider, ) diff --git a/mls-rs/src/group/snapshot.rs b/mls-rs/src/group/snapshot.rs index fc0491f9..851b3640 100644 --- a/mls-rs/src/group/snapshot.rs +++ b/mls-rs/src/group/snapshot.rs @@ -110,7 +110,7 @@ impl RawGroupState { let mut public_tree = self.public_tree; public_tree - .initialize_index_if_necessary(identity_provider) + .initialize_index_if_necessary(identity_provider, &context.extensions) .await?; Ok(GroupState { diff --git a/mls-rs/src/group/util.rs b/mls-rs/src/group/util.rs index 060861e3..fe299136 100644 --- a/mls-rs/src/group/util.rs +++ b/mls-rs/src/group/util.rs @@ -43,22 +43,6 @@ pub(crate) struct JoinContext { pub signer_index: LeafIndex, } -#[cfg(not(feature = "tree_index"))] -#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] -pub(crate) async fn process_group_info( - msg_protocol_version: ProtocolVersion, - group_info: GroupInfo, - tree_data: Option<&[u8]>, - cs: &C, -) -> Result -where - C: CipherSuiteProvider, -{ - let public_tree = find_tree(tree_data, group_info.extensions.get_as()?).await?; - process_group_info_with_tree(msg_protocol_version, group_info, public_tree, cs).await -} - -#[cfg(feature = "tree_index")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(crate) async fn process_group_info( msg_protocol_version: ProtocolVersion, @@ -71,27 +55,21 @@ where C: CipherSuiteProvider, I: IdentityProvider, { - let public_tree = find_tree(tree_data, group_info.extensions.get_as()?, id_provider).await?; - process_group_info_with_tree(msg_protocol_version, group_info, public_tree, cs).await -} + let tree_data = match group_info.extensions.get_as::()? { + Some(ext) => ext.tree_data, + None => NodeVec::mls_decode(&mut tree_data.ok_or(MlsError::RatchetTreeNotFound)?)?, + }; + + let context_ext = &group_info.group_context.extensions; + let public_tree = TreeKemPublic::import_node_data(tree_data, id_provider, context_ext).await?; -#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] -async fn process_group_info_with_tree( - msg_protocol_version: ProtocolVersion, - group_info: GroupInfo, - public_tree: TreeKemPublic, - cipher_suite_provider: &C, -) -> Result -where - C: CipherSuiteProvider, -{ let group_protocol_version = group_info.group_context.protocol_version; if msg_protocol_version != group_protocol_version { return Err(MlsError::ProtocolVersionMismatch); } - let cipher_suite = cipher_suite_provider.cipher_suite(); + let cipher_suite = cs.cipher_suite(); if group_info.group_context.cipher_suite != cipher_suite { return Err(MlsError::CipherSuiteMismatch); @@ -100,11 +78,7 @@ where let sender_key_package = public_tree.get_leaf_node(group_info.signer)?; group_info - .verify( - cipher_suite_provider, - &sender_key_package.signing_identity.signature_key, - &(), - ) + .verify(cs, &sender_key_package.signing_identity.signature_key, &()) .await?; let confirmation_tag = group_info.confirmation_tag; @@ -141,7 +115,6 @@ pub(crate) async fn validate_group_info( - tree_data: Option<&[u8]>, - extension: Option, - identity_provider: &C, -) -> Result -where - C: IdentityProvider, -{ - TreeKemPublic::import_node_data(find_node_data(tree_data, extension)?, identity_provider).await -} - -#[cfg(not(feature = "tree_index"))] -#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] -pub(crate) async fn find_tree( - tree_data: Option<&[u8]>, - extension: Option, -) -> Result { - TreeKemPublic::import_node_data(find_node_data(tree_data, extension)?).await -} - -pub(crate) fn find_node_data( - tree_data: Option<&[u8]>, - extension: Option, -) -> Result { - match tree_data { - Some(tree_data) => Ok(NodeVec::mls_decode(&mut &*tree_data)?), - None => { - let tree_extension = extension.ok_or(MlsError::RatchetTreeNotFound)?; - Ok(tree_extension.tree_data) - } - } -} - pub(crate) fn commit_sender( sender: &Sender, provisional_state: &ProvisionalState, diff --git a/mls-rs/src/identity.rs b/mls-rs/src/identity.rs index 6534211f..cd067103 100644 --- a/mls-rs/src/identity.rs +++ b/mls-rs/src/identity.rs @@ -83,24 +83,28 @@ pub(crate) mod test_utils { &self, signing_id: &SigningIdentity, ) -> Result, BasicWithCustomProviderError> { - self.basic.identity(signing_id).await.or_else(|_| { - signing_id - .credential - .as_custom() - .map(|c| { - if c.credential_type == CredentialType::from(Self::CUSTOM_CREDENTIAL_TYPE) - || self.allow_any_custom - { - Ok(c.data.to_vec()) - } else { - Err(BasicWithCustomProviderError(c.credential_type)) - } - }) - .transpose()? - .ok_or_else(|| { - BasicWithCustomProviderError(signing_id.credential.credential_type()) - }) - }) + self.basic + .identity(signing_id, &Default::default()) + .await + .or_else(|_| { + signing_id + .credential + .as_custom() + .map(|c| { + if c.credential_type + == CredentialType::from(Self::CUSTOM_CREDENTIAL_TYPE) + || self.allow_any_custom + { + Ok(c.data.to_vec()) + } else { + Err(BasicWithCustomProviderError(c.credential_type)) + } + }) + .transpose()? + .ok_or_else(|| { + BasicWithCustomProviderError(signing_id.credential.credential_type()) + }) + }) } } @@ -137,7 +141,11 @@ pub(crate) mod test_utils { Ok(()) } - async fn identity(&self, signing_id: &SigningIdentity) -> Result, Self::Error> { + async fn identity( + &self, + signing_id: &SigningIdentity, + _extensions: &ExtensionList, + ) -> Result, Self::Error> { self.resolve_custom_identity(signing_id).await } @@ -145,6 +153,7 @@ pub(crate) mod test_utils { &self, predecessor: &SigningIdentity, successor: &SigningIdentity, + _extensions: &ExtensionList, ) -> Result { let predecessor = self.resolve_custom_identity(predecessor).await?; let successor = self.resolve_custom_identity(successor).await?; diff --git a/mls-rs/src/identity/basic.rs b/mls-rs/src/identity/basic.rs index 5e48bad0..e5987331 100644 --- a/mls-rs/src/identity/basic.rs +++ b/mls-rs/src/identity/basic.rs @@ -81,7 +81,11 @@ impl IdentityProvider for BasicIdentityProvider { resolve_basic_identity(signing_identity).map(|_| ()) } - async fn identity(&self, signing_identity: &SigningIdentity) -> Result, Self::Error> { + async fn identity( + &self, + signing_identity: &SigningIdentity, + _extensions: &ExtensionList, + ) -> Result, Self::Error> { resolve_basic_identity(signing_identity).map(|b| b.identifier.to_vec()) } @@ -89,6 +93,7 @@ impl IdentityProvider for BasicIdentityProvider { &self, predecessor: &SigningIdentity, successor: &SigningIdentity, + _extensions: &ExtensionList, ) -> Result { Ok(resolve_basic_identity(predecessor)? == resolve_basic_identity(successor)?) } diff --git a/mls-rs/src/tree_kem/interop_test_vectors.rs b/mls-rs/src/tree_kem/interop_test_vectors.rs index 0b06cb4f..9659fc41 100644 --- a/mls-rs/src/tree_kem/interop_test_vectors.rs +++ b/mls-rs/src/tree_kem/interop_test_vectors.rs @@ -95,8 +95,8 @@ async fn validation() { let mut tree = TreeKemPublic::import_node_data( NodeVec::mls_decode(&mut &*test_case.tree).unwrap(), - #[cfg(feature = "tree_index")] &BasicIdentityProvider, + &Default::default(), ) .await .unwrap(); diff --git a/mls-rs/src/tree_kem/kem.rs b/mls-rs/src/tree_kem/kem.rs index e37a936d..3a42205c 100644 --- a/mls-rs/src/tree_kem/kem.rs +++ b/mls-rs/src/tree_kem/kem.rs @@ -545,10 +545,14 @@ mod tests { get_basic_test_node_sig_key(cipher_suite, "encap").await; // Build a test tree we can clone for all leaf nodes - let (mut test_tree, mut encap_private_key) = - TreeKemPublic::derive(encap_node, encap_hpke_secret, &BasicIdentityProvider) - .await - .unwrap(); + let (mut test_tree, mut encap_private_key) = TreeKemPublic::derive( + encap_node, + encap_hpke_secret, + &BasicIdentityProvider, + &Default::default(), + ) + .await + .unwrap(); test_tree .add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider) @@ -619,6 +623,7 @@ mod tests { tree.apply_update_path( LeafIndex(0), &validated_update_path, + &Default::default(), BasicIdentityProvider, &cipher_suite_provider, ) diff --git a/mls-rs/src/tree_kem/leaf_node_validator.rs b/mls-rs/src/tree_kem/leaf_node_validator.rs index 92135387..99eba83f 100644 --- a/mls-rs/src/tree_kem/leaf_node_validator.rs +++ b/mls-rs/src/tree_kem/leaf_node_validator.rs @@ -683,7 +683,11 @@ pub(crate) mod test_utils { } #[cfg_attr(coverage_nightly, coverage(off))] - async fn identity(&self, signing_id: &SigningIdentity) -> Result, Self::Error> { + async fn identity( + &self, + signing_id: &SigningIdentity, + _extensions: &ExtensionList, + ) -> Result, Self::Error> { Ok(signing_id.credential.mls_encode_to_vec().unwrap()) } @@ -692,6 +696,7 @@ pub(crate) mod test_utils { &self, _predecessor: &SigningIdentity, _successor: &SigningIdentity, + _extensions: &ExtensionList, ) -> Result { Err(TestFailureError) } diff --git a/mls-rs/src/tree_kem/mod.rs b/mls-rs/src/tree_kem/mod.rs index 54345e3d..1156e0d8 100644 --- a/mls-rs/src/tree_kem/mod.rs +++ b/mls-rs/src/tree_kem/mod.rs @@ -8,6 +8,7 @@ use alloc::vec::Vec; use core::fmt::Display; use itertools::Itertools; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; +use mls_rs_core::extension::ExtensionList; use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider}; @@ -84,20 +85,12 @@ impl TreeKemPublic { Default::default() } - #[cfg(not(feature = "tree_index"))] - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub(crate) async fn import_node_data(nodes: NodeVec) -> Result { - Ok(TreeKemPublic { - nodes, - ..Default::default() - }) - } - - #[cfg(feature = "tree_index")] + #[cfg_attr(not(feature = "tree_index"), allow(unused))] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(crate) async fn import_node_data( nodes: NodeVec, identity_provider: &IP, + extensions: &ExtensionList, ) -> Result where IP: IdentityProvider, @@ -107,7 +100,8 @@ impl TreeKemPublic { ..Default::default() }; - tree.initialize_index_if_necessary(identity_provider) + #[cfg(feature = "tree_index")] + tree.initialize_index_if_necessary(identity_provider, extensions) .await?; Ok(tree) @@ -118,12 +112,20 @@ impl TreeKemPublic { pub(crate) async fn initialize_index_if_necessary( &mut self, identity_provider: &IP, + extensions: &ExtensionList, ) -> Result<(), MlsError> { if !self.index.is_initialized() { self.index = TreeIndex::new(); for (leaf_index, leaf) in self.nodes.non_empty_leaves() { - index_insert(&mut self.index, leaf, leaf_index, identity_provider).await?; + index_insert( + &mut self.index, + leaf, + leaf_index, + identity_provider, + extensions, + ) + .await?; } } @@ -141,10 +143,11 @@ impl TreeKemPublic { &self, identity: &[u8], id_provider: &I, + extensions: &ExtensionList, ) -> Result, MlsError> { for (i, leaf) in self.nodes.non_empty_leaves() { let leaf_id = id_provider - .identity(&leaf.signing_identity) + .identity(&leaf.signing_identity, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; @@ -165,11 +168,12 @@ impl TreeKemPublic { leaf_node: LeafNode, secret_key: HpkeSecretKey, identity_provider: &I, + extensions: &ExtensionList, ) -> Result<(TreeKemPublic, TreeKemPrivate), MlsError> { let mut public_tree = TreeKemPublic::new(); public_tree - .add_leaf(leaf_node, identity_provider, None) + .add_leaf(leaf_node, identity_provider, extensions, None) .await?; let private_tree = TreeKemPrivate::new_self_leaf(LeafIndex(0), secret_key); @@ -225,7 +229,9 @@ impl TreeKemPublic { let mut added = vec![]; for leaf in leaf_nodes.into_iter() { - start = self.add_leaf(leaf, id_provider, Some(start)).await?; + start = self + .add_leaf(leaf, id_provider, &Default::default(), Some(start)) + .await?; added.push(start); } @@ -261,6 +267,7 @@ impl TreeKemPublic { &mut self, sender: LeafIndex, update_path: &ValidatedUpdatePath, + extensions: &ExtensionList, identity_provider: IP, cipher_suite_provider: &CP, ) -> Result<(), MlsError> @@ -276,7 +283,7 @@ impl TreeKemPublic { #[cfg(feature = "tree_index")] let original_identity = identity_provider - .identity(&original_leaf_node.signing_identity) + .identity(&original_leaf_node.signing_identity, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; @@ -302,6 +309,7 @@ impl TreeKemPublic { &update_path.leaf_node, sender, &identity_provider, + extensions, ) .await?; @@ -329,6 +337,7 @@ impl TreeKemPublic { pub async fn batch_edit( &mut self, proposal_bundle: &mut ProposalBundle, + extensions: &ExtensionList, id_provider: &I, cipher_suite_provider: &CP, filter: bool, @@ -350,7 +359,9 @@ impl TreeKemPublic { #[cfg(feature = "tree_index")] if let Ok(old_leaf) = &res { // If this fails, it's not because the proposal is bad. - let identity = identity(&old_leaf.signing_identity, id_provider).await?; + let identity = + identity(&old_leaf.signing_identity, id_provider, extensions).await?; + self.index.remove(old_leaf, &identity); } @@ -371,7 +382,9 @@ impl TreeKemPublic { match self.nodes.blank_leaf_node(index) { Ok(old_leaf) => { #[cfg(feature = "tree_index")] - let old_id = identity(&old_leaf.signing_identity, id_provider).await?; + let old_id = + identity(&old_leaf.signing_identity, id_provider, extensions).await?; + #[cfg(feature = "tree_index")] self.index.remove(&old_leaf, &old_id); @@ -396,9 +409,11 @@ impl TreeKemPublic { // all updates. for (index, old_leaf, new_leaf, i) in partial_updates.into_iter() { #[cfg(feature = "tree_index")] - let res = index_insert(&mut self.index, &new_leaf, index, id_provider).await; + let res = + index_insert(&mut self.index, &new_leaf, index, id_provider, extensions).await; + #[cfg(not(feature = "tree_index"))] - let res = index_insert(&self.nodes, &new_leaf, index, id_provider).await; + let res = index_insert(&self.nodes, &new_leaf, index, id_provider, extensions).await; let err = res.is_err(); @@ -412,9 +427,12 @@ impl TreeKemPublic { updated_indices.push(index); } else { #[cfg(feature = "tree_index")] - let res = index_insert(&mut self.index, &old_leaf, index, id_provider).await; + let res = + index_insert(&mut self.index, &old_leaf, index, id_provider, extensions).await; + #[cfg(not(feature = "tree_index"))] - let res = index_insert(&self.nodes, &old_leaf, index, id_provider).await; + let res = + index_insert(&self.nodes, &old_leaf, index, id_provider, extensions).await; if res.is_ok() { self.nodes.insert_leaf(index, old_leaf); @@ -465,7 +483,9 @@ impl TreeKemPublic { .leaf_node .clone(); - let res = self.add_leaf(leaf, id_provider, Some(start)).await; + let res = self + .add_leaf(leaf, id_provider, extensions, Some(start)) + .await; if let Ok(index) = res { start = index; @@ -502,6 +522,7 @@ impl TreeKemPublic { pub async fn batch_edit_lite( &mut self, proposal_bundle: &ProposalBundle, + extensions: &ExtensionList, id_provider: &I, cipher_suite_provider: &CP, ) -> Result, MlsError> @@ -517,7 +538,10 @@ impl TreeKemPublic { { // If this fails, it's not because the proposal is bad. let old_leaf = self.nodes.blank_leaf_node(index)?; - let identity = identity(&old_leaf.signing_identity, id_provider).await?; + + let identity = + identity(&old_leaf.signing_identity, id_provider, extensions).await?; + self.index.remove(&old_leaf, &identity); } @@ -533,7 +557,9 @@ impl TreeKemPublic { for p in &proposal_bundle.additions { let leaf = p.proposal.key_package.leaf_node.clone(); - start = self.add_leaf(leaf, id_provider, Some(start)).await?; + start = self + .add_leaf(leaf, id_provider, extensions, Some(start)) + .await?; added.push(start); } @@ -557,15 +583,16 @@ impl TreeKemPublic { &mut self, leaf: LeafNode, id_provider: &I, + extensions: &ExtensionList, start: Option, ) -> Result { let index = self.nodes.next_empty_leaf(start.unwrap_or(LeafIndex(0))); #[cfg(feature = "tree_index")] - index_insert(&mut self.index, &leaf, index, id_provider).await?; + index_insert(&mut self.index, &leaf, index, id_provider, extensions).await?; #[cfg(not(feature = "tree_index"))] - index_insert(&self.nodes, &leaf, index, id_provider).await?; + index_insert(&self.nodes, &leaf, index, id_provider, extensions).await?; self.nodes.insert_leaf(index, leaf); self.update_unmerged(index)?; @@ -579,9 +606,10 @@ impl TreeKemPublic { async fn identity( signing_id: &SigningIdentity, provider: &I, + extensions: &ExtensionList, ) -> Result, MlsError> { provider - .identity(signing_id) + .identity(signing_id, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error())) } @@ -617,8 +645,14 @@ impl TreeKemPublic { bundle.add(p, Sender::Member(leaf_index), ProposalSource::ByValue); bundle.update_senders = vec![LeafIndex(leaf_index)]; - self.batch_edit(&mut bundle, identity_provider, cipher_suite_provider, true) - .await?; + self.batch_edit( + &mut bundle, + &Default::default(), + identity_provider, + cipher_suite_provider, + true, + ) + .await?; Ok(()) } @@ -648,12 +682,23 @@ impl TreeKemPublic { } #[cfg(feature = "by_ref_proposal")] - self.batch_edit(&mut bundle, identity_provider, cipher_suite_provider, true) - .await?; + self.batch_edit( + &mut bundle, + &Default::default(), + identity_provider, + cipher_suite_provider, + true, + ) + .await?; #[cfg(not(feature = "by_ref_proposal"))] - self.batch_edit_lite(&bundle, identity_provider, cipher_suite_provider) - .await?; + self.batch_edit_lite( + &bundle, + &Default::default(), + identity_provider, + cipher_suite_provider, + ) + .await?; bundle .removals @@ -715,6 +760,7 @@ pub(crate) mod test_utils { creator_leaf.clone(), creator_hpke_secret.clone(), &BasicIdentityProvider, + &Default::default(), ) .await .unwrap(); @@ -964,13 +1010,10 @@ mod tests { let exported = test_tree.public.export_node_data(); - let imported = TreeKemPublic::import_node_data( - exported, - #[cfg(feature = "tree_index")] - &BasicIdentityProvider, - ) - .await - .unwrap(); + let imported = + TreeKemPublic::import_node_data(exported, &BasicIdentityProvider, &Default::default()) + .await + .unwrap(); assert_eq!(test_tree.public.nodes, imported.nodes); @@ -1412,6 +1455,7 @@ mod tests { tree.batch_edit( &mut bundle, + &Default::default(), &BasicIdentityProvider, &cipher_suite_provider, true, diff --git a/mls-rs/src/tree_kem/private.rs b/mls-rs/src/tree_kem/private.rs index ccda56ed..402f2410 100644 --- a/mls-rs/src/tree_kem/private.rs +++ b/mls-rs/src/tree_kem/private.rs @@ -160,10 +160,14 @@ mod tests { get_basic_test_node_sig_key(cipher_suite, "charlie").await; // Create a new public tree with Alice - let (mut public_tree, mut alice_private) = - TreeKemPublic::derive(alice_leaf, alice_hpke_secret, &BasicIdentityProvider) - .await - .unwrap(); + let (mut public_tree, mut alice_private) = TreeKemPublic::derive( + alice_leaf, + alice_hpke_secret, + &BasicIdentityProvider, + &Default::default(), + ) + .await + .unwrap(); // Add bob and charlie to the tree public_tree diff --git a/mls-rs/src/tree_kem/tree_hash.rs b/mls-rs/src/tree_kem/tree_hash.rs index 7968a6b3..55f4c66d 100644 --- a/mls-rs/src/tree_kem/tree_hash.rs +++ b/mls-rs/src/tree_kem/tree_hash.rs @@ -347,12 +347,10 @@ mod tests { use crate::{ cipher_suite::CipherSuite, crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider}, + identity::basic::BasicIdentityProvider, tree_kem::{node::NodeVec, parent_hash::test_utils::get_test_tree_fig_12}, }; - #[cfg(feature = "tree_index")] - use crate::identity::basic::BasicIdentityProvider; - use super::*; #[derive(serde::Deserialize, serde::Serialize)] @@ -408,8 +406,8 @@ mod tests { let mut tree = TreeKemPublic::import_node_data( NodeVec::mls_decode(&mut &*one_case.tree_data).unwrap(), - #[cfg(feature = "tree_index")] &BasicIdentityProvider, + &Default::default(), ) .await .unwrap(); diff --git a/mls-rs/src/tree_kem/tree_index.rs b/mls-rs/src/tree_kem/tree_index.rs index e51a4416..b0a9eb49 100644 --- a/mls-rs/src/tree_kem/tree_index.rs +++ b/mls-rs/src/tree_kem/tree_index.rs @@ -63,9 +63,10 @@ pub(super) async fn index_insert( new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, + extensions: &ExtensionList, ) -> Result<(), MlsError> { let new_id = id_provider - .identity(&new_leaf.signing_identity) + .identity(&new_leaf.signing_identity, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; @@ -79,9 +80,10 @@ pub(super) async fn index_insert( new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, + extensions: &ExtensionList, ) -> Result<(), MlsError> { let new_id = id_provider - .identity(&new_leaf.signing_identity) + .identity(&new_leaf.signing_identity, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; @@ -95,7 +97,7 @@ pub(super) async fn index_insert( .ok_or(MlsError::DuplicateLeafData(*i))?; let id = id_provider - .identity(&leaf.signing_identity) + .identity(&leaf.signing_identity, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; diff --git a/mls-rs/src/tree_kem/update_path.rs b/mls-rs/src/tree_kem/update_path.rs index 737a545d..5e975b83 100644 --- a/mls-rs/src/tree_kem/update_path.rs +++ b/mls-rs/src/tree_kem/update_path.rs @@ -71,6 +71,7 @@ pub(crate) async fn validate_update_path