Skip to content

Commit

Permalink
Use stored state for apply-pending-commit (#229)
Browse files Browse the repository at this point in the history
* Use stored state for apply-pending-commit

* Bump version

* Make serialization of Sanapshot backwards compatible

* Fixup

* Fixup

* Make snapshot backwards compatible

* Fixup

* Fixup

* Fixup

---------

Co-authored-by: Marta Mularczyk <[email protected]>
Co-authored-by: Tom Leavy <[email protected]>
  • Loading branch information
3 people authored Jan 7, 2025
1 parent 158a9d3 commit 66d6717
Show file tree
Hide file tree
Showing 17 changed files with 562 additions and 188 deletions.
2 changes: 1 addition & 1 deletion mls-rs-codec/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mls-rs-codec"
version = "0.5.3"
version = "0.5.4"
edition = "2021"
description = "TLS codec and MLS specific encoding used by mls-rs"
homepage = "https://github.com/awslabs/mls-rs"
Expand Down
45 changes: 45 additions & 0 deletions mls-rs-codec/src/bool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use crate::{MlsDecode, MlsEncode, MlsSize};
use alloc::vec::Vec;

impl MlsSize for bool {
fn mls_encoded_len(&self) -> usize {
1
}
}

impl MlsEncode for bool {
fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), crate::Error> {
writer.push(*self as u8);
Ok(())
}
}

impl MlsDecode for bool {
fn mls_decode(reader: &mut &[u8]) -> Result<Self, crate::Error> {
MlsDecode::mls_decode(reader).map(|i: u8| i != 0)
}
}

#[cfg(test)]
mod tests {
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test as test;

use crate::{MlsDecode, MlsEncode};

use alloc::vec;

#[test]
fn round_trip() {
assert_eq!(false.mls_encode_to_vec().unwrap(), vec![0]);
assert_eq!(true.mls_encode_to_vec().unwrap(), vec![1]);

let vec = vec![true, true, false];
let bytes = vec.mls_encode_to_vec().unwrap();
assert_eq!(vec, Vec::mls_decode(&mut &*bytes).unwrap())
}
}
1 change: 1 addition & 0 deletions mls-rs-codec/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod byte_vec;

pub mod iter;

mod bool;
mod cow;
mod map;
mod option;
Expand Down
2 changes: 1 addition & 1 deletion mls-rs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ serde = ["dep:serde", "zeroize/serde", "hex/serde", "dep:serde_bytes"]
last_resort_key_package_ext = []

[dependencies]
mls-rs-codec = { version = "0.5.2", path = "../mls-rs-codec", default-features = false}
mls-rs-codec = { version = "0.5", path = "../mls-rs-codec", default-features = false}
zeroize = { version = "1", default-features = false, features = ["alloc", "zeroize_derive"] }
arbitrary = { version = "1", features = ["derive"], optional = true }
thiserror = { version = "1.0.40", optional = true }
Expand Down
2 changes: 1 addition & 1 deletion mls-rs-uniffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ impl From<mls_rs::group::CommitEffect> for CommitEffect {
},
group::CommitEffect::Removed {
new_epoch: _,
remove_proposal: _,
remover: _,
} => CommitEffect::Removed,
group::CommitEffect::ReInit(_) => CommitEffect::ReInit,
}
Expand Down
4 changes: 2 additions & 2 deletions mls-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mls-rs"
version = "0.43.2"
version = "0.43.3"
edition = "2021"
description = "An implementation of Messaging Layer Security (RFC 9420)"
homepage = "https://github.com/awslabs/mls-rs"
Expand Down Expand Up @@ -54,7 +54,7 @@ fuzz_util = ["test_util", "default", "dep:once_cell", "dep:mls-rs-crypto-openssl
mls-rs-core = { path = "../mls-rs-core", default-features = false, version = "0.20.0" }
mls-rs-identity-x509 = { path = "../mls-rs-identity-x509", default-features = false, version = "0.13.0", optional = true }
zeroize = { version = "1", default-features = false, features = ["alloc", "zeroize_derive"] }
mls-rs-codec = { version = "0.5.2", path = "../mls-rs-codec", default-features = false}
mls-rs-codec = { version = "0.5", path = "../mls-rs-codec", default-features = false}
thiserror = { version = "1.0.40", optional = true }
itertools = { version = "0.12.0", default-features = false, features = ["use_alloc"]}
cfg-if = "1"
Expand Down
2 changes: 1 addition & 1 deletion mls-rs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ mod tests {
ReceivedMessage::Commit(CommitMessageDescription {
effect: CommitEffect::Removed {
new_epoch: _,
remove_proposal: _
remover: _
},
..
})
Expand Down
133 changes: 89 additions & 44 deletions mls-rs/src/group/commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ use super::{
message_signature::AuthenticatedContent,
mls_rules::CommitDirection,
proposal::{Proposal, ProposalOrRef},
EncryptedGroupSecrets, ExportedTree, Group, GroupContext, GroupInfo, Welcome,
CommitEffect, CommitMessageDescription, EncryptedGroupSecrets, EpochSecrets, ExportedTree,
Group, GroupContext, GroupInfo, GroupState, InterimTranscriptHash, NewEpoch,
PendingCommitSnapshot, Welcome,
};

#[cfg(not(feature = "by_ref_proposal"))]
Expand All @@ -64,25 +66,29 @@ pub(crate) struct Commit {
}

#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct CommitGeneration {
pub content: AuthenticatedContent,
pub pending_private_tree: TreeKemPrivate,
pub pending_commit_secret: PathSecret,
pub commit_message_hash: MessageHash,
pub(crate) struct PendingCommit {
pub(crate) state: GroupState,
pub(crate) epoch_secrets: EpochSecrets,
pub(crate) private_tree: TreeKemPrivate,
pub(crate) key_schedule: KeySchedule,
pub(crate) signer: SignatureSecretKey,

pub(crate) output: CommitMessageDescription,

pub(crate) commit_message_hash: MessageHash,
}

#[cfg_attr(
all(feature = "ffi", not(test)),
safer_ffi_gen::ffi_type(clone, opaque)
)]
#[derive(Clone)]
pub struct CommitSecrets(pub(crate) CommitGeneration);
pub struct CommitSecrets(pub(crate) PendingCommitSnapshot);

