diff --git a/.github/workflows/interop_tests.yml b/.github/workflows/interop_tests.yml index 6d7294a2..78dc440b 100644 --- a/.github/workflows/interop_tests.yml +++ b/.github/workflows/interop_tests.yml @@ -130,4 +130,4 @@ jobs: for config in `ls mls-rs-configs | grep -E "(application)|(commit_by_value)|(branch)|(welcome_join)" | sed -e "s/mls-rs-configs\///"`; do >&2 echo $config && test-runner/test-runner --client localhost:50001 --client localhost:50002 --suite 1 --config mls-rs-configs/$config ; done > /dev/null kill %1 kill %2 - \ No newline at end of file + diff --git a/mls-rs/src/client.rs b/mls-rs/src/client.rs index 1ebe0e83..e6a380a1 100644 --- a/mls-rs/src/client.rs +++ b/mls-rs/src/client.rs @@ -653,23 +653,20 @@ where self.config.clone(), group_info_msg, ) - .await? .build() .await } - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub async fn external_commit_builder( + pub fn external_commit_builder( &self, group_info_msg: MlsMessage, ) -> Result, MlsError> { - ExternalCommitBuilder::new( + Ok(ExternalCommitBuilder::new( self.signer()?.clone(), self.signing_identity()?.0.clone(), self.config.clone(), group_info_msg, - ) - .await + )) } /// Load an existing group state into this client using the @@ -796,13 +793,6 @@ where .ok_or(MlsError::SignerNotFound) } - /// The [PreSharedKeyStorage](crate::PreSharedKeyStorage) that - /// this client was configured to use. - #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)] - pub fn secret_store(&self) -> ::PskStore { - self.config.secret_store() - } - /// The [GroupStateStorage] that this client was configured to use. #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)] pub fn group_state_storage(&self) -> ::GroupStateStorage { @@ -892,8 +882,6 @@ mod tests { #[cfg(feature = "by_ref_proposal")] use crate::group::proposal::Proposal; use crate::group::test_utils::test_group; - #[cfg(feature = "psk")] - use crate::group::test_utils::test_group_custom_config; #[cfg(feature = "by_ref_proposal")] use crate::group::ReceivedMessage; #[cfg(feature = "psk")] @@ -998,18 +986,9 @@ mod tests { let psk = PreSharedKey::from(b"psk".to_vec()); let psk_id = ExternalPskId::new(b"psk id".to_vec()); - let mut alice_group = - test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |c| { - c.psk(psk_id.clone(), psk.clone()) - }) - .await; + let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; - let (mut bob_group, _) = alice_group - .join_with_custom_config("bob", false, |c| { - c.0.psk_store.insert(psk_id.clone(), psk.clone()); - }) - .await - .unwrap(); + let (mut bob_group, _) = alice_group.join("bob").await; let group_info_msg = alice_group .group_info_message_allowing_ext_commit(true) @@ -1022,21 +1001,17 @@ mod tests { get_test_signing_identity(TEST_CIPHER_SUITE, new_client_id.as_bytes()).await; let new_client = TestClientBuilder::new_for_test() - .psk(psk_id.clone(), psk) .signing_identity(new_client_identity.clone(), secret_key, TEST_CIPHER_SUITE) .build(); - let mut builder = new_client - .external_commit_builder(group_info_msg) - .await - .unwrap(); + let mut builder = new_client.external_commit_builder(group_info_msg).unwrap(); if do_remove { builder = builder.with_removal(1); } if with_psk { - builder = builder.with_external_psk(psk_id).unwrap(); + builder = builder.with_external_psk(psk_id.clone(), psk.clone()); } let (new_group, external_commit) = builder.build().await?; @@ -1046,14 +1021,22 @@ mod tests { assert_eq!(new_group.roster().members_iter().count(), num_members); let _ = alice_group - .process_incoming_message(external_commit.clone()) + .commit_processor(external_commit.clone()) + .await + .unwrap() + .with_external_psk(psk_id.clone(), psk.clone()) + .process() .await .unwrap(); let bob_current_epoch = bob_group.current_epoch(); let message = bob_group - .process_incoming_message(external_commit) + .commit_processor(external_commit.clone()) + .await + .unwrap() + .with_external_psk(psk_id, psk) + .process() .await .unwrap(); @@ -1067,13 +1050,13 @@ mod tests { assert_matches!( message, - ReceivedMessage::Commit(CommitMessageDescription { + CommitMessageDescription { effect: CommitEffect::Removed { new_epoch: _, remover: _ }, .. - }) + } ); } @@ -1136,7 +1119,6 @@ mod tests { let (_, external_commit) = carol .external_commit_builder(group_info_msg) - .await .unwrap() .build() .await diff --git a/mls-rs/src/client_builder.rs b/mls-rs/src/client_builder.rs index 512e0635..1c36192d 100644 --- a/mls-rs/src/client_builder.rs +++ b/mls-rs/src/client_builder.rs @@ -18,8 +18,7 @@ use crate::{ identity::CredentialType, identity::SigningIdentity, protocol_version::ProtocolVersion, - psk::{ExternalPskId, PreSharedKey}, - storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryPreSharedKeyStorage}, + storage_provider::in_memory::InMemoryGroupStateStorage, tree_kem::{Capabilities, Lifetime}, Sealed, }; @@ -31,35 +30,24 @@ use alloc::vec::Vec; #[cfg(feature = "sqlite")] use mls_rs_provider_sqlite::{ + connection_strategy::ConnectionStrategy, storage::SqLiteGroupStateStorage, SqLiteDataStorageEngine, SqLiteDataStorageError, - { - connection_strategy::ConnectionStrategy, - storage::{SqLiteGroupStateStorage, SqLitePreSharedKeyStorage}, - }, }; #[cfg(feature = "private_message")] pub use crate::group::padding::PaddingMode; /// Base client configuration type when instantiating `ClientBuilder` -pub type BaseConfig = Config< - InMemoryPreSharedKeyStorage, - InMemoryGroupStateStorage, - Missing, - DefaultMlsRules, - Missing, ->; +pub type BaseConfig = Config; /// Base client configuration type when instantiating `ClientBuilder` -pub type BaseInMemoryConfig = - Config; +pub type BaseInMemoryConfig = Config; -pub type EmptyConfig = Config; +pub type EmptyConfig = Config; /// Base client configuration that is backed by SQLite storage. #[cfg(feature = "sqlite")] -pub type BaseSqlConfig = - Config; +pub type BaseSqlConfig = Config; /// Builder for [`Client`] /// @@ -179,7 +167,6 @@ impl ClientBuilder { pub fn new() -> Self { Self(Config(ConfigInner { settings: Default::default(), - psk_store: Default::default(), group_state_storage: Default::default(), identity_provider: Missing, mls_rules: DefaultMlsRules::new(), @@ -195,7 +182,6 @@ impl ClientBuilder { pub fn new_empty() -> Self { Self(Config(ConfigInner { settings: Default::default(), - psk_store: Missing, group_state_storage: Missing, identity_provider: Missing, mls_rules: Missing, @@ -215,7 +201,6 @@ impl ClientBuilder { ) -> Result { Ok(Self(Config(ConfigInner { settings: Default::default(), - psk_store: storage.pre_shared_key_storage()?, group_state_storage: storage.group_state_storage()?, identity_provider: Missing, mls_rules: DefaultMlsRules::new(), @@ -286,28 +271,6 @@ impl ClientBuilder { ClientBuilder(c) } - /// Set the PSK store to be used by the client. - /// - /// By default, an in-memory store is used. - pub fn psk_store

(self, psk_store: P) -> ClientBuilder> - where - P: PreSharedKeyStorage, - { - let Config(c) = self.0.into_config(); - - ClientBuilder(Config(ConfigInner { - settings: c.settings, - psk_store, - group_state_storage: c.group_state_storage, - identity_provider: c.identity_provider, - mls_rules: c.mls_rules, - crypto_provider: c.crypto_provider, - signer: c.signer, - signing_identity: c.signing_identity, - version: c.version, - })) - } - /// Set the group state storage to be used by the client. /// /// By default, an in-memory storage is used. @@ -322,7 +285,6 @@ impl ClientBuilder { ClientBuilder(Config(ConfigInner { settings: c.settings, - psk_store: c.psk_store, group_state_storage, identity_provider: c.identity_provider, crypto_provider: c.crypto_provider, @@ -345,7 +307,6 @@ impl ClientBuilder { ClientBuilder(Config(ConfigInner { settings: c.settings, - psk_store: c.psk_store, group_state_storage: c.group_state_storage, identity_provider, mls_rules: c.mls_rules, @@ -368,7 +329,6 @@ impl ClientBuilder { ClientBuilder(Config(ConfigInner { settings: c.settings, - psk_store: c.psk_store, group_state_storage: c.group_state_storage, identity_provider: c.identity_provider, mls_rules: c.mls_rules, @@ -394,7 +354,6 @@ impl ClientBuilder { ClientBuilder(Config(ConfigInner { settings: c.settings, - psk_store: c.psk_store, group_state_storage: c.group_state_storage, identity_provider: c.identity_provider, mls_rules, @@ -439,7 +398,6 @@ impl ClientBuilder { impl ClientBuilder where - C::PskStore: PreSharedKeyStorage + Clone, C::GroupStateStorage: GroupStateStorage + Clone, C::IdentityProvider: IdentityProvider + Clone, C::MlsRules: MlsRules + Clone, @@ -469,39 +427,14 @@ where } } -impl> ClientBuilder { - /// Add a PSK to the in-memory PSK store. - pub fn psk( - self, - psk_id: ExternalPskId, - psk: PreSharedKey, - ) -> ClientBuilder> { - let mut c = self.0.into_config(); - c.0.psk_store.insert(psk_id, psk); - ClientBuilder(c) - } -} - /// Marker type for required `ClientBuilder` services that have not been specified yet. #[derive(Debug)] pub struct Missing; -/// Change the PSK store used by a client configuration. -/// -/// See [`ClientBuilder::psk_store`]. -pub type WithPskStore = Config< - P, - ::GroupStateStorage, - ::IdentityProvider, - ::MlsRules, - ::CryptoProvider, ->; - /// Change the group state storage used by a client configuration. /// /// See [`ClientBuilder::group_state_storage`]. pub type WithGroupStateStorage = Config< - ::PskStore, G, ::IdentityProvider, ::MlsRules, @@ -512,7 +445,6 @@ pub type WithGroupStateStorage = Config< /// /// See [`ClientBuilder::identity_provider`]. pub type WithIdentityProvider = Config< - ::PskStore, ::GroupStateStorage, I, ::MlsRules, @@ -523,7 +455,6 @@ pub type WithIdentityProvider = Config< /// /// See [`ClientBuilder::mls_rules`]. pub type WithMlsRules = Config< - ::PskStore, ::GroupStateStorage, ::IdentityProvider, Pr, @@ -534,7 +465,6 @@ pub type WithMlsRules = Config< /// /// See [`ClientBuilder::crypto_provider`]. pub type WithCryptoProvider = Config< - ::PskStore, ::GroupStateStorage, ::IdentityProvider, ::MlsRules, @@ -543,7 +473,6 @@ pub type WithCryptoProvider = Config< /// Helper alias for `Config`. pub type IntoConfigOutput = Config< - ::PskStore, ::GroupStateStorage, ::IdentityProvider, ::MlsRules, @@ -552,22 +481,19 @@ pub type IntoConfigOutput = Config< /// Helper alias to make a `Config` from a `ClientConfig` pub type MakeConfig = Config< - ::PskStore, ::GroupStateStorage, ::IdentityProvider, ::MlsRules, ::CryptoProvider, >; -impl ClientConfig for ConfigInner +impl ClientConfig for ConfigInner where - Ps: PreSharedKeyStorage + Clone, Gss: GroupStateStorage + Clone, Ip: IdentityProvider + Clone, Pr: MlsRules + Clone, Cp: CryptoProvider + Clone, { - type PskStore = Ps; type GroupStateStorage = Gss; type IdentityProvider = Ip; type MlsRules = Pr; @@ -585,10 +511,6 @@ where self.mls_rules.clone() } - fn secret_store(&self) -> Self::PskStore { - self.psk_store.clone() - } - fn group_state_storage(&self) -> Self::GroupStateStorage { self.group_state_storage.clone() } @@ -619,17 +541,16 @@ where } } -impl Sealed for Config {} +impl Sealed for Config {} -impl MlsConfig for Config +impl MlsConfig for Config where - Ps: PreSharedKeyStorage + Clone, Gss: GroupStateStorage + Clone, Ip: IdentityProvider + Clone, Pr: MlsRules + Clone, Cp: CryptoProvider + Clone, { - type Output = ConfigInner; + type Output = ConfigInner; fn get(&self) -> &Self::Output { &self.0 @@ -649,7 +570,6 @@ pub trait MlsConfig: Clone + Send + Sync + Sealed { /// Blanket implementation so that `T: MlsConfig` implies `T: ClientConfig` impl ClientConfig for T { - type PskStore = ::PskStore; type GroupStateStorage = ::GroupStateStorage; type IdentityProvider = ::IdentityProvider; type MlsRules = ::MlsRules; @@ -671,10 +591,6 @@ impl ClientConfig for T { self.get().mls_rules() } - fn secret_store(&self) -> Self::PskStore { - self.get().secret_store() - } - fn group_state_storage(&self) -> Self::GroupStateStorage { self.get().group_state_storage() } @@ -739,7 +655,6 @@ pub(crate) fn recreate_config( l.not_after - l.not_before }, }, - psk_store: c.secret_store(), group_state_storage: c.group_state_storage(), identity_provider: c.identity_provider(), mls_rules: c.mls_rules(), @@ -762,12 +677,11 @@ mod private { use crate::client_builder::{IntoConfigOutput, Settings}; #[derive(Clone, Debug)] - pub struct Config(pub(crate) ConfigInner); + pub struct Config(pub(crate) ConfigInner); #[derive(Clone, Debug)] - pub struct ConfigInner { + pub struct ConfigInner { pub(crate) settings: Settings, - pub(crate) psk_store: Ps, pub(crate) group_state_storage: Gss, pub(crate) identity_provider: Ip, pub(crate) mls_rules: Pr, @@ -778,7 +692,6 @@ mod private { } pub trait IntoConfig { - type PskStore; type GroupStateStorage; type IdentityProvider; type MlsRules; @@ -787,8 +700,7 @@ mod private { fn into_config(self) -> IntoConfigOutput; } - impl IntoConfig for Config { - type PskStore = Ps; + impl IntoConfig for Config { type GroupStateStorage = Gss; type IdentityProvider = Ip; type MlsRules = Pr; @@ -804,7 +716,6 @@ use mls_rs_core::{ crypto::{CryptoProvider, SignatureSecretKey}, group::GroupStateStorage, identity::IdentityProvider, - psk::PreSharedKeyStorage, }; use private::{Config, ConfigInner, IntoConfig}; diff --git a/mls-rs/src/client_config.rs b/mls-rs/src/client_config.rs index 0b7e7ea0..2c1a28e2 100644 --- a/mls-rs/src/client_config.rs +++ b/mls-rs/src/client_config.rs @@ -11,13 +11,9 @@ use crate::{ ExtensionList, }; use alloc::vec::Vec; -use mls_rs_core::{ - crypto::CryptoProvider, group::GroupStateStorage, identity::IdentityProvider, - psk::PreSharedKeyStorage, -}; +use mls_rs_core::{crypto::CryptoProvider, group::GroupStateStorage, identity::IdentityProvider}; pub trait ClientConfig: Send + Sync + Clone { - type PskStore: PreSharedKeyStorage + Clone; type GroupStateStorage: GroupStateStorage + Clone; type IdentityProvider: IdentityProvider + Clone; type MlsRules: MlsRules + Clone; @@ -29,7 +25,6 @@ pub trait ClientConfig: Send + Sync + Clone { fn mls_rules(&self) -> Self::MlsRules; - fn secret_store(&self) -> Self::PskStore; fn group_state_storage(&self) -> Self::GroupStateStorage; fn identity_provider(&self) -> Self::IdentityProvider; fn crypto_provider(&self) -> Self::CryptoProvider; diff --git a/mls-rs/src/external_client/group.rs b/mls-rs/src/external_client/group.rs index bc877c2d..d8300afa 100644 --- a/mls-rs/src/external_client/group.rs +++ b/mls-rs/src/external_client/group.rs @@ -36,11 +36,13 @@ use crate::{ identity::SigningIdentity, mls_rules::{CommitSource, ProposalBundle}, protocol_version::ProtocolVersion, - psk::AlwaysFoundPskStorage, tree_kem::{leaf_node::LeafNode, node::LeafIndex, path_secret::PathSecret, TreeKemPrivate}, CryptoProvider, KeyPackage, MlsMessage, }; +#[cfg(feature = "psk")] +use crate::psk::secret::PskSecretInput; + #[cfg(feature = "by_ref_proposal")] use crate::{ group::{ @@ -581,7 +583,6 @@ where { type MlsRules = C::MlsRules; type IdentityProvider = C::IdentityProvider; - type PreSharedKeyStorage = AlwaysFoundPskStorage; type OutputType = ExternalReceivedMessage; type CipherSuiteProvider = ::CipherSuiteProvider; @@ -617,6 +618,7 @@ where interim_transcript_hash: InterimTranscriptHash, confirmation_tag: &ConfirmationTag, provisional_public_state: ProvisionalState, + #[cfg(feature = "psk")] _psks: &[PskSecretInput], ) -> Result<(), MlsError> { self.state.context = provisional_public_state.group_context; #[cfg(feature = "by_ref_proposal")] @@ -632,10 +634,6 @@ where self.config.identity_provider() } - fn psk_storage(&self) -> Self::PreSharedKeyStorage { - AlwaysFoundPskStorage - } - fn group_state(&self) -> &GroupState { &self.state } diff --git a/mls-rs/src/group/commit/builder.rs b/mls-rs/src/group/commit/builder.rs index e53e1616..592369b9 100644 --- a/mls-rs/src/group/commit/builder.rs +++ b/mls-rs/src/group/commit/builder.rs @@ -5,6 +5,8 @@ use alloc::boxed::Box; use alloc::{vec, vec::Vec}; +#[cfg(feature = "psk")] +use itertools::Itertools; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError}; @@ -34,8 +36,13 @@ use crate::WireFormat; #[cfg(feature = "psk")] use crate::{ - group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk}, - psk::ExternalPskId, + group::{ + JustPreSharedKeyID, PreSharedKeyProposal, PskGroupId, ResumptionPSKUsage, ResumptionPsk, + }, + psk::{ + secret::{PskSecret, PskSecretInput}, + ExternalPskId, + }, }; use crate::group::{ @@ -45,13 +52,15 @@ use crate::group::{ message_hash::MessageHash, message_processor::{path_update_required, MessageProcessor}, message_signature::AuthenticatedContent, - mls_rules::CommitDirection, proposal::Proposal, CommitEffect, CommitMessageDescription, EncryptedGroupSecrets, EpochSecrets, ExportedTree, Group, GroupContext, GroupInfo, GroupState, InterimTranscriptHash, NewEpoch, PendingCommitSnapshot, Welcome, }; +#[cfg(feature = "by_ref_proposal")] +use crate::group::mls_rules::CommitDirection; + #[cfg(feature = "custom_proposal")] use crate::group::proposal::CustomProposal; @@ -176,6 +185,8 @@ where new_signer: Option, new_signing_identity: Option, new_leaf_node_extensions: Option, + #[cfg(feature = "psk")] + psks: Vec>, sender: Sender, } @@ -222,13 +233,59 @@ where Ok(self.with_proposal(proposal)) } + /// Add an external PSK that can be used to fulfil PSK requirements that were + /// established via a [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) + /// from another client during the current epoch. + #[cfg(feature = "psk")] + pub fn apply_external_psk(self, id: ExternalPskId, psk: crate::psk::PreSharedKey) -> Self { + self.apply_psk(JustPreSharedKeyID::External(id), psk) + } + + /// Add an resumption PSK that can be used to fulfil PSK requirements that were + /// established via a [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) + /// from another client during the current epoch. + #[cfg(feature = "psk")] + pub fn apply_resumption_psk(self, id: ResumptionPsk, psk: crate::psk::PreSharedKey) -> Self { + self.apply_psk(JustPreSharedKeyID::Resumption(id), psk) + } + + #[cfg(feature = "psk")] + fn apply_psk(mut self, id: JustPreSharedKeyID, psk: crate::psk::PreSharedKey) -> Self { + if let Some((i, proposal)) = self + .proposals + .psks + .iter() + .filter(|proposal| proposal.is_by_reference()) + .find_position(|p| p.proposal.psk.key_id == id) + { + let id = proposal.proposal.psk.clone(); + let secret_input = PskSecretInput { id, psk }; + self.psks[i] = Some(secret_input); + } + + self + } + /// Insert a /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with /// an external PSK into the current commit that is being built. #[cfg(feature = "psk")] - pub fn add_external_psk(self, psk_id: ExternalPskId) -> Result { - let key_id = JustPreSharedKeyID::External(psk_id); - let proposal = self.group.psk_proposal(key_id)?; + pub fn add_external_psk( + mut self, + id: ExternalPskId, + psk: crate::psk::PreSharedKey, + ) -> Result { + use crate::psk::PreSharedKeyID; + + let key_id = JustPreSharedKeyID::External(id); + let id = PreSharedKeyID::new(key_id, &self.group.cipher_suite_provider)?; + + self.psks.push(Some(PskSecretInput { + id: id.clone(), + psk, + })); + + let proposal = Proposal::Psk(PreSharedKeyProposal { psk: id }); Ok(self.with_proposal(proposal)) } @@ -236,9 +293,13 @@ where /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with /// a resumption PSK into the current commit that is being built. #[cfg(feature = "psk")] - pub fn add_resumption_psk(self, psk_epoch: u64) -> Result { + pub fn add_resumption_psk( + self, + psk_epoch: u64, + psk: crate::psk::PreSharedKey, + ) -> Result { let group_id = self.group.group_id().to_vec(); - self.add_resumption_psk_for_group(psk_epoch, group_id) + self.add_resumption_psk_for_group(psk_epoch, group_id, psk) } /// Insert a @@ -246,10 +307,13 @@ where /// a resumption PSK into the current commit that is being built. #[cfg(feature = "psk")] pub fn add_resumption_psk_for_group( - self, + mut self, psk_epoch: u64, group_id: Vec, + psk: crate::psk::PreSharedKey, ) -> Result { + use crate::psk::PreSharedKeyID; + let psk_id = ResumptionPsk { psk_epoch, usage: ResumptionPSKUsage::Application, @@ -257,7 +321,14 @@ where }; let key_id = JustPreSharedKeyID::Resumption(psk_id); - let proposal = self.group.psk_proposal(key_id)?; + let id = PreSharedKeyID::new(key_id, &self.group.cipher_suite_provider)?; + + self.psks.push(Some(PskSecretInput { + id: id.clone(), + psk, + })); + + let proposal = Proposal::Psk(PreSharedKeyProposal { psk: id }); Ok(self.with_proposal(proposal)) } @@ -366,6 +437,13 @@ where /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules). #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn build(self) -> Result { + #[cfg(feature = "psk")] + let psks = self + .psks + .into_iter() + .map(|psk| psk.ok_or(MlsError::MissingRequiredPsk)) + .try_collect()?; + let (output, pending_commit) = self .group .commit_internal( @@ -376,6 +454,8 @@ where self.new_signer, self.new_signing_identity, self.new_leaf_node_extensions, + #[cfg(feature = "psk")] + psks, self.sender, ) .await?; @@ -391,6 +471,13 @@ where /// A detached commit can be applied using `Group::apply_detached_commit`. #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn build_detached(self) -> Result<(CommitOutput, CommitSecrets), MlsError> { + #[cfg(feature = "psk")] + let psks = self + .psks + .into_iter() + .map(|psk| psk.ok_or(MlsError::MissingRequiredPsk)) + .try_collect()?; + let (output, pending_commit) = self .group .commit_internal( @@ -401,6 +488,8 @@ where self.new_signer, self.new_signing_identity, self.new_leaf_node_extensions, + #[cfg(feature = "psk")] + psks, self.sender, ) .await?; @@ -484,11 +573,18 @@ where /// Create a new commit builder that can include proposals /// by-value. pub fn commit_builder(&mut self) -> CommitBuilder { + #[cfg(feature = "by_ref_proposal")] + let proposals = self.state.proposals.prepare_commit(); + + #[cfg(not(feature = "by_ref_proposal"))] + let proposals = Default::default(); + CommitBuilder { - #[cfg(feature = "by_ref_proposal")] - proposals: self.state.proposals.prepare_commit(), - #[cfg(not(feature = "by_ref_proposal"))] - proposals: Default::default(), + #[cfg(all(feature = "by_ref_proposal", feature = "psk"))] + psks: vec![None; proposals.psk_proposals().len()], + #[cfg(all(not(feature = "by_ref_proposal"), feature = "psk"))] + psks: vec![], + proposals, sender: Sender::Member(*self.private_tree.self_index), group: self, authenticated_data: Default::default(), @@ -512,6 +608,7 @@ where new_signer: Option, new_signing_identity: Option, new_leaf_node_extensions: Option, + #[cfg(feature = "psk")] psks: Vec, sender: Sender, ) -> Result<(CommitOutput, PendingCommit), MlsError> { if !self.pending_commit.is_none() { @@ -548,8 +645,13 @@ where &self.config.identity_provider(), &self.cipher_suite_provider, time, + #[cfg(feature = "by_ref_proposal")] CommitDirection::Send, - &self.config.secret_store(), + #[cfg(feature = "psk")] + &psks + .iter() + .map(|psk| psk.id.key_id.clone()) + .collect::>(), &committer, ) .await?; @@ -636,12 +738,22 @@ where }; #[cfg(feature = "psk")] - let (psk_secret, psks) = self - .get_psk(&provisional_state.applied_proposals.psks) - .await?; + let (psk_secret, psk_ids) = { + if let Some(previous) = self.previous_psk.as_ref() { + ( + PskSecret::calculate(&[previous.clone()], &self.cipher_suite_provider).await?, + vec![previous.id.clone()], + ) + } else { + ( + PskSecret::calculate(&psks, &self.cipher_suite_provider).await?, + psks.into_iter().map(|psk| psk.id).collect::>(), + ) + } + }; #[cfg(not(feature = "psk"))] - let psk_secret = self.get_psk(); + let psk_secret = crate::psk::secret::PskSecret::new(&self.cipher_suite_provider); let added_key_pkgs: Vec<_> = provisional_state .applied_proposals @@ -788,7 +900,7 @@ where &key_schedule_result.joiner_secret, path_secrets, #[cfg(feature = "psk")] - psks.clone(), + psk_ids.clone(), &encrypted_group_info, ) }) @@ -809,7 +921,7 @@ where &key_schedule_result.joiner_secret, path_secrets, #[cfg(feature = "psk")] - psks.clone(), + psk_ids.clone(), &encrypted_group_info, ) .await?, @@ -1003,7 +1115,7 @@ mod tests { #[cfg(feature = "psk")] use crate::{ group::proposal::PreSharedKeyProposal, - psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID}, + psk::{JustPreSharedKeyID, PreSharedKeyID}, }; use super::*; @@ -1157,17 +1269,15 @@ mod tests { #[cfg(feature = "psk")] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_commit_builder_psk() { + use mls_rs_core::psk::PreSharedKey; + let mut group = test_commit_builder_group().await; let test_psk = ExternalPskId::new(vec![1]); - - group - .config - .secret_store() - .insert(test_psk.clone(), PreSharedKey::from(vec![1])); + let psk_data = PreSharedKey::new(vec![2]); let commit_output = group .commit_builder() - .add_external_psk(test_psk.clone()) + .add_external_psk(test_psk.clone(), psk_data) .unwrap() .build() .await diff --git a/mls-rs/src/group/commit/processor.rs b/mls-rs/src/group/commit/processor.rs index 93da9757..0204acde 100644 --- a/mls-rs/src/group/commit/processor.rs +++ b/mls-rs/src/group/commit/processor.rs @@ -18,11 +18,20 @@ use crate::{ CommitEffect, CommitMessageDescription, ConfirmationTag, Content, EventOrContent, InterimTranscriptHash, MessageProcessor, NewEpoch, }, - mls_rules::{CommitDirection, CommitSource, ProposalBundle}, + mls_rules::{CommitSource, ProposalBundle}, tree_kem::{leaf_node::LeafNode, node::LeafIndex, validate_update_path, UpdatePath}, Group, MlsMessage, }; +#[cfg(feature = "by_ref_proposal")] +use crate::mls_rules::CommitDirection; + +#[cfg(feature = "psk")] +use mls_rs_core::psk::{ExternalPskId, PreSharedKey}; + +#[cfg(feature = "psk")] +use crate::psk::{secret::PskSecretInput, JustPreSharedKeyID, ResumptionPsk}; + pub(crate) struct InternalCommitProcessor<'a, P: MessageProcessor> { // Group pub(crate) processor: &'a mut P, @@ -38,6 +47,10 @@ pub(crate) struct InternalCommitProcessor<'a, P: MessageProcessor> { // Processing options pub(crate) time_sent: Option, + + // Outside inputs + #[cfg(feature = "psk")] + pub(crate) psks: Vec<(JustPreSharedKeyID, PreSharedKey)>, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] @@ -114,6 +127,9 @@ pub(crate) async fn commit_processor_from_content( .confirmation_tag .ok_or(MlsError::InvalidConfirmationTag)?, time_sent: None, + // FIXME: Where can this actually come from??? + #[cfg(feature = "psk")] + psks: Default::default(), }) } @@ -124,9 +140,6 @@ pub(crate) async fn process_commit( let id_provider = commit_processor.processor.identity_provider(); let cs_provider = commit_processor.processor.cipher_suite_provider(); - // TODO remove - let psk_storage = commit_processor.processor.psk_storage(); - let mut provisional_state = commit_processor .processor .group_state() @@ -136,8 +149,14 @@ pub(crate) async fn process_commit( &id_provider, &cs_provider, commit_processor.time_sent, + #[cfg(feature = "by_ref_proposal")] CommitDirection::Receive, - &psk_storage, + #[cfg(feature = "psk")] + &commit_processor + .psks + .iter() + .map(|(id, _)| id.clone()) + .collect::>(), &commit_processor.committer, ) .await?; @@ -226,6 +245,25 @@ pub(crate) async fn process_commit( .tree_hash(&cs_provider) .await?; + #[cfg(feature = "psk")] + let psk_inputs = provisional_state + .applied_proposals + .psks + .iter() + .map(|id| { + commit_processor + .psks + .iter() + .find_map(|(psk_id, psk)| { + (*psk_id == id.proposal.psk.key_id).then(|| PskSecretInput { + id: id.proposal.psk.clone(), + psk: psk.clone(), + }) + }) + .ok_or_else(|| MlsError::MissingRequiredPsk) + }) + .collect::, _>>()?; + if !is_self_removed { // Update the key schedule to calculate new private keys commit_processor @@ -235,6 +273,8 @@ pub(crate) async fn process_commit( commit_processor.interim_transcript_hash, &commit_processor.confirmation_tag, provisional_state, + #[cfg(feature = "psk")] + &psk_inputs, ) .await?; } @@ -263,6 +303,18 @@ impl CommitProcessor<'_, C> { }) } + #[cfg(feature = "psk")] + pub fn with_external_psk(mut self, id: ExternalPskId, psk: crate::psk::PreSharedKey) -> Self { + self.0.psks.push((JustPreSharedKeyID::External(id), psk)); + self + } + + #[cfg(feature = "psk")] + pub fn with_resumption_psk(mut self, id: ResumptionPsk, psk: crate::psk::PreSharedKey) -> Self { + self.0.psks.push((JustPreSharedKeyID::Resumption(id), psk)); + self + } + pub fn proposals_mut(&mut self) -> &mut ProposalBundle { &mut self.0.proposals } @@ -288,6 +340,24 @@ impl CommitProcessor<'_, C> { &self.0.authenticated_data } + #[cfg(feature = "psk")] + pub fn required_external_psk(&self) -> impl Iterator { + self.0 + .proposals + .psk_proposals() + .iter() + .filter_map(|p| p.proposal.external_psk_id()) + } + + #[cfg(feature = "psk")] + pub fn required_resumption_psk(&self) -> impl Iterator { + self.0 + .proposals + .psk_proposals() + .iter() + .filter_map(|p| p.proposal.resumption_psk_id()) + } + pub fn context(&self) -> &GroupContext { self.0.processor.context() } diff --git a/mls-rs/src/group/external_commit.rs b/mls-rs/src/group/external_commit.rs index 422eaa79..5a7110cf 100644 --- a/mls-rs/src/group/external_commit.rs +++ b/mls-rs/src/group/external_commit.rs @@ -4,7 +4,6 @@ use mls_rs_core::{ crypto::SignatureSecretKey, extension::ExtensionList, identity::SigningIdentity, - protocol_version::ProtocolVersion, }; use crate::{ @@ -13,8 +12,8 @@ use crate::{ cipher_suite_provider, epoch::SenderDataSecret, key_schedule::{InitSecret, KeySchedule}, - proposal::{ExternalInit, Proposal, RemoveProposal}, - EpochSecrets, ExternalPubExt, LeafIndex, LeafNode, MlsError, TreeKemPrivate, + proposal::{ExternalInit, Proposal}, + EpochSecrets, ExternalPubExt, LeafNode, MlsError, TreeKemPrivate, }, mls_rules::{ProposalBundle, ProposalSource}, Group, MlsMessage, @@ -32,17 +31,14 @@ use alloc::vec; use alloc::vec::Vec; #[cfg(feature = "psk")] -use mls_rs_core::psk::{ExternalPskId, PreSharedKey}; +use crate::psk::{secret::PskSecretInput, ExternalPskId, PreSharedKey}; #[cfg(feature = "psk")] use crate::group::{ PreSharedKeyProposal, {JustPreSharedKeyID, PreSharedKeyID}, }; -use super::{validate_tree_and_info_joiner, ExportedTree, GroupInfo, Sender}; - -#[cfg(feature = "custom_proposal")] -use super::PublicMessage; +use super::{validate_tree_and_info_joiner, ExportedTree, Sender}; /// A builder that aids with the construction of an external commit. #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))] @@ -52,62 +48,40 @@ pub struct ExternalCommitBuilder { leaf_node_extensions: ExtensionList, config: C, tree_data: Option>, + to_remove: Option, + #[cfg(feature = "psk")] + external_psks: Vec<(ExternalPskId, PreSharedKey)>, authenticated_data: Vec, #[cfg(feature = "custom_proposal")] - received_custom_proposals: Vec, - proposals: ProposalBundle, - group_info: GroupInfo, - protocol_version: ProtocolVersion, - init_secret: InitSecret, + custom_proposals: Vec, + #[cfg(feature = "custom_proposal")] + received_custom_proposals: Vec, + group_info: MlsMessage, } impl ExternalCommitBuilder { - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub(crate) async fn new( + pub(crate) fn new( signer: SignatureSecretKey, signing_identity: SigningIdentity, config: C, group_info: MlsMessage, - ) -> Result { - let protocol_version = group_info.version; - - if !config.version_supported(protocol_version) { - return Err(MlsError::UnsupportedProtocolVersion(protocol_version)); - } - - let group_info = group_info - .into_group_info() - .ok_or(MlsError::UnexpectedMessageType)?; - - let cipher_suite = cipher_suite_provider( - config.crypto_provider(), - group_info.group_context.cipher_suite, - )?; - - let external_pub_ext = group_info - .extensions - .get_as::()? - .ok_or(MlsError::MissingExternalPubExtension)?; - - let (init_secret, kem_output) = - InitSecret::encode_for_external(&cipher_suite, &external_pub_ext.external_pub).await?; - - let builder = Self { - tree_data: None, - authenticated_data: Vec::new(), + ) -> Self { + Self { signer, signing_identity, - leaf_node_extensions: Default::default(), config, + group_info, + tree_data: None, + authenticated_data: Vec::new(), + leaf_node_extensions: Default::default(), + to_remove: Default::default(), + #[cfg(feature = "custom_proposal")] + custom_proposals: Vec::new(), #[cfg(feature = "custom_proposal")] received_custom_proposals: Vec::new(), - proposals: Default::default(), - group_info, - protocol_version, - init_secret, - }; - - Ok(builder.with_proposal(Proposal::ExternalInit(ExternalInit { kem_output }))) + #[cfg(feature = "psk")] + external_psks: Default::default(), + } } #[must_use] @@ -124,9 +98,10 @@ impl ExternalCommitBuilder { /// Propose the removal of an old version of the client as part of the external commit. /// Only one such proposal is allowed. pub fn with_removal(self, to_remove: u32) -> Self { - self.with_proposal(Proposal::Remove(RemoveProposal { - to_remove: LeafIndex(to_remove), - })) + Self { + to_remove: Some(to_remove), + ..self + } } #[must_use] @@ -140,23 +115,17 @@ impl ExternalCommitBuilder { #[cfg(feature = "psk")] /// Add an external psk to the group as part of the external commit. - pub fn with_external_psk(self, psk: ExternalPskId) -> Result { - let cipher_suite = cipher_suite_provider( - self.config.crypto_provider(), - self.group_info.group_context.cipher_suite, - )?; - - let key_id = JustPreSharedKeyID::External(psk); - let psk = PreSharedKeyID::new(key_id, &cipher_suite)?; - let proposal = Proposal::Psk(PreSharedKeyProposal { psk }); - Ok(self.with_proposal(proposal)) + pub fn with_external_psk(mut self, id: ExternalPskId, psk: PreSharedKey) -> Self { + self.external_psks.push((id, psk)); + self } #[cfg(feature = "custom_proposal")] #[must_use] /// Insert a [`CustomProposal`] into the current commit that is being built. - pub fn with_custom_proposal(self, proposal: CustomProposal) -> Self { - self.with_proposal(Proposal::Custom(proposal)) + pub fn with_custom_proposal(mut self, proposal: CustomProposal) -> Self { + self.custom_proposals.push(Proposal::Custom(proposal)); + self } #[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))] @@ -168,20 +137,9 @@ impl ExternalCommitBuilder { /// The authenticity of the proposal is NOT fully verified. It is only verified the /// same way as by [`ExternalGroup`](`crate::external_client::ExternalGroup`). /// The proposal MUST be an MlsPlaintext, else the [`Self::build`] function will fail. - pub fn with_received_custom_proposal(mut self, proposal: MlsMessage) -> Result { - let MlsMessagePayload::Plain(plaintext) = proposal.payload else { - return Err(MlsError::UnexpectedMessageType); - }; - - let super::Content::Proposal(proposal) = plaintext.content.content.clone() else { - return Err(MlsError::UnexpectedMessageType); - }; - - // We store proposal to verify authenticity later. At this point this may not be possible if we - // don't have the tree. - self.received_custom_proposals.push(plaintext); - - Ok(self.with_proposal(*proposal)) + pub fn with_received_custom_proposal(mut self, proposal: MlsMessage) -> Self { + self.received_custom_proposals.push(proposal); + self } /// Change the committer's leaf node extensions as part of making this commit. @@ -192,17 +150,19 @@ impl ExternalCommitBuilder { } } - fn with_proposal(mut self, proposal: Proposal) -> Self { - self.proposals - .add(proposal, Sender::NewMemberCommit, ProposalSource::ByValue); - - self - } - /// Build the external commit using a GroupInfo message provided by an existing group member. #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn build(self) -> Result<(Group, MlsMessage), MlsError> { - let group_info = self.group_info; + let protocol_version = self.group_info.version; + + if !self.config.version_supported(protocol_version) { + return Err(MlsError::UnsupportedProtocolVersion(protocol_version)); + } + + let group_info = self + .group_info + .into_group_info() + .ok_or(MlsError::UnexpectedMessageType)?; let cipher_suite = cipher_suite_provider( self.config.crypto_provider(), @@ -210,7 +170,7 @@ impl ExternalCommitBuilder { )?; let public_tree = validate_tree_and_info_joiner( - self.protocol_version, + protocol_version, &group_info, self.tree_data, &self.config.identity_provider(), @@ -235,31 +195,90 @@ impl ExternalCommitBuilder { secret_tree: SecretTree::empty(), }; + let external_pub_ext = group_info + .extensions + .get_as::()? + .ok_or(MlsError::MissingExternalPubExtension)?; + + let (init_secret, kem_output) = + InitSecret::encode_for_external(&cipher_suite, &external_pub_ext.external_pub).await?; + let (mut group, _) = Group::join_with( self.config, group_info, public_tree, - KeySchedule::new(self.init_secret), + KeySchedule::new(init_secret), epoch_secrets, TreeKemPrivate::new_for_external(), self.signer, ) .await?; + let mut proposals = vec![Proposal::ExternalInit(ExternalInit { kem_output })]; + + if let Some(to_remove) = self.to_remove { + proposals.push(Proposal::Remove(to_remove.into())); + } + + #[cfg(feature = "psk")] + let psks = self + .external_psks + .into_iter() + .map(|(psk_id, psk_secret)| { + let key_id = + PreSharedKeyID::new(JustPreSharedKeyID::External(psk_id), &cipher_suite)?; + + proposals.push(Proposal::Psk(PreSharedKeyProposal { + psk: key_id.clone(), + })); + + Ok(PskSecretInput { + id: key_id, + psk: psk_secret, + }) + }) + .collect::, MlsError>>()?; + + #[cfg(feature = "custom_proposal")] + { + let mut custom_proposals = self.custom_proposals; + proposals.append(&mut custom_proposals); + } + #[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))] for message in self.received_custom_proposals { - verify_plaintext_authentication(&cipher_suite, message, None, &group.state).await?; + let MlsMessagePayload::Plain(plaintext) = message.payload else { + return Err(MlsError::UnexpectedMessageType); + }; + + let super::Content::Proposal(proposal) = plaintext.content.content.clone() else { + return Err(MlsError::UnexpectedMessageType); + }; + + verify_plaintext_authentication(&cipher_suite, plaintext, None, &group.state).await?; + + proposals.push(*proposal); } + let proposal_bundle = + proposals + .into_iter() + .fold(ProposalBundle::default(), |mut bundle, proposal| { + bundle.add(proposal, Sender::NewMemberCommit, ProposalSource::ByValue); + bundle + }); + let (commit_output, pending_commit) = group .commit_internal( - self.proposals, + proposal_bundle, Some(&leaf_node), self.authenticated_data, Default::default(), None, None, None, + #[cfg(feature = "psk")] + psks, Sender::NewMemberCommit, ) .await?; diff --git a/mls-rs/src/group/interop_test_vectors/passive_client.rs b/mls-rs/src/group/interop_test_vectors/passive_client.rs index b9a89ff0..5a739ffe 100644 --- a/mls-rs/src/group/interop_test_vectors/passive_client.rs +++ b/mls-rs/src/group/interop_test_vectors/passive_client.rs @@ -167,14 +167,10 @@ async fn interop_passive_client() { let id = key_package.leaf_node.signing_identity.clone(); let key = test_case.signature_priv.clone().into(); - let mut client_builder = ClientBuilder::new() + let client_builder = ClientBuilder::new() .crypto_provider(crypto_provider) .identity_provider(BasicIdentityProvider::new()); - for psk in test_case.external_psks { - client_builder = client_builder.psk(ExternalPskId::new(psk.psk_id), psk.psk.into()); - } - let client = client_builder .signing_identity(id, key, cs.cipher_suite()) .build(); @@ -193,6 +189,10 @@ async fn interop_passive_client() { .await .unwrap(); + joiner = test_case.external_psks.iter().fold(joiner, |joiner, psk| { + joiner.with_external_psk(psk.psk_id.clone().into(), psk.psk.clone().into()) + }); + if let Some(tree) = test_case.ratchet_tree { joiner = joiner.ratchet_tree(ExportedTree::from_bytes(&tree.0).unwrap()) } @@ -216,10 +216,33 @@ async fn interop_passive_client() { let message = MlsMessage::from_bytes(&epoch.commit).unwrap(); - group - .process_incoming_message_with_time(message, MlsTime::now()) - .await - .unwrap(); + let group_clone = group.clone(); + + let mut processor = group.commit_processor(message).await.unwrap(); + + processor = test_case + .external_psks + .iter() + .fold(processor, |processor, psk| { + processor.with_external_psk(psk.psk_id.clone().into(), psk.psk.clone().into()) + }); + + for resumption_psk in processor + .required_resumption_psk() + .cloned() + .collect::>() + .iter() + { + processor = processor.with_resumption_psk( + resumption_psk.clone(), + group_clone + .resumption_secret(resumption_psk.psk_epoch) + .await + .unwrap(), + ); + } + + processor.process().await.unwrap(); assert_eq!( epoch.epoch_authenticator, @@ -263,9 +286,17 @@ async fn invite_passive_client( .add_member(key_pckg.key_package_message.clone()) .unwrap(); + let external_psk = TestExternalPsk { + psk_id: TEST_EXT_PSK_ID.to_vec(), + psk: make_test_ext_psk(), + }; + if with_psk { commit_builder = commit_builder - .add_external_psk(ExternalPskId::new(TEST_EXT_PSK_ID.to_vec())) + .add_external_psk( + external_psk.psk_id.clone().into(), + external_psk.psk.clone().into(), + ) .unwrap(); } @@ -273,11 +304,6 @@ async fn invite_passive_client( all_process_message(groups, &commit.commit_message, 0, true).await; - let external_psk = TestExternalPsk { - psk_id: TEST_EXT_PSK_ID.to_vec(), - psk: make_test_ext_psk(), - }; - TestCase { cipher_suite: cs.cipher_suite().into(), key_package: key_pckg.key_package_data.key_package_bytes, @@ -342,16 +368,25 @@ pub async fn generate_passive_client_proposal_tests() -> Vec { let test_case = commit_by_value( &mut groups[1].clone(), - |b| b.add_external_psk(psk.clone()).unwrap(), + |b| { + b.add_external_psk(psk.clone(), make_test_ext_psk().into()) + .unwrap() + }, partial_test_case.clone(), ) .await; test_cases.push(test_case); + let epoch_to_resume = groups[1].current_epoch() - 1; + let resumption_psk = groups[1].resumption_secret(epoch_to_resume).await.unwrap(); + let test_case = commit_by_value( &mut groups[5].clone(), - |b| b.add_resumption_psk(groups[1].current_epoch() - 1).unwrap(), + |b| { + b.add_resumption_psk(epoch_to_resume, resumption_psk) + .unwrap() + }, partial_test_case.clone(), ) .await; @@ -367,6 +402,9 @@ pub async fn generate_passive_client_proposal_tests() -> Vec { test_cases.push(test_case); + let epoch_to_resume = groups[4].current_epoch() - 1; + let resumption_psk = groups[4].resumption_secret(epoch_to_resume).await.unwrap(); + let test_case = commit_by_value( &mut groups[3].clone(), |b| { @@ -374,9 +412,9 @@ pub async fn generate_passive_client_proposal_tests() -> Vec { .unwrap() .remove_member(5) .unwrap() - .add_external_psk(psk.clone()) + .add_external_psk(psk.clone(), make_test_ext_psk().into()) .unwrap() - .add_resumption_psk(groups[4].current_epoch() - 1) + .add_resumption_psk(epoch_to_resume, resumption_psk) .unwrap() .set_group_context_ext(Default::default()) .unwrap() @@ -661,11 +699,10 @@ pub async fn add_random_members( for client in &clients { let commit = commit_output.welcome_messages[0].clone(); - let group = client + let (group, _) = client .join_group(Some(tree_data.clone()), &commit) .await - .unwrap() - .0; + .unwrap(); groups.push(group); } diff --git a/mls-rs/src/group/key_schedule.rs b/mls-rs/src/group/key_schedule.rs index 77c53352..a91f77f5 100644 --- a/mls-rs/src/group/key_schedule.rs +++ b/mls-rs/src/group/key_schedule.rs @@ -587,10 +587,7 @@ mod tests { #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))] use crate::{ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider}, - group::{ - key_schedule::KeyScheduleDerivationResult, test_utils::random_bytes, InitSecret, - PskSecret, - }, + group::{key_schedule::KeyScheduleDerivationResult, test_utils::random_bytes, InitSecret}, }; #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))] @@ -603,6 +600,9 @@ mod tests { use super::test_utils::get_test_key_schedule; use super::KeySchedule; + #[cfg(feature = "psk")] + use crate::psk::secret::PskSecret; + #[derive(serde::Deserialize, serde::Serialize)] struct TestCase { cipher_suite: u16, diff --git a/mls-rs/src/group/message_processor.rs b/mls-rs/src/group/message_processor.rs index 720f4d2c..8b983cdd 100644 --- a/mls-rs/src/group/message_processor.rs +++ b/mls-rs/src/group/message_processor.rs @@ -16,6 +16,10 @@ use super::{ transcript_hash::InterimTranscriptHash, validate_group_info_member, GroupContext, GroupInfo, ReInitProposal, RemoveProposal, Welcome, }; + +#[cfg(feature = "psk")] +use crate::psk::secret::PskSecretInput; + use crate::{ client::MlsError, key_package::validate_key_package_properties, @@ -37,7 +41,6 @@ use core::fmt::{self, Debug}; use mls_rs_core::{ identity::{IdentityProvider, MemberValidationContext}, protocol_version::ProtocolVersion, - psk::PreSharedKeyStorage, }; #[cfg(feature = "by_ref_proposal")] @@ -477,7 +480,6 @@ pub(crate) trait MessageProcessor: Send + Sync + Sized { type MlsRules: MlsRules; type IdentityProvider: IdentityProvider; type CipherSuiteProvider: CipherSuiteProvider; - type PreSharedKeyStorage: PreSharedKeyStorage; async fn process_incoming_message( &mut self, @@ -649,7 +651,6 @@ pub(crate) trait MessageProcessor: Send + Sync + Sized { fn group_state_mut(&mut self) -> &mut GroupState; fn identity_provider(&self) -> Self::IdentityProvider; fn cipher_suite_provider(&self) -> Self::CipherSuiteProvider; - fn psk_storage(&self) -> Self::PreSharedKeyStorage; fn removal_proposal( &self, @@ -796,6 +797,7 @@ pub(crate) trait MessageProcessor: Send + Sync + Sized { interim_transcript_hash: InterimTranscriptHash, confirmation_tag: &ConfirmationTag, provisional_public_state: ProvisionalState, + #[cfg(feature = "psk")] psks: &[PskSecretInput], ) -> Result<(), MlsError>; } diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 1b81ac7f..9d3e9219 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -8,6 +8,8 @@ use core::fmt::{self, Debug}; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::error::IntoAnyError; use mls_rs_core::identity::MemberValidationContext; +#[cfg(feature = "psk")] +use mls_rs_core::psk::PreSharedKey; use mls_rs_core::secret::Secret; use mls_rs_core::time::MlsTime; use snapshot::PendingCommitSnapshot; @@ -20,7 +22,6 @@ use crate::extension::RatchetTreeExt; use crate::identity::SigningIdentity; use crate::key_package::{KeyPackage, KeyPackageRef}; use crate::protocol_version::ProtocolVersion; -use crate::psk::secret::PskSecret; use crate::psk::PreSharedKeyID; use crate::signer::Signable; use crate::tree_kem::hpke_encryption::HpkeEncryptable; @@ -49,8 +50,8 @@ pub use self::resumption::ReinitClient; #[cfg(feature = "psk")] use crate::psk::{ - resolver::PskResolver, secret::PskSecretInput, ExternalPskId, JustPreSharedKeyID, PskGroupId, - ResumptionPSKUsage, ResumptionPsk, + secret::{PskSecret, PskSecretInput}, + ExternalPskId, JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk, }; #[cfg(feature = "private_message")] @@ -1395,37 +1396,6 @@ where a.state == b.state && a.key_schedule == b.key_schedule && a.epoch_secrets == b.epoch_secrets } - #[cfg(feature = "psk")] - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - async fn get_psk( - &self, - psks: &[ProposalInfo], - ) -> Result<(PskSecret, Vec), MlsError> { - if let Some(psk) = self.previous_psk.clone() { - // TODO consider throwing error if psks not empty - let psk_id = vec![psk.id.clone()]; - let psk = PskSecret::calculate(&[psk], &self.cipher_suite_provider()).await?; - - Ok((psk, psk_id)) - } else { - let psks = psks - .iter() - .map(|psk| psk.proposal.psk.clone()) - .collect::>(); - - let psk = PskResolver { - group_context: Some(self.context()), - current_epoch: Some(&self.epoch_secrets), - prior_epochs: Some(&self.state_repo), - psk_store: &self.config.secret_store(), - } - .resolve_to_secret(&psks, &self.cipher_suite_provider()) - .await?; - - Ok((psk, psks)) - } - } - #[cfg(feature = "private_message")] pub(crate) fn encryption_options(&self) -> Result { self.config @@ -1434,11 +1404,6 @@ where .map_err(|e| MlsError::MlsRulesError(e.into_any_error())) } - #[cfg(not(feature = "psk"))] - fn get_psk(&self) -> PskSecret { - PskSecret::new(&self.cipher_suite_provider()) - } - #[cfg(feature = "secret_tree_access")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[inline(never)] @@ -1473,47 +1438,25 @@ where } impl Group { + // FIXME: This is temporary until we get rid of the group state storage #[cfg(feature = "psk")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub(crate) async fn psk_secret( - config: &C, - cipher_suite_provider: &CS, - psks: &[PreSharedKeyID], - additional_psk: Option, - ) -> Result { - if let Some(psk) = additional_psk { - let psk_id = psks.first().ok_or(MlsError::UnexpectedPskId)?; - - match &psk_id.key_id { - JustPreSharedKeyID::Resumption(r) if r.usage != ResumptionPSKUsage::Application => { - Ok(()) - } - _ => Err(MlsError::UnexpectedPskId), - }?; - - let mut psk = psk; - psk.id.psk_nonce = psk_id.psk_nonce.clone(); - PskSecret::calculate(&[psk], cipher_suite_provider).await - } else { - PskResolver::<::GroupStateStorage, ::PskStore> { - group_context: None, - current_epoch: None, - prior_epochs: None, - psk_store: &config.secret_store(), - } - .resolve_to_secret(psks, cipher_suite_provider) - .await + pub async fn resumption_secret(&self, epoch: u64) -> Result { + if epoch == self.current_epoch() { + return Ok(self.epoch_secrets.resumption_secret.clone()); } - } - #[cfg(not(feature = "psk"))] - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub(crate) async fn psk_secret( - _config: &C, - cipher_suite_provider: &CS, - _psks: &[PreSharedKeyID], - ) -> Result { - Ok(PskSecret::new(cipher_suite_provider)) + #[cfg(feature = "prior_epoch")] + let res = self + .state_repo + .resumption_secret(epoch) + .await? + .ok_or_else(|| MlsError::EpochNotFound); + + #[cfg(not(feature = "prior_epoch"))] + let res = Err(MlsError::EpochNotFound); + + res } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] @@ -1625,7 +1568,6 @@ where { type MlsRules = C::MlsRules; type IdentityProvider = C::IdentityProvider; - type PreSharedKeyStorage = C::PskStore; type OutputType = ReceivedMessage; type CipherSuiteProvider = ::CipherSuiteProvider; @@ -1720,6 +1662,7 @@ where interim_transcript_hash: InterimTranscriptHash, confirmation_tag: &ConfirmationTag, provisional_state: ProvisionalState, + #[cfg(feature = "psk")] psks: &[PskSecretInput], ) -> Result<(), MlsError> { let commit_secret = if let Some(secrets) = secrets { self.private_tree = secrets.0; @@ -1747,12 +1690,16 @@ where }; #[cfg(feature = "psk")] - let (psk, _) = self - .get_psk(&provisional_state.applied_proposals.psks) - .await?; + let psk = { + if let Some(psk) = self.previous_psk.clone() { + PskSecret::calculate(&[psk], &self.cipher_suite_provider()).await + } else { + PskSecret::calculate(psks, &self.cipher_suite_provider()).await + } + }?; #[cfg(not(feature = "psk"))] - let psk = self.get_psk(); + let psk = crate::psk::secret::PskSecret::new(&self.cipher_suite_provider); let key_schedule_result = KeySchedule::from_key_schedule( &key_schedule, @@ -1807,10 +1754,6 @@ where self.config.identity_provider() } - fn psk_storage(&self) -> Self::PreSharedKeyStorage { - self.config.secret_store() - } - fn group_state(&self) -> &GroupState { &self.state } @@ -2204,6 +2147,7 @@ mod tests { let (mut bob_group, _) = bob_client .join_group(None, &commit_output.welcome_messages[0]) .await?; + // This no longer deletes the key package bob_group.write_to_storage()?; @@ -2435,7 +2379,7 @@ mod tests { let group = test_group(protocol_version, cipher_suite).await; let info = group - .group_info_message(false) + .group_info_message(true) .await .unwrap() .into_group_info() @@ -2451,10 +2395,11 @@ mod tests { group.group.config, info_msg, ) - .await - .map(|_| {}); + .build() + .await; - assert_matches!(res, Err(MlsError::MissingExternalPubExtension)); + // assert_matches! needs Debug which Group doesn't have + assert!(matches!(res, Err(MlsError::MissingExternalPubExtension))); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] @@ -2477,7 +2422,6 @@ mod tests { test_client .external_commit_builder(commit_output.external_commit_group_info.unwrap()) - .await .unwrap() .build() .await @@ -2591,7 +2535,6 @@ mod tests { .await .unwrap(), ) - .await .unwrap() .build() .await @@ -2638,7 +2581,6 @@ mod tests { .await .unwrap(), ) - .await .unwrap() .with_tree_data(alice_group.export_tree().into_owned()) .build() @@ -3712,26 +3654,123 @@ mod tests { test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await; let psk_id = ExternalPskId::new(vec![0]); - let psk = PreSharedKey::from(vec![0]); + let psk = PreSharedKey::from(vec![1]); - alice - .config - .secret_store() - .insert(psk_id.clone(), psk.clone()); + let commit = alice + .commit_builder() + .add_member(key_pkg) + .unwrap() + .add_external_psk(psk_id.clone(), psk.clone()) + .unwrap() + .build() + .await + .unwrap(); + + bob.join_group_custom(None, &commit.welcome_messages[0], |joiner| { + joiner.with_external_psk(psk_id, psk) + }) + .await + .unwrap(); + } + + #[cfg(feature = "psk")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn can_process_commit_with_psk_by_value() { + let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; + let (mut bob, _) = alice.join("bob").await; - bob.config.secret_store().insert(psk_id.clone(), psk); + let psk_id_external = ExternalPskId::new(vec![0]); + let psk_external = PreSharedKey::from(vec![1]); + let psk_epoch = bob.context().epoch; + + let psk_id_resumption = ResumptionPsk { + usage: ResumptionPSKUsage::Application, + psk_group_id: bob.context().group_id.clone().into(), + psk_epoch, + }; + + let psk_resumption = bob.resumption_secret(psk_epoch).await.unwrap(); let commit = alice .commit_builder() - .add_member(key_pkg) + .add_external_psk(psk_id_external.clone(), psk_external.clone()) .unwrap() - .add_external_psk(psk_id) + .add_resumption_psk(psk_epoch, psk_resumption.clone()) .unwrap() .build() .await .unwrap(); - bob.join_group(None, &commit.welcome_messages[0]) + bob.commit_processor(commit.commit_message) + .await + .unwrap() + .with_external_psk(psk_id_external, psk_external) + .with_resumption_psk(psk_id_resumption, psk_resumption) + .process() + .await + .unwrap(); + } + + #[cfg(feature = "psk")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn can_process_commit_with_psk_by_reference() { + let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; + let (mut bob, _) = alice.join("bob").await; + + let psk_id_external1 = ExternalPskId::new(vec![0]); + let psk_id_external2 = ExternalPskId::new(vec![1]); + let psk_external = PreSharedKey::from(vec![2]); + let psk_epoch = bob.context().epoch; + + let psk_id_resumption = ResumptionPsk { + usage: ResumptionPSKUsage::Application, + psk_group_id: bob.context().group_id.clone().into(), + psk_epoch, + }; + + let psk_resumption = bob.resumption_secret(psk_epoch).await.unwrap(); + + let psk_external_proposal = alice + .propose_external_psk(psk_id_external1.clone(), Vec::new()) + .await + .unwrap(); + + let psk_resumption_proposal = alice + .propose_resumption_psk(psk_epoch, Vec::new()) + .await + .unwrap(); + + let commit = alice + .commit_builder() + .apply_external_psk(psk_id_external1.clone(), psk_external.clone()) + .add_external_psk(psk_id_external2.clone(), psk_external.clone()) + .unwrap() + .apply_resumption_psk(psk_id_resumption.clone(), psk_resumption.clone()) + .build() + .await + .unwrap(); + + let commit_changes = alice.apply_pending_commit().await.unwrap(); + + // Make sure the proposals are actually in the commit to verify the rest of the test is + // working as expected + assert_matches!(commit_changes.effect, CommitEffect::NewEpoch(epoch) if epoch.unused_proposals.is_empty()); + + bob.process_incoming_message(psk_external_proposal) + .await + .unwrap(); + + bob.process_incoming_message(psk_resumption_proposal) + .await + .unwrap(); + + bob.commit_processor(commit.commit_message) + .await + .unwrap() + .with_external_psk(psk_id_external2.clone(), psk_external.clone()) + .with_external_psk(psk_id_external1, psk_external) + .with_resumption_psk(psk_id_resumption, psk_resumption) + .process() .await .unwrap(); } @@ -3797,11 +3836,10 @@ mod tests { alice.apply_pending_commit().await.unwrap(); - let mut bob = bob_client + let (mut bob, _) = bob_client .join_group(None, &commit.welcome_messages[0]) .await - .unwrap() - .0; + .unwrap(); bob.write_to_storage().await.unwrap(); @@ -3955,7 +3993,6 @@ mod tests { let commit = client_with_custom_rules(b"bob", mls_rules) .await .external_commit_builder(group_info) - .await .unwrap() .with_custom_proposal(CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![])) .build() @@ -3989,10 +4026,8 @@ mod tests { let (_, commit) = client_with_custom_rules(b"bob", mls_rules) .await .external_commit_builder(group_info) - .await .unwrap() .with_received_custom_proposal(by_ref) - .unwrap() .build() .await .unwrap(); diff --git a/mls-rs/src/group/proposal.rs b/mls-rs/src/group/proposal.rs index 1aab618f..6913b761 100644 --- a/mls-rs/src/group/proposal.rs +++ b/mls-rs/src/group/proposal.rs @@ -22,7 +22,7 @@ pub use mls_rs_core::extension::ExtensionList; pub use mls_rs_core::group::ProposalType; #[cfg(feature = "psk")] -use crate::psk::{ExternalPskId, JustPreSharedKeyID, PreSharedKeyID}; +use crate::psk::{ExternalPskId, JustPreSharedKeyID, PreSharedKeyID, ResumptionPsk}; #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] @@ -160,6 +160,13 @@ impl PreSharedKeyProposal { JustPreSharedKeyID::Resumption(_) => None, } } + + pub fn resumption_psk_id(&self) -> Option<&ResumptionPsk> { + match self.psk.key_id { + JustPreSharedKeyID::External(_) => None, + JustPreSharedKeyID::Resumption(ref ext) => Some(ext), + } + } } #[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)] diff --git a/mls-rs/src/group/proposal_cache.rs b/mls-rs/src/group/proposal_cache.rs index 7d62cf1c..dc608129 100644 --- a/mls-rs/src/group/proposal_cache.rs +++ b/mls-rs/src/group/proposal_cache.rs @@ -5,9 +5,7 @@ use alloc::vec::Vec; use super::{ - message_processor::ProvisionalState, - mls_rules::{CommitDirection, CommitSource}, - GroupState, ProposalOrRef, + message_processor::ProvisionalState, mls_rules::CommitSource, GroupState, ProposalOrRef, }; use crate::{ client::MlsError, @@ -18,12 +16,16 @@ use crate::{ time::MlsTime, }; +#[cfg(feature = "psk")] +use crate::psk::JustPreSharedKeyID; + #[cfg(feature = "by_ref_proposal")] use crate::{ group::{ message_hash::MessageHash, Proposal, ProposalMessageDescription, ProposalRef, ProtocolVersion, }, + mls_rules::CommitDirection, MlsMessage, }; @@ -32,9 +34,7 @@ use crate::tree_kem::leaf_node::LeafNode; #[cfg(feature = "by_ref_proposal")] use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; -use mls_rs_core::{ - crypto::CipherSuiteProvider, identity::IdentityProvider, psk::PreSharedKeyStorage, -}; +use mls_rs_core::{crypto::CipherSuiteProvider, identity::IdentityProvider}; #[cfg(feature = "by_ref_proposal")] use core::fmt::{self, Debug}; @@ -223,7 +223,7 @@ impl GroupState { #[inline(never)] #[allow(clippy::too_many_arguments)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub(crate) async fn apply_resolved( + pub(crate) async fn apply_resolved( &self, proposals: ProposalBundle, external_leaf: Option<&LeafNode>, @@ -231,14 +231,12 @@ impl GroupState { cipher_suite_provider: &CSP, commit_time: Option, #[cfg(feature = "by_ref_proposal")] direction: CommitDirection, - #[cfg(not(feature = "by_ref_proposal"))] _: CommitDirection, - psk_storage: &PSK, + #[cfg(feature = "psk")] psks: &[JustPreSharedKeyID], sender: &CommitSource, ) -> Result where C: IdentityProvider, CSP: CipherSuiteProvider, - PSK: PreSharedKeyStorage, { #[cfg(feature = "by_ref_proposal")] let all_proposals = proposals.clone(); @@ -255,7 +253,8 @@ impl GroupState { &self.context, external_leaf, identity_provider, - psk_storage, + #[cfg(feature = "psk")] + psks, ); #[cfg(feature = "by_ref_proposal")] @@ -337,7 +336,6 @@ fn unused_proposals( pub(crate) mod test_utils { use mls_rs_core::{ crypto::CipherSuiteProvider, extension::ExtensionList, identity::IdentityProvider, - psk::PreSharedKeyStorage, }; use crate::{ @@ -353,11 +351,11 @@ pub(crate) mod test_utils { }, identity::{basic::BasicIdentityProvider, test_utils::BasicWithCustomProvider}, mls_rules::{CommitSource, ProposalSource}, - psk::AlwaysFoundPskStorage, }; - use super::{CachedProposal, MlsError, ProposalCache}; + use super::{CachedProposal, JustPreSharedKeyID, MlsError, ProposalCache}; + use alloc::vec; use alloc::vec::Vec; impl CachedProposal { @@ -367,7 +365,7 @@ pub(crate) mod test_utils { } #[derive(Debug)] - pub(crate) struct CommitReceiver<'a, C, P, CSP> { + pub(crate) struct CommitReceiver<'a, C, CSP> { tree: &'a TreeKemPublic, sender: Sender, receiver: LeafIndex, @@ -375,10 +373,10 @@ pub(crate) mod test_utils { identity_provider: C, cipher_suite_provider: CSP, group_context_extensions: ExtensionList, - with_psk_storage: P, + psks: Vec, } - impl<'a, CSP> CommitReceiver<'a, BasicWithCustomProvider, AlwaysFoundPskStorage, CSP> { + impl<'a, CSP> CommitReceiver<'a, BasicWithCustomProvider, CSP> { pub fn new( tree: &'a TreeKemPublic, sender: S, @@ -395,20 +393,19 @@ pub(crate) mod test_utils { cache: make_proposal_cache(), identity_provider: BasicWithCustomProvider::new(BasicIdentityProvider), group_context_extensions: Default::default(), - with_psk_storage: AlwaysFoundPskStorage, + psks: vec![], cipher_suite_provider, } } } - impl<'a, C, P, CSP> CommitReceiver<'a, C, P, CSP> + impl<'a, C, CSP> CommitReceiver<'a, C, CSP> where C: IdentityProvider, - P: PreSharedKeyStorage, CSP: CipherSuiteProvider, { #[cfg(feature = "by_ref_proposal")] - pub fn with_identity_provider(self, validator: V) -> CommitReceiver<'a, V, P, CSP> + pub fn with_identity_provider(self, validator: V) -> CommitReceiver<'a, V, CSP> where V: IdentityProvider, { @@ -419,15 +416,12 @@ pub(crate) mod test_utils { cache: self.cache, identity_provider: validator, group_context_extensions: self.group_context_extensions, - with_psk_storage: self.with_psk_storage, + psks: self.psks, cipher_suite_provider: self.cipher_suite_provider, } } - pub fn with_psk_storage(self, v: V) -> CommitReceiver<'a, C, V, CSP> - where - V: PreSharedKeyStorage, - { + pub fn with_psks(self, v: Vec) -> CommitReceiver<'a, C, CSP> { CommitReceiver { tree: self.tree, sender: self.sender, @@ -435,7 +429,7 @@ pub(crate) mod test_utils { cache: self.cache, identity_provider: self.identity_provider, group_context_extensions: self.group_context_extensions, - with_psk_storage: v, + psks: v, cipher_suite_provider: self.cipher_suite_provider, } } @@ -471,7 +465,7 @@ pub(crate) mod test_utils { &self.identity_provider, &self.cipher_suite_provider, self.tree, - &self.with_psk_storage, + &self.psks, ) .await } @@ -484,7 +478,7 @@ pub(crate) mod test_utils { impl ProposalCache { #[allow(clippy::too_many_arguments)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub async fn resolve_for_commit_default( + pub async fn resolve_for_commit_default( &self, sender: Sender, proposal_list: Vec, @@ -493,11 +487,10 @@ pub(crate) mod test_utils { identity_provider: &C, cipher_suite_provider: &CSP, public_tree: &TreeKemPublic, - psk_storage: &P, + #[cfg(feature = "psk")] psks: &[JustPreSharedKeyID], ) -> Result where C: IdentityProvider, - P: PreSharedKeyStorage, CSP: CipherSuiteProvider, { let mut context = @@ -523,8 +516,10 @@ pub(crate) mod test_utils { identity_provider, cipher_suite_provider, None, + #[cfg(feature = "by_ref_proposal")] CommitDirection::Receive, - psk_storage, + #[cfg(feature = "psk")] + psks, &committer, ) .await @@ -532,7 +527,7 @@ pub(crate) mod test_utils { #[allow(clippy::too_many_arguments)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub async fn prepare_commit_default( + pub async fn prepare_commit_default( &self, sender: Sender, additional_proposals: Vec, @@ -541,11 +536,10 @@ pub(crate) mod test_utils { cipher_suite_provider: &CSP, public_tree: &TreeKemPublic, external_leaf: Option<&LeafNode>, - psk_storage: &P, + #[cfg(feature = "psk")] psks: &[JustPreSharedKeyID], ) -> Result where C: IdentityProvider, - P: PreSharedKeyStorage, CSP: CipherSuiteProvider, { let state = GroupState::new( @@ -572,8 +566,10 @@ pub(crate) mod test_utils { identity_provider, cipher_suite_provider, None, + #[cfg(feature = "by_ref_proposal")] CommitDirection::Send, - psk_storage, + #[cfg(feature = "psk")] + psks, &committer, ) .await @@ -616,7 +612,6 @@ mod tests { identity::basic::BasicIdentityProvider, identity::test_utils::{get_test_signing_identity, BasicWithCustomProvider}, key_package::test_utils::test_key_package, - psk::AlwaysFoundPskStorage, tree_kem::{ leaf_node::{ test_utils::{ @@ -650,14 +645,12 @@ mod tests { use crate::group::proposal::CustomProposal; use assert_matches::assert_matches; - use core::convert::Infallible; use itertools::Itertools; use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider}; use mls_rs_core::extension::ExtensionList; use mls_rs_core::group::{Capabilities, ProposalType}; use mls_rs_core::identity::IdentityProvider; use mls_rs_core::protocol_version::ProtocolVersion; - use mls_rs_core::psk::{PreSharedKey, PreSharedKeyStorage}; use mls_rs_core::{ extension::MlsExtension, identity::{Credential, CredentialType, CustomCredential}, @@ -945,7 +938,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[], ) .await .unwrap(); @@ -983,7 +976,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[], ) .await .unwrap(); @@ -1032,7 +1025,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[], ) .await; @@ -1065,7 +1058,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[], ) .await .unwrap(); @@ -1106,7 +1099,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[], ) .await .unwrap(); @@ -1149,7 +1142,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[], ) .await .unwrap(); @@ -1202,7 +1195,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[], ) .await .unwrap(); @@ -1221,7 +1214,7 @@ mod tests { &BasicIdentityProvider, &cipher_suite_provider, &tree, - &AlwaysFoundPskStorage, + &[], ) .await .unwrap(); @@ -1237,8 +1230,10 @@ mod tests { let (alice, tree) = new_tree("alice").await; let cache = make_proposal_cache(); + let psk_id = b"ted"; + let proposal = Proposal::Psk(make_external_psk( - b"ted", + psk_id, crate::psk::PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap(), )); @@ -1251,7 +1246,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[JustPreSharedKeyID::External(psk_id.to_vec().into())], ) .await; @@ -1298,7 +1293,7 @@ mod tests { &BasicIdentityProvider, &cipher_suite_provider, public_tree, - &AlwaysFoundPskStorage, + &[], ) .await; @@ -1335,7 +1330,7 @@ mod tests { &BasicIdentityProvider, &cipher_suite_provider, public_tree, - &AlwaysFoundPskStorage, + &[], ) .await; @@ -1367,7 +1362,7 @@ mod tests { &BasicIdentityProvider, &cipher_suite_provider, public_tree, - &AlwaysFoundPskStorage, + &[], ) .await; @@ -1400,7 +1395,7 @@ mod tests { &BasicIdentityProvider, &cipher_suite_provider, public_tree, - &AlwaysFoundPskStorage, + &[], ) .await } @@ -1466,7 +1461,7 @@ mod tests { &BasicIdentityProvider, &cipher_suite_provider, &public_tree, - &AlwaysFoundPskStorage, + &[], ) .await; @@ -1514,7 +1509,7 @@ mod tests { &BasicIdentityProvider, &cipher_suite_provider, &public_tree, - &AlwaysFoundPskStorage, + &[], ) .await; @@ -1562,7 +1557,7 @@ mod tests { &BasicIdentityProvider, &cipher_suite_provider, &public_tree, - &AlwaysFoundPskStorage, + &[], ) .await; @@ -1632,7 +1627,7 @@ mod tests { &BasicIdentityProvider, &cipher_suite_provider, public_tree, - &AlwaysFoundPskStorage, + &[], ) .await; @@ -1660,7 +1655,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[], ) .await .unwrap(); @@ -1694,7 +1689,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[], ) .await .unwrap(); @@ -1742,7 +1737,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[], ) .await .unwrap(); @@ -1756,10 +1751,11 @@ mod tests { let (alice, tree) = new_tree("alice").await; let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let cache = make_proposal_cache(); + let psk_id = JustPreSharedKeyID::External(ExternalPskId::new(vec![])); let psk = Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID::new( - JustPreSharedKeyID::External(ExternalPskId::new(vec![])), + psk_id.clone(), &test_cipher_suite_provider(TEST_CIPHER_SUITE), ) .unwrap(), @@ -1778,7 +1774,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[psk_id], ) .await .unwrap(); @@ -1808,7 +1804,7 @@ mod tests { &cipher_suite_provider, &tree, None, - &AlwaysFoundPskStorage, + &[], ) .await .unwrap(); @@ -1817,17 +1813,16 @@ mod tests { } #[derive(Debug)] - struct CommitSender<'a, C, P, CSP> { + struct CommitSender<'a, C, CSP> { cipher_suite_provider: CSP, tree: &'a TreeKemPublic, sender: LeafIndex, cache: ProposalCache, additional_proposals: Vec, identity_provider: C, - psk_storage: P, + psks: Vec, } - - impl<'a, CSP> CommitSender<'a, BasicWithCustomProvider, AlwaysFoundPskStorage, CSP> { + impl<'a, CSP> CommitSender<'a, BasicWithCustomProvider, CSP> { fn new(tree: &'a TreeKemPublic, sender: LeafIndex, cipher_suite_provider: CSP) -> Self { Self { tree, @@ -1835,20 +1830,19 @@ mod tests { cache: make_proposal_cache(), additional_proposals: Vec::new(), identity_provider: BasicWithCustomProvider::new(BasicIdentityProvider::new()), - psk_storage: AlwaysFoundPskStorage, + psks: vec![], cipher_suite_provider, } } } - impl<'a, C, P, CSP> CommitSender<'a, C, P, CSP> + impl<'a, C, CSP> CommitSender<'a, C, CSP> where C: IdentityProvider, - P: PreSharedKeyStorage, CSP: CipherSuiteProvider, { #[cfg(feature = "by_ref_proposal")] - fn with_identity_provider(self, identity_provider: V) -> CommitSender<'a, V, P, CSP> + fn with_identity_provider(self, identity_provider: V) -> CommitSender<'a, V, CSP> where V: IdentityProvider, { @@ -1859,7 +1853,7 @@ mod tests { sender: self.sender, cache: self.cache, additional_proposals: self.additional_proposals, - psk_storage: self.psk_storage, + psks: self.psks, } } @@ -1879,19 +1873,8 @@ mod tests { self } - fn with_psk_storage(self, v: V) -> CommitSender<'a, C, V, CSP> - where - V: PreSharedKeyStorage, - { - CommitSender { - tree: self.tree, - sender: self.sender, - cache: self.cache, - additional_proposals: self.additional_proposals, - identity_provider: self.identity_provider, - psk_storage: v, - cipher_suite_provider: self.cipher_suite_provider, - } + fn with_psks(self, psks: Vec) -> CommitSender<'a, C, CSP> { + CommitSender { psks, ..self } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] @@ -1906,7 +1889,7 @@ mod tests { &self.cipher_suite_provider, self.tree, None, - &self.psk_storage, + &self.psks, ) .await?; @@ -2779,6 +2762,7 @@ mod tests { alice, test_cipher_suite_provider(TEST_CIPHER_SUITE), ) + .with_psks(vec![JustPreSharedKeyID::External(b"foo".to_vec().into())]) .receive([psk_proposal.clone(), psk_proposal]) .await; @@ -2792,6 +2776,7 @@ mod tests { let psk_proposal = Proposal::Psk(new_external_psk(b"foo")); let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) + .with_psks(vec![JustPreSharedKeyID::External(b"foo".to_vec().into())]) .with_additional([psk_proposal.clone(), psk_proposal]) .send() .await; @@ -2814,6 +2799,7 @@ mod tests { let processed_proposals = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) + .with_psks(vec![JustPreSharedKeyID::External(b"foo".to_vec().into())]) .cache( proposal_info[0].proposal_ref().unwrap().clone(), proposal.clone(), @@ -3674,21 +3660,6 @@ mod tests { assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); } - #[cfg(feature = "psk")] - #[derive(Debug)] - struct AlwaysNotFoundPskStorage; - - #[cfg(feature = "psk")] - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - #[cfg_attr(mls_build_async, maybe_async::must_be_async)] - impl PreSharedKeyStorage for AlwaysNotFoundPskStorage { - type Error = Infallible; - - async fn get(&self, _: &ExternalPskId) -> Result, Self::Error> { - Ok(None) - } - } - #[cfg(feature = "psk")] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn receiving_external_psk_with_unknown_id_fails() { @@ -3700,7 +3671,7 @@ mod tests { alice, test_cipher_suite_provider(TEST_CIPHER_SUITE), ) - .with_psk_storage(AlwaysNotFoundPskStorage) + .with_psks(vec![]) .receive([Proposal::Psk(new_external_psk(b"abc"))]) .await; @@ -3713,7 +3684,7 @@ mod tests { let (alice, tree) = new_tree("alice").await; let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) - .with_psk_storage(AlwaysNotFoundPskStorage) + .with_psks(vec![]) .with_additional([Proposal::Psk(new_external_psk(b"abc"))]) .send() .await; @@ -3730,7 +3701,7 @@ mod tests { let processed_proposals = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) - .with_psk_storage(AlwaysNotFoundPskStorage) + .with_psks(vec![]) .cache( proposal_info.proposal_ref().unwrap().clone(), proposal.clone(), @@ -3792,7 +3763,8 @@ mod tests { committer, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE), - ); + ) + .with_psks(vec![JustPreSharedKeyID::External(b"ted".to_vec().into())]); #[cfg(feature = "by_ref_proposal")] let extensions: ExtensionList = diff --git a/mls-rs/src/group/proposal_filter/filtering.rs b/mls-rs/src/group/proposal_filter/filtering.rs index 7ce85663..60ec8e36 100644 --- a/mls-rs/src/group/proposal_filter/filtering.rs +++ b/mls-rs/src/group/proposal_filter/filtering.rs @@ -32,7 +32,6 @@ use alloc::vec::Vec; use mls_rs_core::{ error::IntoAnyError, identity::{IdentityProvider, MemberValidationContext}, - psk::PreSharedKeyStorage, }; #[cfg(not(any(mls_build_async, feature = "rayon")))] @@ -49,10 +48,9 @@ use {crate::iter::ParallelIteratorExt, rayon::prelude::*}; #[cfg(mls_build_async)] use futures::{StreamExt, TryStreamExt}; -impl ProposalApplier<'_, C, P, CSP> +impl ProposalApplier<'_, C, CSP> where C: IdentityProvider, - P: PreSharedKeyStorage, CSP: CipherSuiteProvider, { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] @@ -81,7 +79,8 @@ where strategy, self.cipher_suite_provider, &mut proposals, - self.psk_storage, + #[cfg(feature = "psk")] + self.psks, ) .await?; diff --git a/mls-rs/src/group/proposal_filter/filtering_common.rs b/mls-rs/src/group/proposal_filter/filtering_common.rs index 68ee32a9..94cd0758 100644 --- a/mls-rs/src/group/proposal_filter/filtering_common.rs +++ b/mls-rs/src/group/proposal_filter/filtering_common.rs @@ -28,7 +28,7 @@ use crate::extension::ExternalSendersExt; use mls_rs_core::{error::IntoAnyError, identity::MemberValidationContext}; use alloc::vec::Vec; -use mls_rs_core::{identity::IdentityProvider, psk::PreSharedKeyStorage}; +use mls_rs_core::identity::IdentityProvider; use crate::group::{ExternalInit, ProposalType, RemoveProposal}; @@ -45,13 +45,14 @@ use std::collections::HashSet; use super::filtering::{apply_strategy, filter_out_invalid_proposers, FilterStrategy}; #[derive(Debug)] -pub(crate) struct ProposalApplier<'a, C, P, CSP> { +pub(crate) struct ProposalApplier<'a, C, CSP> { pub original_tree: &'a TreeKemPublic, pub cipher_suite_provider: &'a CSP, pub original_context: &'a GroupContext, pub external_leaf: Option<&'a LeafNode>, pub identity_provider: &'a C, - pub psk_storage: &'a P, + #[cfg(feature = "psk")] + pub psks: &'a [JustPreSharedKeyID], } #[derive(Debug)] @@ -64,10 +65,9 @@ pub(crate) struct ApplyProposalsOutput { pub(crate) new_context_extensions: Option, } -impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP> +impl<'a, C, CSP> ProposalApplier<'a, C, CSP> where C: IdentityProvider, - P: PreSharedKeyStorage, CSP: CipherSuiteProvider, { #[allow(clippy::too_many_arguments)] @@ -77,7 +77,7 @@ where original_context: &'a GroupContext, external_leaf: Option<&'a LeafNode>, identity_provider: &'a C, - psk_storage: &'a P, + #[cfg(feature = "psk")] psks: &'a [JustPreSharedKeyID], ) -> Self { Self { original_tree, @@ -85,7 +85,8 @@ where original_context, external_leaf, identity_provider, - psk_storage, + #[cfg(feature = "psk")] + psks, } } @@ -156,7 +157,8 @@ where &mut proposals, #[cfg(not(feature = "by_ref_proposal"))] proposals, - self.psk_storage, + #[cfg(feature = "psk")] + self.psks, ) .await?; @@ -338,15 +340,14 @@ where #[cfg(feature = "psk")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] -pub(crate) async fn filter_out_invalid_psks( +pub(crate) async fn filter_out_invalid_psks( #[cfg(feature = "by_ref_proposal")] strategy: FilterStrategy, cipher_suite_provider: &CP, #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle, #[cfg(feature = "by_ref_proposal")] proposals: &mut ProposalBundle, - psk_storage: &P, + #[cfg(feature = "psk")] psks: &[crate::psk::JustPreSharedKeyID], ) -> Result<(), MlsError> where - P: PreSharedKeyStorage, CP: CipherSuiteProvider, { let kdf_extract_size = cipher_suite_provider.kdf_extract_size(); @@ -381,20 +382,13 @@ where #[cfg(not(feature = "std"))] let is_new_id = !ids_seen.contains(&p.proposal.psk); - let external_id_is_valid = match &p.proposal.psk.key_id { - JustPreSharedKeyID::External(id) => psk_storage - .contains(id) - .await - .map_err(|e| MlsError::PskStoreError(e.into_any_error())) - .and_then(|found| { - if found { - Ok(()) - } else { - Err(MlsError::MissingRequiredPsk) - } - }), - JustPreSharedKeyID::Resumption(_) => Ok(()), - }; + let has_required_psk_secret = psks + .contains(&p.proposal.psk.key_id) + .then_some(()) + .ok_or_else(|| MlsError::MissingRequiredPsk); + + #[cfg(not(feature = "psk"))] + let has_required_psk_secret = Ok(()); #[cfg(not(feature = "by_ref_proposal"))] if !valid { @@ -403,8 +397,8 @@ where return Err(MlsError::InvalidPskNonceLength); } else if !is_new_id { return Err(MlsError::DuplicatePskIds); - } else if external_id_is_valid.is_err() { - return external_id_is_valid; + } else if has_required_psk_secret.is_err() { + return has_required_psk_secret; } #[cfg(feature = "by_ref_proposal")] @@ -416,7 +410,7 @@ where } else if !is_new_id { Err(MlsError::DuplicatePskIds) } else { - external_id_is_valid + has_required_psk_secret }; if !apply_strategy(strategy, p.is_by_reference(), res)? { @@ -439,15 +433,13 @@ where #[cfg(not(feature = "psk"))] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] -pub(crate) async fn filter_out_invalid_psks( +pub(crate) async fn filter_out_invalid_psks( #[cfg(feature = "by_ref_proposal")] _: FilterStrategy, _: &CP, #[cfg(not(feature = "by_ref_proposal"))] _: &ProposalBundle, #[cfg(feature = "by_ref_proposal")] _: &mut ProposalBundle, - _: &P, ) -> Result<(), MlsError> where - P: PreSharedKeyStorage, CP: CipherSuiteProvider, { Ok(()) diff --git a/mls-rs/src/group/proposal_filter/filtering_lite.rs b/mls-rs/src/group/proposal_filter/filtering_lite.rs index cadd92e3..1924f3bb 100644 --- a/mls-rs/src/group/proposal_filter/filtering_lite.rs +++ b/mls-rs/src/group/proposal_filter/filtering_lite.rs @@ -17,10 +17,7 @@ use super::filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, Pro #[cfg(feature = "by_ref_proposal")] use {crate::extension::ExternalSendersExt, mls_rs_core::error::IntoAnyError}; -use mls_rs_core::{ - identity::{IdentityProvider, MemberValidationContext}, - psk::PreSharedKeyStorage, -}; +use mls_rs_core::identity::{IdentityProvider, MemberValidationContext}; #[cfg(feature = "custom_proposal")] use itertools::Itertools; @@ -42,10 +39,9 @@ use crate::group::{ #[cfg(all(feature = "std", feature = "psk"))] use std::collections::HashSet; -impl ProposalApplier<'_, C, P, CSP> +impl ProposalApplier<'_, C, CSP> where C: IdentityProvider, - P: PreSharedKeyStorage, CSP: CipherSuiteProvider, { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] @@ -56,7 +52,14 @@ where commit_time: Option, ) -> Result { filter_out_removal_of_committer(commit_sender, proposals)?; - filter_out_invalid_psks(self.cipher_suite_provider, proposals, self.psk_storage).await?; + + filter_out_invalid_psks( + self.cipher_suite_provider, + proposals, + #[cfg(feature = "psk")] + self.psks, + ) + .await?; #[cfg(feature = "by_ref_proposal")] filter_out_invalid_group_extensions(proposals, self.identity_provider, commit_time).await?; diff --git a/mls-rs/src/group/state_repo.rs b/mls-rs/src/group/state_repo.rs index 8d30a499..eef292b7 100644 --- a/mls-rs/src/group/state_repo.rs +++ b/mls-rs/src/group/state_repo.rs @@ -14,9 +14,6 @@ use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage}; use super::snapshot::Snapshot; -#[cfg(feature = "psk")] -use crate::group::ResumptionPsk; - #[cfg(feature = "psk")] use mls_rs_core::psk::PreSharedKey; @@ -80,23 +77,20 @@ impl GroupStateRepository { #[cfg(feature = "psk")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub async fn resumption_secret( - &self, - psk_id: &ResumptionPsk, - ) -> Result, MlsError> { + pub async fn resumption_secret(&self, epoch_id: u64) -> Result, MlsError> { // Search the local inserts cache if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) { - if psk_id.psk_epoch >= min { + if epoch_id >= min { return Ok(self .pending_commit .inserts - .get((psk_id.psk_epoch - min) as usize) + .get((epoch_id - min) as usize) .map(|e| e.secrets.resumption_secret.clone())); } } // Search the local updates cache - let maybe_pending = self.find_pending(psk_id.psk_epoch); + let maybe_pending = self.find_pending(epoch_id); if let Some(pending) = maybe_pending { return Ok(Some( @@ -109,7 +103,7 @@ impl GroupStateRepository { // Search the stored cache self.storage - .epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch) + .epoch(&self.group_id, epoch_id) .await .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))? .map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret)) @@ -224,7 +218,6 @@ mod tests { group::{ epoch::{test_utils::get_test_epoch_with_id, SenderDataSecret}, test_utils::{random_bytes, TEST_GROUP}, - PskGroupId, ResumptionPSKUsage, }, storage_provider::in_memory::InMemoryGroupStateStorage, }; @@ -272,14 +265,8 @@ mod tests { #[cfg(not(feature = "std"))] assert!(test_repo.storage.inner.lock().is_empty()); - let psk_id = ResumptionPsk { - psk_epoch: 0, - psk_group_id: PskGroupId(test_repo.group_id.clone()), - usage: ResumptionPSKUsage::Application, - }; - // Make sure you can recall an epoch sitting as a pending insert - let resumption = test_repo.resumption_secret(&psk_id).await.unwrap(); + let resumption = test_repo.resumption_secret(0).await.unwrap(); let prior_epoch = test_repo.get_epoch_mut(0).await.unwrap().cloned(); assert_eq!( @@ -349,13 +336,7 @@ mod tests { ); // Make sure you can access an epoch pending update - let psk_id = ResumptionPsk { - psk_epoch: 0, - psk_group_id: PskGroupId(test_repo.group_id.clone()), - usage: ResumptionPSKUsage::Application, - }; - - let owned = test_repo.resumption_secret(&psk_id).await.unwrap(); + let owned = test_repo.resumption_secret(0).await.unwrap(); assert_eq!(owned.as_ref(), Some(&to_update.secrets.resumption_secret)); // Write the update to storage diff --git a/mls-rs/src/group/test_utils.rs b/mls-rs/src/group/test_utils.rs index d8d1ec94..b2b077cf 100644 --- a/mls-rs/src/group/test_utils.rs +++ b/mls-rs/src/group/test_utils.rs @@ -416,7 +416,6 @@ impl GroupWithoutKeySchedule { impl MessageProcessor for GroupWithoutKeySchedule { type CipherSuiteProvider = as MessageProcessor>::CipherSuiteProvider; type OutputType = as MessageProcessor>::OutputType; - type PreSharedKeyStorage = as MessageProcessor>::PreSharedKeyStorage; type IdentityProvider = as MessageProcessor>::IdentityProvider; type MlsRules = as MessageProcessor>::MlsRules; @@ -437,10 +436,6 @@ impl MessageProcessor for GroupWithoutKeySchedule { self.inner.cipher_suite_provider() } - fn psk_storage(&self) -> Self::PreSharedKeyStorage { - self.inner.psk_storage() - } - fn removal_proposal( &self, provisional_state: &ProvisionalState, @@ -488,6 +483,7 @@ impl MessageProcessor for GroupWithoutKeySchedule { _interim_transcript_hash: InterimTranscriptHash, _confirmation_tag: &ConfirmationTag, provisional_public_state: ProvisionalState, + #[cfg(feature = "psk")] _psks: &[PskSecretInput], ) -> Result<(), MlsError> { self.provisional_public_state = Some(provisional_public_state); self.secrets = secrets; diff --git a/mls-rs/src/group_joiner.rs b/mls-rs/src/group_joiner.rs index a9fad728..2f5e8762 100644 --- a/mls-rs/src/group_joiner.rs +++ b/mls-rs/src/group_joiner.rs @@ -9,9 +9,6 @@ use mls_rs_core::{ protocol_version::ProtocolVersion, }; -#[cfg(feature = "psk")] -use mls_rs_core::psk::ExternalPskId; - use crate::{ client_config::ClientConfig, error::MlsError, @@ -27,7 +24,13 @@ use crate::{ }; #[cfg(feature = "psk")] -use crate::psk::{JustPreSharedKeyID, ResumptionPsk}; +use alloc::vec::Vec; + +#[cfg(feature = "psk")] +use crate::psk::{ + secret::PskSecretInput, ExternalPskId, JustPreSharedKeyID, PreSharedKey, ResumptionPSKUsage, + ResumptionPsk, +}; pub struct GroupJoiner<'a, 'b, C> { // Parsed data @@ -41,6 +44,8 @@ pub struct GroupJoiner<'a, 'b, C> { // Inputted by application tree: Option>, signer: Option, + #[cfg(feature = "psk")] + psks: Vec<(JustPreSharedKeyID, PreSharedKey)>, // Needed for reinit #[cfg(feature = "psk")] @@ -90,11 +95,21 @@ impl<'a, 'b, C: ClientConfig> GroupJoiner<'a, 'b, C> { } } - // TODO with_psks + #[cfg(feature = "psk")] + pub fn with_external_psk(mut self, id: ExternalPskId, psk: PreSharedKey) -> Self { + self.psks.push((JustPreSharedKeyID::External(id), psk)); + self + } + + #[cfg(feature = "psk")] + pub fn with_resumption_psk(mut self, id: ResumptionPsk, psk: PreSharedKey) -> Self { + self.psks.push((JustPreSharedKeyID::Resumption(id), psk)); + self + } // Reinit #[cfg(feature = "psk")] - pub(crate) fn additional_psk(self, additional_psk: crate::psk::secret::PskSecretInput) -> Self { + pub(crate) fn additional_psk(self, additional_psk: PskSecretInput) -> Self { Self { additional_psk: Some(additional_psk), ..self @@ -126,6 +141,8 @@ impl<'a, 'b, C: ClientConfig> GroupJoiner<'a, 'b, C> { signer, #[cfg(feature = "psk")] additional_psk: None, + #[cfg(feature = "psk")] + psks: Default::default(), }) } @@ -137,14 +154,47 @@ impl<'a, 'b, C: ClientConfig> GroupJoiner<'a, 'b, C> { let cipher_suite_provider = cipher_suite_provider(self.config.crypto_provider(), self.welcome.cipher_suite)?; - let psk_secret = Group::psk_secret( - &self.config, - &cipher_suite_provider, - &self.group_secrets.psks, - #[cfg(feature = "psk")] - self.additional_psk.take(), - ) - .await?; + #[cfg(feature = "psk")] + let psk_secret = if let Some(psk) = &self.additional_psk { + let psk_id = self + .group_secrets + .psks + .first() + .ok_or(MlsError::UnexpectedPskId)?; + + match &psk_id.key_id { + JustPreSharedKeyID::Resumption(r) if r.usage != ResumptionPSKUsage::Application => { + Ok(()) + } + _ => Err(MlsError::UnexpectedPskId), + }?; + + let mut psk = psk.clone(); + psk.id.psk_nonce = psk_id.psk_nonce.clone(); + PskSecret::calculate(&[psk], &cipher_suite_provider).await + } else { + let psk_inputs = self + .group_secrets + .psks + .iter() + .map(|id| { + self.psks + .iter() + .find_map(|(psk_id, psk)| { + (*psk_id == id.key_id).then(|| PskSecretInput { + id: id.clone(), + psk: psk.clone(), + }) + }) + .ok_or_else(|| MlsError::MissingRequiredPsk) + }) + .collect::, _>>()?; + + PskSecret::calculate(&psk_inputs, &cipher_suite_provider).await + }?; + + #[cfg(not(feature = "psk"))] + let psk_secret = PskSecret::new(&cipher_suite_provider); // From the joiner_secret in the decrypted GroupSecrets object and the PSKs specified in // the GroupSecrets, derive the welcome_secret and using that the welcome_key and @@ -265,36 +315,17 @@ mod tests { use mls_rs_core::psk::{ExternalPskId, PreSharedKey}; use crate::{ - client::test_utils::{ - test_client_with_key_pkg_custom, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION, - }, + client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, psk::{PskGroupId, ResumptionPSKUsage, ResumptionPsk}, - storage_provider::in_memory::InMemoryPreSharedKeyStorage, }; #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn outputs_correct_join_info() { - let mut psk_store = InMemoryPreSharedKeyStorage::default(); - - let (alice, _kp_alice) = test_client_with_key_pkg_custom( - TEST_PROTOCOL_VERSION, - TEST_CIPHER_SUITE, - "alice", - Default::default(), - Default::default(), - |c| c.0.psk_store = psk_store.clone(), - ) - .await; - - let (bob, kp_bob) = test_client_with_key_pkg_custom( - TEST_PROTOCOL_VERSION, - TEST_CIPHER_SUITE, - "bob", - Default::default(), - Default::default(), - |c| c.0.psk_store = psk_store.clone(), - ) - .await; + let (alice, _kp_alice) = + test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await; + + let (bob, kp_bob) = + test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await; let mut group_alice = alice .create_group(Default::default(), Default::default()) @@ -310,10 +341,13 @@ mod tests { .unwrap(); group_alice.apply_pending_commit().await.unwrap(); + let commit2 = group_alice.commit(vec![]).await.unwrap(); group_alice.apply_pending_commit().await.unwrap(); group_alice.write_to_storage().await.unwrap(); + let resumption_psk = group_alice.resumption_secret(1).await.unwrap(); + let mut group_bob = bob .join_group(None, &commit1.welcome_messages[0]) .await @@ -328,7 +362,7 @@ mod tests { group_bob.write_to_storage().await.unwrap(); let psk_id = ExternalPskId::new(b"123".into()); - psk_store.insert(psk_id.clone(), PreSharedKey::new(b"123".into())); + let psk_secret = PreSharedKey::new(b"456".into()); let kp_alice = alice .key_package_builder(None) @@ -337,16 +371,22 @@ mod tests { .await .unwrap(); - let commit = bob + let mut bob_group = bob .create_group(Default::default(), Default::default()) .await - .unwrap() + .unwrap(); + + let commit = bob_group .commit_builder() .add_member(kp_alice.key_package_message) .unwrap() - .add_external_psk(psk_id.clone()) + .add_external_psk(psk_id.clone(), psk_secret.clone()) .unwrap() - .add_resumption_psk_for_group(1, group_alice.group_id().to_vec()) + .add_resumption_psk_for_group( + 1, + group_alice.group_id().to_vec(), + resumption_psk.clone(), + ) .unwrap() .build() .await @@ -360,7 +400,10 @@ mod tests { let external_psks = joiner.required_external_psks().collect::>(); assert_eq!(external_psks, vec![&psk_id]); - let resumption_psks = joiner.required_resumption_psks().collect::>(); + let resumption_psks = joiner + .required_resumption_psks() + .cloned() + .collect::>(); let expected_resumption_psk = ResumptionPsk { usage: ResumptionPSKUsage::Application, @@ -368,8 +411,14 @@ mod tests { psk_epoch: 1, }; - assert_eq!(resumption_psks, vec![&expected_resumption_psk]); - + assert_eq!(resumption_psks, vec![expected_resumption_psk]); assert_eq!(joiner.cipher_suite(), TEST_CIPHER_SUITE); + + joiner + .with_resumption_psk(resumption_psks[0].clone(), resumption_psk) + .with_external_psk(psk_id.clone(), psk_secret) + .join() + .await + .unwrap(); } } diff --git a/mls-rs/src/psk.rs b/mls-rs/src/psk.rs index 148d100d..cbbd2558 100644 --- a/mls-rs/src/psk.rs +++ b/mls-rs/src/psk.rs @@ -4,16 +4,8 @@ use alloc::vec::Vec; -#[cfg(any(test, feature = "external_client"))] -use alloc::vec; - use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; -#[cfg(any(test, feature = "external_client"))] -use mls_rs_core::psk::PreSharedKeyStorage; - -#[cfg(any(test, feature = "external_client"))] -use core::convert::Infallible; use core::fmt::{self, Debug}; #[cfg(feature = "psk")] @@ -22,8 +14,6 @@ use crate::{client::MlsError, CipherSuiteProvider}; #[cfg(feature = "psk")] use mls_rs_core::error::IntoAnyError; -#[cfg(feature = "psk")] -pub(crate) mod resolver; pub(crate) mod secret; pub use mls_rs_core::psk::{ExternalPskId, PreSharedKey}; @@ -76,6 +66,12 @@ impl Debug for PskGroupId { } } +impl From> for PskGroupId { + fn from(value: Vec) -> Self { + Self(value) + } +} + #[derive(Clone, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -131,21 +127,6 @@ struct PSKLabel<'a> { count: u16, } -#[cfg(any(test, feature = "external_client"))] -#[derive(Clone, Copy, Debug)] -pub(crate) struct AlwaysFoundPskStorage; - -#[cfg(any(test, feature = "external_client"))] -#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] -#[cfg_attr(mls_build_async, maybe_async::must_be_async)] -impl PreSharedKeyStorage for AlwaysFoundPskStorage { - type Error = Infallible; - - async fn get(&self, _: &ExternalPskId) -> Result, Self::Error> { - Ok(Some(vec![].into())) - } -} - #[cfg(feature = "psk")] #[cfg(test)] pub(crate) mod test_utils { diff --git a/mls-rs/src/psk/resolver.rs b/mls-rs/src/psk/resolver.rs deleted file mode 100644 index 1b2b5ca5..00000000 --- a/mls-rs/src/psk/resolver.rs +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// Copyright by contributors to this project. -// SPDX-License-Identifier: (Apache-2.0 OR MIT) - -use alloc::vec::Vec; -use mls_rs_core::{ - crypto::CipherSuiteProvider, - error::IntoAnyError, - group::GroupStateStorage, - psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage}, -}; - -use crate::{ - client::MlsError, - group::{epoch::EpochSecrets, state_repo::GroupStateRepository, GroupContext}, - psk::secret::PskSecret, -}; - -use super::{secret::PskSecretInput, JustPreSharedKeyID, PreSharedKeyID, ResumptionPsk}; - -pub(crate) struct PskResolver<'a, GS, PS> -where - GS: GroupStateStorage, - PS: PreSharedKeyStorage, -{ - pub group_context: Option<&'a GroupContext>, - pub current_epoch: Option<&'a EpochSecrets>, - pub prior_epochs: Option<&'a GroupStateRepository>, - pub psk_store: &'a PS, -} - -impl PskResolver<'_, GS, PS> { - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - async fn resolve_resumption(&self, psk_id: &ResumptionPsk) -> Result { - if let Some(ctx) = self.group_context { - if ctx.epoch == psk_id.psk_epoch && ctx.group_id == psk_id.psk_group_id.0 { - let epoch = self.current_epoch.ok_or(MlsError::OldGroupStateNotFound)?; - return Ok(epoch.resumption_secret.clone()); - } - } - - #[cfg(feature = "prior_epoch")] - if let Some(eps) = self.prior_epochs { - if let Some(psk) = eps.resumption_secret(psk_id).await? { - return Ok(psk); - } - } - - Err(MlsError::OldGroupStateNotFound) - } - - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - async fn resolve_external(&self, psk_id: &ExternalPskId) -> Result { - self.psk_store - .get(psk_id) - .await - .map_err(|e| MlsError::PskStoreError(e.into_any_error()))? - .ok_or(MlsError::MissingRequiredPsk) - } - - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - async fn resolve(&self, id: &[PreSharedKeyID]) -> Result, MlsError> { - let mut secret_inputs = Vec::new(); - - for id in id { - let psk = match &id.key_id { - JustPreSharedKeyID::External(external) => self.resolve_external(external).await, - JustPreSharedKeyID::Resumption(resumption) => { - self.resolve_resumption(resumption).await - } - }?; - - secret_inputs.push(PskSecretInput { - id: id.clone(), - psk, - }) - } - - Ok(secret_inputs) - } - - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub async fn resolve_to_secret( - &self, - id: &[PreSharedKeyID], - cipher_suite_provider: &P, - ) -> Result { - let psk = self.resolve(id).await?; - PskSecret::calculate(&psk, cipher_suite_provider).await - } -} diff --git a/mls-rs/src/psk/secret.rs b/mls-rs/src/psk/secret.rs index 4fe9cc83..87d54b4d 100644 --- a/mls-rs/src/psk/secret.rs +++ b/mls-rs/src/psk/secret.rs @@ -25,7 +25,7 @@ use crate::{ }; #[cfg(feature = "psk")] -#[derive(Clone)] +#[derive(Debug, Clone)] pub(crate) struct PskSecretInput { pub id: PreSharedKeyID, pub psk: PreSharedKey, diff --git a/mls-rs/src/test_utils/mod.rs b/mls-rs/src/test_utils/mod.rs index 5cb4f6d2..cfde43f1 100644 --- a/mls-rs/src/test_utils/mod.rs +++ b/mls-rs/src/test_utils/mod.rs @@ -16,13 +16,13 @@ use mls_rs_core::{ identity::{BasicCredential, Credential, SigningIdentity}, key_package::KeyPackageData, protocol_version::ProtocolVersion, - psk::ExternalPskId, }; use crate::{ client_builder::{ClientBuilder, MlsConfig}, error::MlsError, group::{framing::MlsMessageDescription, ExportedTree, NewMemberInfo}, + group_joiner::GroupJoiner, identity::basic::BasicIdentityProvider, mls_rules::{CommitOptions, DefaultMlsRules}, storage_provider::in_memory::InMemoryKeyPackageStorage, @@ -50,11 +50,15 @@ impl TestClient { } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub async fn join_group( - &self, - tree: Option>, - welcome: &MlsMessage, - ) -> Result<(Group, NewMemberInfo), MlsError> { + pub async fn join_group_custom<'a, 'b, F>( + &'a self, + tree: Option>, + welcome: &'a MlsMessage, + joiner_edit: F, + ) -> Result<(Group, NewMemberInfo), MlsError> + where + F: FnOnce(GroupJoiner<'a, 'b, C>) -> GroupJoiner<'a, 'b, C>, + { let MlsMessageDescription::Welcome { key_package_refs, .. } = welcome.description() @@ -70,9 +74,20 @@ impl TestClient { joiner = joiner.ratchet_tree(tree) }; + joiner = joiner_edit(joiner); + joiner.join().await } + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub async fn join_group<'a>( + &'a self, + tree: Option>, + welcome: &'a MlsMessage, + ) -> Result<(Group, NewMemberInfo), MlsError> { + self.join_group_custom(tree, welcome, |joiner| joiner).await + } + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn generate_key_package(&self) -> Result { let package = self.client.key_package_builder(None)?.build().await?; @@ -145,6 +160,11 @@ pub fn make_test_ext_psk() -> Vec { b"secret psk key".to_vec() } +#[cfg_attr(coverage_nightly, coverage(off))] +pub fn make_test_resumption_psk() -> Vec { + b"secret resumption key".to_vec() +} + pub fn is_edwards(cs: u16) -> bool { [ CipherSuite::CURVE25519_AES128, @@ -187,10 +207,6 @@ pub async fn generate_basic_client( .crypto_provider(crypto.clone()) .identity_provider(BasicIdentityProvider::new()) .mls_rules(mls_rules) - .psk( - ExternalPskId::new(TEST_EXT_PSK_ID.to_vec()), - make_test_ext_psk().into(), - ) .used_protocol_version(protocol_version) .signing_identity(identity, secret_key, cipher_suite); diff --git a/mls-rs/test_harness_integration/configs/external_join.json b/mls-rs/test_harness_integration/configs/external_join.json index d6c379db..343a6298 100644 --- a/mls-rs/test_harness_integration/configs/external_join.json +++ b/mls-rs/test_harness_integration/configs/external_join.json @@ -5,12 +5,6 @@ {"action": "externalJoin", "actor": "alice", "joiner": "bob"} ], - "with_psk": [ - {"action": "createGroup", "actor": "alice"}, - {"action": "installExternalPSK", "clients": ["alice"]}, - {"action": "externalJoin", "actor": "alice", "joiner": "bob", "psks": [1]} - ], - "removing_prior": [ {"action": "createGroup", "actor": "alice"}, {"action": "externalJoin", "actor": "alice", "joiner": "bob"}, diff --git a/mls-rs/test_harness_integration/configs/external_join_with_psk.json b/mls-rs/test_harness_integration/configs/external_join_with_psk.json new file mode 100644 index 00000000..5943b0ad --- /dev/null +++ b/mls-rs/test_harness_integration/configs/external_join_with_psk.json @@ -0,0 +1,9 @@ +{ + "scripts": { + "with_psk": [ + {"action": "createGroup", "actor": "alice"}, + {"action": "installExternalPSK", "clients": ["alice"]}, + {"action": "externalJoin", "actor": "alice", "joiner": "bob", "psks": [1]} + ] + } +} diff --git a/mls-rs/test_harness_integration/src/main.rs b/mls-rs/test_harness_integration/src/main.rs index 2b3a1593..3804dc3f 100644 --- a/mls-rs/test_harness_integration/src/main.rs +++ b/mls-rs/test_harness_integration/src/main.rs @@ -30,6 +30,9 @@ use mls_rs::{ MlsMessage, MlsMessageDescription, MlsRules, }; +#[cfg(feature = "psk")] +use mls_rs::{error::MlsError, group::CommitBuilder}; + #[cfg(feature = "by_ref_proposal")] use mls_rs::external_client::builder::ExternalBaseConfig; @@ -282,18 +285,51 @@ impl MlsClient for MlsClientImpl { let tree = get_tree(&request.ratchet_tree)?; - let (group, _) = client + let mut joiner = client .client .group_joiner(&welcome_msg, pkg_data) - .and_then(|joiner| { - match tree { - Some(tree) => joiner.ratchet_tree(tree), - None => joiner, - } - .join() - }) .map_err(abort)?; + if let Some(tree) = tree { + joiner = joiner.ratchet_tree(tree); + } + + #[cfg(feature = "psk")] + let required_external_psks = joiner.required_external_psks().cloned().collect::>(); + + #[cfg(feature = "psk")] + let joiner = required_external_psks + .into_iter() + .try_fold(joiner, |joiner, psk| { + let psk_secret = client + .psk_store + .get(&psk) + .ok_or(Status::aborted("missing psk"))?; + + Ok::<_, Status>(joiner.with_external_psk(psk, psk_secret)) + })?; + + #[cfg(feature = "psk")] + let required_resumption_psks = joiner + .required_resumption_psks() + .cloned() + .collect::>(); + + #[cfg(feature = "psk")] + let joiner = required_resumption_psks + .into_iter() + .try_fold(joiner, |joiner, psk| { + let resumption_psk = client + .client + .load_group(&psk.psk_group_id.0) + .and_then(|g| g.resumption_secret(psk.psk_epoch)) + .map_err(abort)?; + + Ok::<_, Status>(joiner.with_resumption_psk(psk, resumption_psk)) + })?; + + let (group, _) = joiner.join().map_err(abort)?; + let epoch_authenticator = group.epoch_authenticator().map_err(abort)?.to_vec(); client.group = Some(group); client.set_enc_controls(request.encrypt_handshake).await; @@ -360,6 +396,15 @@ impl MlsClient for MlsClientImpl { builder = builder.with_removal(removed_index); } + #[cfg(feature = "psk")] + let builder = request + .psks + .clone() + .into_iter() + .fold(builder, |builder, psk| { + builder.with_external_psk(psk.psk_id.into(), psk.psk_secret.into()) + }); + let (group, commit) = builder.build().map_err(abort)?; let epoch_authenticator = group.epoch_authenticator().map_err(abort)?.to_vec(); @@ -726,7 +771,17 @@ impl MlsClientImpl { let roster = group.roster().members(); - let mut commit_builder = group.commit_builder(); + #[cfg(feature = "psk")] + let group_clone = group.clone(); + + let commit_builder = group.commit_builder(); + + #[cfg(feature = "psk")] + let commit_builder = self + .inject_psks(commit_builder, &client.psk_store, &group_clone) + .map_err(abort)?; + + let mut commit_builder = commit_builder; for proposal in request.by_value { match proposal.proposal_type.as_slice() { @@ -746,12 +801,25 @@ impl MlsClientImpl { #[cfg(feature = "psk")] PROPOSAL_DESC_EXTERNAL_PSK => { let psk_id = ExternalPskId::new(proposal.psk_id.to_vec()); - commit_builder = commit_builder.add_external_psk(psk_id).map_err(abort)?; + + let psk = client + .psk_store + .get(&psk_id) + .ok_or(MlsError::MissingRequiredPsk) + .map_err(abort)?; + + commit_builder = commit_builder + .add_external_psk(psk_id, psk) + .map_err(abort)?; } #[cfg(feature = "psk")] PROPOSAL_DESC_RESUMPTION_PSK => { + let resumption_psk = group_clone + .resumption_secret(proposal.epoch_id) + .map_err(abort)?; + commit_builder = commit_builder - .add_resumption_psk(proposal.epoch_id) + .add_resumption_psk(proposal.epoch_id, resumption_psk) .map_err(abort)?; } PROPOSAL_DESC_GCE => { @@ -791,6 +859,35 @@ impl MlsClientImpl { Ok(Response::new(resp)) } + #[cfg(feature = "psk")] + fn inject_psks<'a>( + &self, + mut commit_builder: CommitBuilder<'a, TestClientConfig>, + psk_store: &InMemoryPreSharedKeyStorage, + group_clone: &Group, + ) -> Result, MlsError> { + let required_psks = commit_builder + .proposals() + .psk_proposals() + .iter() + .map(|psk| psk.proposal.clone()) + .collect::>(); + + for psk_proposal in required_psks { + if let Some(psk_id) = psk_proposal.external_psk_id() { + let psk = psk_store.get(psk_id).ok_or(MlsError::MissingRequiredPsk)?; + commit_builder = commit_builder.apply_external_psk(psk_id.clone(), psk); + } + + if let Some(psk_id) = psk_proposal.resumption_psk_id() { + let psk = group_clone.resumption_secret(psk_id.psk_epoch)?; + commit_builder = commit_builder.apply_resumption_psk(psk_id.clone(), psk); + } + } + + Ok(commit_builder) + } + async fn handle_commit( &self, request: Request, @@ -798,9 +895,11 @@ impl MlsClientImpl { let request = request.into_inner(); let clients = &mut self.clients.lock().await; - let group = clients + let client = clients .get_mut(&request.state_id) - .ok_or_else(|| Status::aborted("no group with such index."))? + .ok_or_else(|| Status::aborted("no group with such index."))?; + + let group = client .group .as_mut() .ok_or_else(|| Status::aborted("no group with such index."))?; @@ -812,17 +911,56 @@ impl MlsClientImpl { let commit = MlsMessage::from_bytes(&request.commit).map_err(abort)?; - let message = group.process_incoming_message(commit).map_err(abort)?; + #[cfg(feature = "psk")] + let group_clone = group.clone(); + + let processor = group.commit_processor(commit).map_err(abort)?; + + #[cfg(feature = "psk")] + let required_external_psks = processor + .required_external_psk() + .cloned() + .collect::>(); + + #[cfg(feature = "psk")] + let processor = + required_external_psks + .into_iter() + .try_fold(processor, |processor, psk| { + let psk_secret = client + .psk_store + .get(&psk) + .ok_or(Status::aborted("missing psk"))?; + + Ok::<_, Status>(processor.with_external_psk(psk, psk_secret)) + })?; + + #[cfg(feature = "psk")] + let required_resumption_psks = processor + .required_resumption_psk() + .cloned() + .collect::>(); + + #[cfg(feature = "psk")] + let processor = + required_resumption_psks + .into_iter() + .try_fold(processor, |processor, psk| { + let resumption_psk = group_clone + .resumption_secret(psk.psk_epoch) + .map_err(abort)?; + + Ok::<_, Status>(processor.with_resumption_psk(psk, resumption_psk)) + })?; + + let message = processor.process().map_err(abort)?; let resp = HandleCommitResponse { state_id: request.state_id, epoch_authenticator: group.epoch_authenticator().map_err(abort)?.to_vec(), }; - match message { - ReceivedMessage::Commit(update) => Ok((Response::new(resp), update.effect)), - _ => Err(Status::aborted("message not a commit.")), - } + Ok((Response::new(resp), message.effect)) } } @@ -845,7 +983,6 @@ async fn create_client(cipher_suite: u16, identity: &[u8]) -> Result