Skip to content

Commit

Permalink
Add context to identity functions (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
mulmarta authored Dec 5, 2023
1 parent 56c5848 commit cc0ff86
Show file tree
Hide file tree
Showing 25 changed files with 271 additions and 205 deletions.
7 changes: 6 additions & 1 deletion mls-rs-core/src/identity/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ pub trait IdentityProvider: Send + Sync {
///
/// The MLS protocol requires that each member of a group has a unique
/// set of identifiers according to the application.
async fn identity(&self, signing_identity: &SigningIdentity) -> Result<Vec<u8>, Self::Error>;
async fn identity(
&self,
signing_identity: &SigningIdentity,
extensions: &ExtensionList,
) -> Result<Vec<u8>, Self::Error>;

/// Determines if `successor` can remove `predecessor` as part of an external commit.
///
Expand All @@ -83,6 +87,7 @@ pub trait IdentityProvider: Send + Sync {
&self,
predecessor: &SigningIdentity,
successor: &SigningIdentity,
extensions: &ExtensionList,
) -> Result<bool, Self::Error>;

/// Credential types that are supported by this provider.
Expand Down
2 changes: 2 additions & 0 deletions mls-rs-identity-x509/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ where
async fn identity(
&self,
signing_id: &mls_rs_core::identity::SigningIdentity,
_extensions: &ExtensionList,
) -> Result<Vec<u8>, Self::Error> {
self.identity(signing_id)
}
Expand All @@ -222,6 +223,7 @@ where
&self,
predecessor: &mls_rs_core::identity::SigningIdentity,
successor: &mls_rs_core::identity::SigningIdentity,
_extensions: &ExtensionList,
) -> Result<bool, Self::Error> {
self.valid_successor(predecessor, successor)
}
Expand Down
1 change: 0 additions & 1 deletion mls-rs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,6 @@ where
protocol_version,
group_info,
tree_data,
#[cfg(feature = "tree_index")]
&self.config.identity_provider(),
&cipher_suite_provider,
)
Expand Down
8 changes: 6 additions & 2 deletions mls-rs/src/external_client/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
) -> Result<Member, MlsError> {
let identity = self
.identity_provider()
.identity(identity_id)
.identity(identity_id, self.context_extensions())
.await
.map_err(|error| MlsError::IdentityProviderError(error.into_any_error()))?;

Expand All @@ -541,7 +541,11 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {

#[cfg(not(feature = "tree_index"))]
let index = tree
.get_leaf_node_with_identity(&identity, &self.identity_provider())
.get_leaf_node_with_identity(
&identity,
&self.identity_provider(),
self.context_extensions(),
)
.await?;

let index = index.ok_or(MlsError::MemberNotFound)?;
Expand Down
6 changes: 4 additions & 2 deletions mls-rs/src/group/commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1356,7 +1356,7 @@ mod tests {
) -> bool {
if let Some(extensions) = extensions {
if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
self.identity(identity).await.unwrap()[0] == ext.foo
self.identity(identity, extensions).await.unwrap()[0] == ext.foo
} else {
true
}
Expand Down Expand Up @@ -1397,9 +1397,10 @@ mod tests {
async fn identity(
&self,
signing_identity: &SigningIdentity,
extensions: &ExtensionList,
) -> Result<Vec<u8>, Self::Error> {
self.0
.identity(signing_identity)
.identity(signing_identity, extensions)
.await
.map_err(|_| IdentityProviderWithExtensionError {})
}
Expand All @@ -1408,6 +1409,7 @@ mod tests {
&self,
_predecessor: &SigningIdentity,
_successor: &SigningIdentity,
_extensions: &ExtensionList,
) -> Result<bool, Self::Error> {
Ok(true)
}
Expand Down
15 changes: 5 additions & 10 deletions mls-rs/src/group/interop_test_vectors/tree_kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
message_signature::AuthenticatedContent, test_utils::GroupWithoutKeySchedule, Commit,
GroupContext, PathSecret, Sender,
},
identity::basic::BasicIdentityProvider,
tree_kem::{
node::{LeafIndex, NodeVec},
TreeKemPrivate, TreeKemPublic, UpdatePath,
Expand All @@ -21,9 +22,6 @@ use alloc::vec::Vec;
use mls_rs_codec::MlsDecode;
use mls_rs_core::{crypto::CipherSuiteProvider, extension::ExtensionList};

#[cfg(feature = "tree_index")]
use crate::identity::basic::BasicIdentityProvider;

#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
struct TreeKemTestCase {
pub cipher_suite: u16,
Expand Down Expand Up @@ -83,13 +81,10 @@ async fn tree_kem() {
// Import the public ratchet tree
let nodes = NodeVec::mls_decode(&mut &*test_case.ratchet_tree).unwrap();

let mut tree = TreeKemPublic::import_node_data(
nodes,
#[cfg(feature = "tree_index")]
&BasicIdentityProvider,
)
.await
.unwrap();
let mut tree =
TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
.await
.unwrap();

// Construct GroupContext
let group_context = GroupContext {
Expand Down
15 changes: 5 additions & 10 deletions mls-rs/src/group/interop_test_vectors/tree_modifications.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ use crate::{
test_utils::TEST_GROUP,
LeafIndex, Sender, TreeKemPublic,
},
identity::basic::BasicIdentityProvider,
key_package::test_utils::test_key_package,
tree_kem::{
leaf_node::test_utils::default_properties, node::NodeVec, test_utils::TreeWithSigners,
},
};

#[cfg(feature = "tree_index")]
use crate::identity::basic::BasicIdentityProvider;

#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
struct TreeModsTestCase {
#[serde(with = "hex::serde")]
Expand Down Expand Up @@ -112,13 +110,10 @@ async fn tree_modifications_interop() {
for test_case in test_cases.into_iter() {
let nodes = NodeVec::mls_decode(&mut &*test_case.tree_before).unwrap();

let tree_before = TreeKemPublic::import_node_data(
nodes,
#[cfg(feature = "tree_index")]
&BasicIdentityProvider,
)
.await
.unwrap();
let tree_before =
TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
.await
.unwrap();

let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();

Expand Down
1 change: 1 addition & 0 deletions mls-rs/src/group/message_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ pub(crate) trait MessageProcessor: Send + Sync {
.apply_update_path(
sender,
update_path,
&provisional_state.group_context.extensions,
self.identity_provider(),
self.cipher_suite_provider(),
)
Expand Down
16 changes: 13 additions & 3 deletions mls-rs/src/group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,13 @@ where
.check_if_valid(&leaf_node, ValidationContext::Add(None))
.await?;

let (mut public_tree, private_tree) =
TreeKemPublic::derive(leaf_node, leaf_node_secret, &config.identity_provider()).await?;
let (mut public_tree, private_tree) = TreeKemPublic::derive(
leaf_node,
leaf_node_secret,
&config.identity_provider(),
&group_context_extensions,
)
.await?;

let tree_hash = public_tree.tree_hash(&cipher_suite_provider).await?;

Expand Down Expand Up @@ -1320,7 +1325,11 @@ where

#[cfg(not(feature = "tree_index"))]
let index = tree
.get_leaf_node_with_identity(identity, &self.identity_provider())
.get_leaf_node_with_identity(
identity,
&self.identity_provider(),
&self.state.context.extensions,
)
.await?;

let index = index.ok_or(MlsError::MemberNotFound)?;
Expand Down Expand Up @@ -1615,6 +1624,7 @@ where
.apply_update_path(
sender,
update_path,
&provisional_state.group_context.extensions,
self.identity_provider(),
self.cipher_suite_provider(),
)
Expand Down
62 changes: 41 additions & 21 deletions mls-rs/src/group/proposal_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,10 @@ mod tests {
)
.await;

let (pub_tree, priv_tree) = TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider)
.await
.unwrap();
let (pub_tree, priv_tree) =
TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider, &Default::default())
.await
.unwrap();

(priv_tree.self_index, pub_tree)
}
Expand Down Expand Up @@ -717,10 +718,14 @@ mod tests {

let sender = LeafIndex(0);

let (mut tree, _) =
TreeKemPublic::derive(sender_leaf, sender_leaf_secret, &BasicIdentityProvider)
.await
.unwrap();
let (mut tree, _) = TreeKemPublic::derive(
sender_leaf,
sender_leaf_secret,
&BasicIdentityProvider,
&Default::default(),
)
.await
.unwrap();

let add_package = test_key_package(protocol_version, cipher_suite, "dave").await;

Expand Down Expand Up @@ -774,6 +779,7 @@ mod tests {
expected_tree
.batch_edit(
&mut bundle,
&Default::default(),
&BasicIdentityProvider,
&cipher_suite_provider,
true,
Expand Down Expand Up @@ -1699,9 +1705,14 @@ mod tests {
get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;
let alice = 0;

let (mut tree, _) = TreeKemPublic::derive(alice_leaf, alice_secret, &BasicIdentityProvider)
.await
.unwrap();
let (mut tree, _) = TreeKemPublic::derive(
alice_leaf,
alice_secret,
&BasicIdentityProvider,
&Default::default(),
)
.await
.unwrap();

let bob_node = get_basic_test_node(TEST_CIPHER_SUITE, "bob").await;

Expand Down Expand Up @@ -2983,9 +2994,10 @@ mod tests {
.await
.unwrap();

let (pub_tree, priv_tree) = TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider)
.await
.unwrap();
let (pub_tree, priv_tree) =
TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider, &Default::default())
.await
.unwrap();

(priv_tree.self_index, pub_tree)
};
Expand Down Expand Up @@ -3387,10 +3399,14 @@ mod tests {
let (alice_leaf, alice_secret, alice_signer) =
get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;

let (mut tree, priv_tree) =
TreeKemPublic::derive(alice_leaf.clone(), alice_secret, &BasicIdentityProvider)
.await
.unwrap();
let (mut tree, priv_tree) = TreeKemPublic::derive(
alice_leaf.clone(),
alice_secret,
&BasicIdentityProvider,
&Default::default(),
)
.await
.unwrap();

let alice = priv_tree.self_index;

Expand Down Expand Up @@ -3460,10 +3476,14 @@ mod tests {
let (alice_leaf, alice_secret, alice_signer) =
get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;

let (mut tree, priv_tree) =
TreeKemPublic::derive(alice_leaf.clone(), alice_secret, &BasicIdentityProvider)
.await
.unwrap();
let (mut tree, priv_tree) = TreeKemPublic::derive(
alice_leaf.clone(),
alice_secret,
&BasicIdentityProvider,
&Default::default(),
)
.await
.unwrap();

let alice = priv_tree.self_index;

Expand Down
7 changes: 6 additions & 1 deletion mls-rs/src/group/proposal_filter/filtering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ where
let added = new_tree
.batch_edit(
&mut applied_proposals,
group_extensions_in_use,
self.identity_provider,
self.cipher_suite_provider,
strategy.is_ignore(),
Expand Down Expand Up @@ -195,7 +196,11 @@ where

let valid_successor = self
.identity_provider
.valid_successor(&old_leaf.signing_identity, &leaf.signing_identity)
.valid_successor(
&old_leaf.signing_identity,
&leaf.signing_identity,
group_extensions_in_use,
)
.await
.map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
.and_then(|valid| valid.then_some(()).ok_or(MlsError::InvalidSuccessor));
Expand Down
Loading

0 comments on commit cc0ff86

Please sign in to comment.