impl CommitSecrets {
/// Deserialize the commit secrets from bytes
pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
Ok(CommitGeneration::mls_decode(&mut &*bytes).map(Self)?)
Ok(MlsDecode::mls_decode(&mut &*bytes).map(Self)?)
}

/// Serialize the commit secrets to bytes
Expand Down Expand Up @@ -359,7 +365,7 @@ where
)
.await?;

self.group.pending_commit = Some(pending_commit);
self.group.pending_commit = pending_commit.try_into()?;

Ok(output)
}
Expand All @@ -383,7 +389,12 @@ where
)
.await?;

Ok((output, CommitSecrets(pending_commit)))
Ok((
output,
CommitSecrets(PendingCommitSnapshot::PendingCommit(
pending_commit.mls_encode_to_vec()?,
)),
))
}
}

Expand Down Expand Up @@ -481,8 +492,8 @@ where
new_signer: Option<SignatureSecretKey>,
new_signing_identity: Option<SigningIdentity>,
new_leaf_node_extensions: Option<ExtensionList>,
) -> Result<(CommitOutput, CommitGeneration), MlsError> {
if self.pending_commit.is_some() {
) -> Result<(CommitOutput, PendingCommit), MlsError> {
if !self.pending_commit.is_none() {
return Err(MlsError::ExistingPendingCommit);
}

Expand All @@ -503,7 +514,7 @@ where
Sender::Member(*self.private_tree.self_index)
};

let new_signer_ref = new_signer.as_ref().unwrap_or(&self.signer);
let new_signer = new_signer.unwrap_or_else(|| self.signer.clone());
let old_signer = &self.signer;

#[cfg(feature = "std")]
Expand Down Expand Up @@ -544,15 +555,13 @@ where
self.private_tree.self_index = provisional_private_tree.self_index;
}

let mut provisional_group_context = provisional_state.group_context;

// Decide whether to populate the path field: If the path field is required based on the
// proposals that are in the commit (see above), then it MUST be populated. Otherwise, the
// sender MAY omit the path field at its discretion.
let commit_options = mls_rules
.commit_options(
&provisional_state.public_tree.roster(),
&provisional_group_context,
&provisional_state.group_context,
&provisional_state.applied_proposals,
)
.map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
Expand Down Expand Up @@ -582,9 +591,9 @@ where
&mut provisional_private_tree,
)
.encap(
&mut provisional_group_context,
&mut provisional_state.group_context,
&provisional_state.indexes_of_added_kpkgs,
new_signer_ref,
&new_signer,
Some(self.config.leaf_properties(new_leaf_node_extensions)),
new_signing_identity,
&self.cipher_suite_provider,
Expand All @@ -608,7 +617,7 @@ where
)
.await?;

provisional_group_context.tree_hash = provisional_state
provisional_state.group_context.tree_hash = provisional_state
.public_tree
.tree_hash(&self.cipher_suite_provider)
.await?;
Expand All @@ -632,7 +641,7 @@ where
.collect();

