From b2f5312c77be559946e7138062cfd4010b71f2e0 Mon Sep 17 00:00:00 2001 From: Marta Mularczyk Date: Mon, 18 Dec 2023 17:29:33 +0100 Subject: [PATCH] Expose more group info --- mls-rs/src/external_client/group.rs | 37 +++------------ mls-rs/src/group/context.rs | 45 +++++++++++++++---- mls-rs/src/group/framing.rs | 2 +- mls-rs/src/group/group_info.rs | 27 ++++++++--- mls-rs/src/group/mod.rs | 15 +++---- mls-rs/src/group/resumption.rs | 6 +-- .../src/by_ref_proposal.rs | 2 +- 7 files changed, 75 insertions(+), 59 deletions(-) diff --git a/mls-rs/src/external_client/group.rs b/mls-rs/src/external_client/group.rs index 4783d44f..3596cfdd 100644 --- a/mls-rs/src/external_client/group.rs +++ b/mls-rs/src/external_client/group.rs @@ -24,7 +24,7 @@ use crate::{ snapshot::RawGroupState, state::GroupState, transcript_hash::InterimTranscriptHash, - validate_group_info, Roster, + validate_group_info, GroupContext, Roster, }, identity::SigningIdentity, protocol_version::ProtocolVersion, @@ -316,7 +316,7 @@ impl ExternalGroup { let key_id = ResumptionPsk { psk_epoch, usage: ResumptionPSKUsage::Application, - psk_group_id: PskGroupId(self.group_id().to_vec()), + psk_group_id: PskGroupId(self.group_context().group_id().to_vec()), }; let proposal = self.psk_proposal(JustPreSharedKeyID::Resumption(key_id))?; @@ -456,7 +456,7 @@ impl ExternalGroup { }; Ok(MlsMessage::new( - self.protocol_version(), + self.group_context().version(), MlsMessagePayload::Plain(plaintext), )) } @@ -466,28 +466,10 @@ impl ExternalGroup { &self.state } - /// Get the unique identifier of this group. + /// Get the current group context summarizing various information about the group. #[inline(always)] - pub fn group_id(&self) -> &[u8] { - &self.group_state().context.group_id - } - - /// Get the current epoch number of the group's state. - #[inline(always)] - pub fn current_epoch(&self) -> u64 { - self.group_state().context.epoch - } - - /// Get the current protocol version in use by the group. - #[inline(always)] - pub fn protocol_version(&self) -> ProtocolVersion { - self.group_state().context.protocol_version - } - - /// Get the current ciphersuite in use by the group. - #[inline(always)] - pub fn cipher_suite(&self) -> CipherSuite { - self.group_state().context.cipher_suite + pub fn group_context(&self) -> &GroupContext { + &self.group_state().context } /// Export the current ratchet tree used within the group. @@ -505,11 +487,6 @@ impl ExternalGroup { self.group_state().public_tree.roster() } - #[inline(always)] - pub fn context_extensions(&self) -> &ExtensionList { - &self.group_state().context.extensions - } - /// Get the /// [transcript hash](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-transcript-hashes) /// for the current epoch that the group is in. @@ -538,7 +515,7 @@ impl ExternalGroup { ) -> Result { let identity = self .identity_provider() - .identity(identity_id, self.context_extensions()) + .identity(identity_id, self.group_context().extensions()) .await .map_err(|error| MlsError::IdentityProviderError(error.into_any_error()))?; diff --git a/mls-rs/src/group/context.rs b/mls-rs/src/group/context.rs index f20947a6..571aeeb5 100644 --- a/mls-rs/src/group/context.rs +++ b/mls-rs/src/group/context.rs @@ -12,20 +12,25 @@ use super::ConfirmedTranscriptHash; #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[cfg_attr( + all(feature = "ffi", not(test)), + safer_ffi_gen::ffi_type(clone, opaque) +)] pub struct GroupContext { - pub protocol_version: ProtocolVersion, - pub cipher_suite: CipherSuite, + pub(crate) protocol_version: ProtocolVersion, + pub(crate) cipher_suite: CipherSuite, #[mls_codec(with = "mls_rs_codec::byte_vec")] - pub group_id: Vec, - pub epoch: u64, + pub(crate) group_id: Vec, + pub(crate) epoch: u64, #[mls_codec(with = "mls_rs_codec::byte_vec")] - pub tree_hash: Vec, - pub confirmed_transcript_hash: ConfirmedTranscriptHash, - pub extensions: ExtensionList, + pub(crate) tree_hash: Vec, + pub(crate) confirmed_transcript_hash: ConfirmedTranscriptHash, + pub(crate) extensions: ExtensionList, } +#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)] impl GroupContext { - pub fn new_group( + pub(crate) fn new_group( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, group_id: Vec, @@ -42,4 +47,28 @@ impl GroupContext { extensions, } } + + /// Get the current protocol version in use by the group. + pub fn version(&self) -> ProtocolVersion { + self.protocol_version + } + + /// Get the current cipher suite in use by the group. + pub fn cipher_suite(&self) -> CipherSuite { + self.cipher_suite + } + + /// Get the unique identifier of this group. + pub fn group_id(&self) -> &[u8] { + &self.group_id + } + + /// Get the current epoch number of the group's state. + pub fn epoch(&self) -> u64 { + self.epoch + } + + pub fn extensions(&self) -> &ExtensionList { + &self.extensions + } } diff --git a/mls-rs/src/group/framing.rs b/mls-rs/src/group/framing.rs index 0ab4e8e8..b3299b15 100644 --- a/mls-rs/src/group/framing.rs +++ b/mls-rs/src/group/framing.rs @@ -311,7 +311,7 @@ impl MlsMessage { } #[inline(always)] - pub(crate) fn into_group_info(self) -> Option { + pub fn into_group_info(self) -> Option { match self.payload { MlsMessagePayload::GroupInfo(info) => Some(info), _ => None, diff --git a/mls-rs/src/group/group_info.rs b/mls-rs/src/group/group_info.rs index c7f6107d..bad9f2eb 100644 --- a/mls-rs/src/group/group_info.rs +++ b/mls-rs/src/group/group_info.rs @@ -10,13 +10,28 @@ use super::*; #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -pub(crate) struct GroupInfo { - pub group_context: GroupContext, - pub extensions: ExtensionList, - pub confirmation_tag: ConfirmationTag, - pub signer: LeafIndex, +#[cfg_attr( + all(feature = "ffi", not(test)), + safer_ffi_gen::ffi_type(clone, opaque) +)] +pub struct GroupInfo { + pub(crate) group_context: GroupContext, + pub(crate) extensions: ExtensionList, + pub(crate) confirmation_tag: ConfirmationTag, + pub(crate) signer: LeafIndex, #[mls_codec(with = "mls_rs_codec::byte_vec")] - pub signature: Vec, + pub(crate) signature: Vec, +} + +#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)] +impl GroupInfo { + pub fn group_context(&self) -> &GroupContext { + &self.group_context + } + + pub fn extensions(&self) -> &ExtensionList { + &self.extensions + } } #[derive(MlsEncode, MlsSize)] diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 1f84fb61..a9fa7785 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -102,14 +102,12 @@ pub(crate) use group_info::GroupInfo; use self::framing::MlsMessage; pub use self::framing::Sender; pub use commit::*; - +pub use context::GroupContext; pub use roster::*; pub(crate) use transcript_hash::ConfirmedTranscriptHash; pub(crate) use util::*; -pub(crate) use context::*; - #[cfg(all(feature = "by_ref_proposal", feature = "external_client"))] pub use self::message_processor::CachedProposal; @@ -1406,13 +1404,10 @@ where )) } + /// Get the current group context summarizing various information about the group. #[inline(always)] - pub(crate) fn context(&self) -> &GroupContext { - &self.state.context - } - - pub fn context_extensions(&self) -> &ExtensionList { - &self.state.context.extensions + pub fn context(&self) -> &GroupContext { + &self.group_state().context } /// Get the @@ -1503,7 +1498,7 @@ where pub(crate) fn encryption_options(&self) -> Result { self.config .mls_rules() - .encryption_options(&self.roster(), self.context_extensions()) + .encryption_options(&self.roster(), self.group_context().extensions()) .map_err(|e| MlsError::MlsRulesError(e.into_any_error())) } diff --git a/mls-rs/src/group/resumption.rs b/mls-rs/src/group/resumption.rs index 59312715..6554c256 100644 --- a/mls-rs/src/group/resumption.rs +++ b/mls-rs/src/group/resumption.rs @@ -54,7 +54,7 @@ where group_id: &sub_group_id, cipher_suite: self.cipher_suite(), version: self.protocol_version(), - extensions: self.context_extensions(), + extensions: &self.group_state().context.extensions, }; resumption_create_group( @@ -81,7 +81,7 @@ where group_id: &[], cipher_suite: self.cipher_suite(), version: self.protocol_version(), - extensions: self.context_extensions(), + extensions: &self.group_state().context.extensions, }; resumption_join_group( @@ -291,7 +291,7 @@ async fn resumption_join_group( Err(MlsError::CipherSuiteMismatch) } else if verify_group_id && group.group_id() != expected_new_group_params.group_id { Err(MlsError::GroupIdMismatch) - } else if group.context_extensions() != expected_new_group_params.extensions { + } else if &group.group_state().context.extensions != expected_new_group_params.extensions { Err(MlsError::ReInitExtensionsMismatch) } else { Ok((group, new_member_info)) diff --git a/mls-rs/test_harness_integration/src/by_ref_proposal.rs b/mls-rs/test_harness_integration/src/by_ref_proposal.rs index c2b4150e..1bb3b315 100644 --- a/mls-rs/test_harness_integration/src/by_ref_proposal.rs +++ b/mls-rs/test_harness_integration/src/by_ref_proposal.rs @@ -218,7 +218,7 @@ pub(crate) mod inner { let request = request.into_inner(); self.send_proposal(request.state_id, move |group| { - let mut extensions = group.context_extensions().clone(); + let mut extensions = group.context().extensions().clone(); let ext_sender = SigningIdentity::mls_decode(&mut &*request.external_sender).map_err(abort)?;