Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework leaf node extensions to work via parameters rather than as a c… #196

Merged
merged 6 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions mls-rs-uniffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,10 @@ impl Client {
/// See [`mls_rs::Client::generate_key_package_message`] for
/// details.
pub async fn generate_key_package_message(&self) -> Result<Message, Error> {
let message = self.inner.generate_key_package_message().await?;
let message = self
.inner
.generate_key_package_message(Default::default(), Default::default())
.await?;
Ok(message.into())
}

Expand All @@ -403,10 +406,14 @@ impl Client {
let inner = match group_id {
Some(group_id) => {
self.inner
.create_group_with_id(group_id, extensions)
.create_group_with_id(group_id, extensions, Default::default())
.await?
}
None => {
self.inner
.create_group(extensions, Default::default())
.await?
}
None => self.inner.create_group(extensions).await?,
};
Ok(Group {
inner: Arc::new(Mutex::new(inner)),
Expand Down
4 changes: 2 additions & 2 deletions mls-rs/benches/group_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ use mls_rs_crypto_openssl::OpensslCryptoProvider;

fn bench(c: &mut Criterion) {
let alice = make_client("alice")
.create_group(Default::default())
.create_group(Default::default(), Default::default())
.unwrap();

const MAX_ADD_COUNT: usize = 1000;

let key_packages = (0..MAX_ADD_COUNT)
.map(|i| {
make_client(&format!("bob-{i}"))
.generate_key_package_message()
.generate_key_package_message(Default::default(), Default::default())
.unwrap()
})
.collect::<Vec<_>>();
Expand Down
5 changes: 3 additions & 2 deletions mls-rs/examples/basic_server_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ fn main() -> Result<(), MlsError> {
let bob = make_client("bob")?;

// Alice creates a group with bob
let mut alice_group = alice.create_group(ExtensionList::default())?;
let bob_key_package = bob.generate_key_package_message()?;
let mut alice_group = alice.create_group(ExtensionList::default(), Default::default())?;
let bob_key_package =
bob.generate_key_package_message(Default::default(), Default::default())?;

let welcome = &alice_group
.commit_builder()
Expand Down
5 changes: 3 additions & 2 deletions mls-rs/examples/basic_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ fn main() -> Result<(), MlsError> {
let bob = make_client(crypto_provider.clone(), "bob")?;

// Alice creates a new group.
let mut alice_group = alice.create_group(ExtensionList::default())?;
let mut alice_group = alice.create_group(ExtensionList::default(), Default::default())?;

// Bob generates a key package that Alice needs to add Bob to the group.
let bob_key_package = bob.generate_key_package_message()?;
let bob_key_package =
bob.generate_key_package_message(Default::default(), Default::default())?;

// Alice issues a commit that adds Bob to the group.
let alice_commit = alice_group
Expand Down
9 changes: 6 additions & 3 deletions mls-rs/examples/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,13 @@ fn main() -> Result<(), CustomError> {
let roster = vec![alice.credential];
context_extensions.set_from(RosterExtension { roster })?;

let mut alice_tablet_group = make_client(alice_tablet)?.create_group(context_extensions)?;
let mut alice_tablet_group =
make_client(alice_tablet)?.create_group(context_extensions, Default::default())?;

// Alice can add her other device
let alice_pc_client = make_client(alice_pc)?;
let key_package = alice_pc_client.generate_key_package_message()?;
let key_package =
alice_pc_client.generate_key_package_message(Default::default(), Default::default())?;

let welcome = alice_tablet_group
.commit_builder()
Expand All @@ -387,7 +389,8 @@ fn main() -> Result<(), CustomError> {

// Alice cannot add bob's devices yet
let bob_tablet_client = make_client(bob_tablet)?;
let key_package = bob_tablet_client.generate_key_package_message()?;
let key_package =
bob_tablet_client.generate_key_package_message(Default::default(), Default::default())?;

let res = alice_tablet_group
.commit_builder()
Expand Down
10 changes: 6 additions & 4 deletions mls-rs/examples/large_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,16 @@ fn make_groups_best_case<P: CryptoProvider + Clone>(
) -> Result<Vec<Group<impl MlsConfig>>, MlsError> {
let bob_client = make_client(crypto_provider.clone(), &make_name(0))?;

let bob_group = bob_client.create_group(Default::default())?;
let bob_group = bob_client.create_group(Default::default(), Default::default())?;

let mut groups = vec![bob_group];

for i in 0..(num_groups - 1) {
let bob_client = make_client(crypto_provider.clone(), &make_name(i + 1))?;

// The new client generates a key package.
let bob_kpkg = bob_client.generate_key_package_message()?;
let bob_kpkg =
bob_client.generate_key_package_message(Default::default(), Default::default())?;

// Last group sends a commit adding the new client to the group.
let commit = groups
Expand Down Expand Up @@ -100,7 +101,7 @@ fn make_groups_worst_case<P: CryptoProvider + Clone>(
) -> Result<Vec<Group<impl MlsConfig>>, MlsError> {
let alice_client = make_client(crypto_provider.clone(), &make_name(0))?;

let mut alice_group = alice_client.create_group(Default::default())?;
let mut alice_group = alice_client.create_group(Default::default(), Default::default())?;

let bob_clients = (0..(num_groups - 1))
.map(|i| make_client(crypto_provider.clone(), &make_name(i + 1)))
Expand All @@ -110,7 +111,8 @@ fn make_groups_worst_case<P: CryptoProvider + Clone>(
let mut commit_builder = alice_group.commit_builder();

for bob_client in &bob_clients {
let bob_kpkg = bob_client.generate_key_package_message()?;
let bob_kpkg =
bob_client.generate_key_package_message(Default::default(), Default::default())?;
commit_builder = commit_builder.add_member(bob_kpkg)?;
}

Expand Down
4 changes: 3 additions & 1 deletion mls-rs/examples/x509.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ fn main() {
.signing_identity(signing_identity, secret_key, CIPHERSUITE)
.build();

let mut alice_group = alice_client.create_group(Default::default()).unwrap();
let mut alice_group = alice_client
.create_group(Default::default(), Default::default())
.unwrap();

alice_group.commit(Vec::new()).unwrap();
alice_group.apply_pending_commit().unwrap();
Expand Down
66 changes: 51 additions & 15 deletions mls-rs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,23 @@ where
///
/// A key package message may only be used once.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn generate_key_package_message(&self) -> Result<MlsMessage, MlsError> {
Ok(self.generate_key_package().await?.key_package_message())
pub async fn generate_key_package_message(
&self,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<MlsMessage, MlsError> {
Ok(self
.generate_key_package(key_package_extensions, leaf_node_extensions)
.await?
.key_package_message())
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn generate_key_package(&self) -> Result<KeyPackageGeneration, MlsError> {
async fn generate_key_package(
&self,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<KeyPackageGeneration, MlsError> {
let (signing_identity, cipher_suite) = self.signing_identity()?;

let cipher_suite_provider = self
Expand All @@ -454,8 +465,8 @@ where
.generate(
self.config.lifetime(),
self.config.capabilities(),
self.config.key_package_extensions(),
self.config.leaf_node_extensions(),
key_package_extensions,
leaf_node_extensions,
)
.await?;

Expand Down Expand Up @@ -486,6 +497,7 @@ where
&self,
group_id: Vec<u8>,
group_context_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<Group<C>, MlsError> {
let (signing_identity, cipher_suite) = self.signing_identity()?;

Expand All @@ -496,6 +508,7 @@ where
self.version,
signing_identity.clone(),
group_context_extensions,
leaf_node_extensions,
self.signer()?.clone(),
)
.await
Expand All @@ -510,6 +523,7 @@ where
pub async fn create_group(
&self,
group_context_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<Group<C>, MlsError> {
let (signing_identity, cipher_suite) = self.signing_identity()?;

Expand All @@ -520,6 +534,7 @@ where
self.version,
signing_identity.clone(),
group_context_extensions,
leaf_node_extensions,
self.signer()?.clone(),
)
.await
Expand Down Expand Up @@ -674,6 +689,8 @@ where
group_info: &MlsMessage,
tree_data: Option<crate::group::ExportedTree<'_>>,
authenticated_data: Vec<u8>,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<MlsMessage, MlsError> {
let protocol_version = group_info.version;

Expand Down Expand Up @@ -702,7 +719,10 @@ where
)
.await?;

let key_package = self.generate_key_package().await?.key_package;
let key_package = self
.generate_key_package(key_package_extensions, leaf_node_extensions)
.await?
.key_package;

(key_package.cipher_suite == cipher_suite)
.then_some(())
Expand Down Expand Up @@ -745,11 +765,6 @@ where
.ok_or(MlsError::SignerNotFound)
}

/// Returns key package extensions used by this client
pub fn key_package_extensions(&self) -> ExtensionList {
self.config.key_package_extensions()
}

/// The [KeyPackageStorage] that this client was configured to use.
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
pub fn key_package_store(&self) -> <C as ClientConfig>::KeyPackageRepository {
Expand Down Expand Up @@ -793,14 +808,24 @@ pub(crate) mod test_utils {
cipher_suite: CipherSuite,
identity: &str,
) -> (Client<TestClientConfig>, MlsMessage) {
test_client_with_key_pkg_custom(protocol_version, cipher_suite, identity, |_| {}).await
test_client_with_key_pkg_custom(
protocol_version,
cipher_suite,
identity,
Default::default(),
Default::default(),
|_| {},
)
.await
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn test_client_with_key_pkg_custom<F>(
protocol_version: ProtocolVersion,
cipher_suite: CipherSuite,
identity: &str,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
mut config: F,
) -> (Client<TestClientConfig>, MlsMessage)
where
Expand All @@ -816,7 +841,10 @@ pub(crate) mod test_utils {

config(&mut client.config);

let key_package = client.generate_key_package_message().await.unwrap();
let key_package = client
.generate_key_package_message(key_package_extensions, leaf_node_extensions)
.await
.unwrap();

(client, key_package)
}
Expand Down Expand Up @@ -863,7 +891,10 @@ mod tests {
.build();

// TODO: Tests around extensions
let key_package = client.generate_key_package_message().await.unwrap();
let key_package = client
.generate_key_package_message(Default::default(), Default::default())
.await
.unwrap();

assert_eq!(key_package.version, protocol_version);

Expand Down Expand Up @@ -902,6 +933,8 @@ mod tests {
&alice_group.group_info_message(true).await.unwrap(),
None,
vec![],
Default::default(),
Default::default(),
)
.await
.unwrap();
Expand Down Expand Up @@ -1047,7 +1080,10 @@ mod tests {
.signing_identity(alice_identity.clone(), secret_key, TEST_CIPHER_SUITE)
.build();

let msg = alice.generate_key_package_message().await.unwrap();
let msg = alice
.generate_key_package_message(Default::default(), Default::default())
.await
.unwrap();
let res = alice.commit_external(msg).await.map(|_| ());

assert_matches!(res, Err(MlsError::UnexpectedMessageType));
Expand Down
Loading
Loading