let commit = Commit {
proposals: provisional_state.applied_proposals.into_proposals_or_refs(),
proposals: provisional_state.applied_proposals.proposals_or_refs(),
path: update_path,
};

Expand All @@ -659,26 +668,33 @@ where
)
.await?;

provisional_group_context.confirmed_transcript_hash = confirmed_transcript_hash;
provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;

let key_schedule_result = KeySchedule::from_key_schedule(
&self.key_schedule,
&commit_secret,
&provisional_group_context,
&provisional_state.group_context,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
self.state.public_tree.total_leaf_count(),
provisional_state.public_tree.total_leaf_count(),
&psk_secret,
&self.cipher_suite_provider,
)
.await?;

let confirmation_tag = ConfirmationTag::create(
&key_schedule_result.confirmation_key,
&provisional_group_context.confirmed_transcript_hash,
&provisional_state.group_context.confirmed_transcript_hash,
&self.cipher_suite_provider,
)
.await?;

let interim_transcript_hash = InterimTranscriptHash::create(
self.cipher_suite_provider(),
&provisional_state.group_context.confirmed_transcript_hash,
&confirmation_tag,
)
.await?;

auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());

let ratchet_tree_ext = commit_options
Expand All @@ -705,10 +721,10 @@ where

let info = self
.make_group_info(
&provisional_group_context,
&provisional_state.group_context,
extensions,
&confirmation_tag,
new_signer_ref,
&new_signer,
)
.await?;

Expand All @@ -728,10 +744,10 @@ where

let welcome_group_info = self
.make_group_info(
&provisional_group_context,
&provisional_state.group_context,
welcome_group_info_extensions,
&confirmation_tag,
new_signer_ref,
&new_signer,
)
.await?;

Expand All @@ -754,11 +770,11 @@ where
#[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
let encrypted_path_secrets: Vec<_> = added_key_pkgs
.into_par_iter()
.zip(provisional_state.indexes_of_added_kpkgs)
.zip(&provisional_state.indexes_of_added_kpkgs)
.map(|(key_package, leaf_index)| {
self.encrypt_group_secrets(
&key_package,
leaf_index,
*leaf_index,
&key_schedule_result.joiner_secret,
path_secrets,
#[cfg(feature = "psk")]
Expand All @@ -774,12 +790,12 @@ where

for (key_package, leaf_index) in added_key_pkgs
.into_iter()
.zip(provisional_state.indexes_of_added_kpkgs)
.zip(&provisional_state.indexes_of_added_kpkgs)
{
secrets.push(
self.encrypt_group_secrets(
&key_package,
leaf_index,
*leaf_index,
&key_schedule_result.joiner_secret,
path_secrets,
#[cfg(feature = "psk")]
Expand All @@ -805,20 +821,49 @@ where

let commit_message = self.format_for_wire(auth_content.clone()).await?;

let pending_commit = CommitGeneration {
content: auth_content,
pending_private_tree: provisional_private_tree,
pending_commit_secret: commit_secret,
// TODO is it necessary to clone the tree here? or can we just output serialized bytes?
let ratchet_tree = (!commit_options.ratchet_tree_extension)
.then(|| ExportedTree::new(provisional_state.public_tree.nodes.clone()));

let pending_reinit = provisional_state
.applied_proposals
.reinitializations
.first();

let pending_commit = PendingCommit {
output: CommitMessageDescription {
is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
authenticated_data: auth_content.content.authenticated_data,
committer: *provisional_private_tree.self_index,
effect: match pending_reinit {
Some(r) => CommitEffect::ReInit(r.clone()),
None => CommitEffect::NewEpoch(
NewEpoch::new(self.state.clone(), &provisional_state).into(),
),
},
},

state: GroupState {
#[cfg(feature = "by_ref_proposal")]
proposals: crate::group::ProposalCache::new(
self.protocol_version(),
self.group_id().to_vec(),
),
context: provisional_state.group_context,
public_tree: provisional_state.public_tree,
interim_transcript_hash,
pending_reinit: pending_reinit.map(|r| r.proposal.clone()),
confirmation_tag,
},

commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message)
.await?,
};
signer: new_signer,
epoch_secrets: key_schedule_result.epoch_secrets,
key_schedule: key_schedule_result.key_schedule,

let ratchet_tree = (!commit_options.ratchet_tree_extension)
.then(|| ExportedTree::new(provisional_state.public_tree.nodes));

if let Some(signer) = new_signer {
self.signer = signer;
}
private_tree: provisional_private_tree,
};

let output = CommitOutput {
commit_message,
Expand Down
Loading

0 comments on commit 66d6717

Please sign in to comment.