From 4b058ee086cb5847325ddbcc7f020da212dfdd0a Mon Sep 17 00:00:00 2001 From: Tom Leavy Date: Wed, 15 Jan 2025 23:43:01 -0500 Subject: [PATCH] Allow adding psks to use with by ref proposals --- mls-rs/src/group/commit/builder.rs | 34 ++++++++++ mls-rs/src/group/mod.rs | 63 ++++++++++++++++++- .../group/proposal_filter/filtering_common.rs | 24 +++---- 3 files changed, 104 insertions(+), 17 deletions(-) diff --git a/mls-rs/src/group/commit/builder.rs b/mls-rs/src/group/commit/builder.rs index 73fa594c..495f927e 100644 --- a/mls-rs/src/group/commit/builder.rs +++ b/mls-rs/src/group/commit/builder.rs @@ -229,6 +229,40 @@ where Ok(self.with_proposal(proposal)) } + /// Add an external PSK that can be used to fulfil PSK requirements that were + /// established via a [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) + /// from another client during the current epoch. + pub fn apply_external_psk(self, id: ExternalPskId, psk: crate::psk::PreSharedKey) -> Self { + self.apply_psk(JustPreSharedKeyID::External(id), psk) + } + + /// Add an resumption PSK that can be used to fulfil PSK requirements that were + /// established via a [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) + /// from another client during the current epoch. + pub fn apply_resumption_psk(self, id: ResumptionPsk, psk: crate::psk::PreSharedKey) -> Self { + self.apply_psk(JustPreSharedKeyID::Resumption(id), psk) + } + + #[cfg(feature = "psk")] + fn apply_psk(mut self, id: JustPreSharedKeyID, psk: crate::psk::PreSharedKey) -> Self { + if let Some(secret_input) = self + .proposals + .psks + .iter() + .filter(|proposal| proposal.is_by_reference()) + .find_map(|p| { + (p.proposal.psk.key_id == id).then(|| PskSecretInput { + id: p.proposal.psk.clone(), + psk: psk.clone(), + }) + }) + { + self.psks.push(secret_input); + } + + self + } + /// Insert a /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with /// an external PSK into the current commit that is being built. diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 1bc236bc..e357c6dc 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -3675,7 +3675,7 @@ mod tests { #[cfg(feature = "psk")] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] - async fn can_process_with_psk() { + async fn can_process_commit_with_psk_by_value() { let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let (mut bob, _) = alice.join("bob").await; @@ -3711,6 +3711,67 @@ mod tests { .unwrap(); } + #[cfg(feature = "psk")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn can_process_commit_with_psk_by_reference() { + let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; + let (mut bob, _) = alice.join("bob").await; + + let psk_id_external = ExternalPskId::new(vec![0]); + let psk_external = PreSharedKey::from(vec![1]); + let psk_epoch = bob.context().epoch; + + let psk_id_resumption = ResumptionPsk { + usage: ResumptionPSKUsage::Application, + psk_group_id: bob.context().group_id.clone().into(), + psk_epoch, + }; + + let psk_resumption = bob.resumption_secret(psk_epoch).await.unwrap(); + + let psk_external_proposal = alice + .propose_external_psk(psk_id_external.clone(), Vec::new()) + .await + .unwrap(); + + let psk_resumption_proposal = alice + .propose_resumption_psk(psk_epoch, Vec::new()) + .await + .unwrap(); + + // This fails due to not having the secrets for by reference psk + let commit = alice + .commit_builder() + .apply_external_psk(psk_id_external.clone(), psk_external.clone()) + .apply_resumption_psk(psk_id_resumption.clone(), psk_resumption.clone()) + .build() + .await + .unwrap(); + + let commit_changes = alice.apply_pending_commit().await.unwrap(); + + // Make sure the proposals are actually in the commit to verify the rest of the test is + // working as expected + assert_matches!(commit_changes.effect, CommitEffect::NewEpoch(epoch) if epoch.unused_proposals.is_empty()); + + bob.process_incoming_message(psk_external_proposal) + .await + .unwrap(); + + bob.process_incoming_message(psk_resumption_proposal) + .await + .unwrap(); + + bob.commit_processor(commit.commit_message) + .await + .unwrap() + .with_external_psk(psk_id_external, psk_external) + .with_resumption_psk(psk_id_resumption, psk_resumption) + .process() + .await + .unwrap(); + } + #[cfg(feature = "by_ref_proposal")] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn invalid_update_does_not_prevent_other_updates() { diff --git a/mls-rs/src/group/proposal_filter/filtering_common.rs b/mls-rs/src/group/proposal_filter/filtering_common.rs index 6d473b30..94cd0758 100644 --- a/mls-rs/src/group/proposal_filter/filtering_common.rs +++ b/mls-rs/src/group/proposal_filter/filtering_common.rs @@ -382,21 +382,13 @@ where #[cfg(not(feature = "std"))] let is_new_id = !ids_seen.contains(&p.proposal.psk); - let external_id_is_valid = match &p.proposal.psk.key_id { - JustPreSharedKeyID::External(id) => { - if psks.iter().any( - |one_id| matches!(one_id, JustPreSharedKeyID::External(ext_id) if ext_id == id), - ) { - Ok(()) - } else { - Err(MlsError::MissingRequiredPsk) - } - } - JustPreSharedKeyID::Resumption(_) => Ok(()), - }; + let has_required_psk_secret = psks + .contains(&p.proposal.psk.key_id) + .then_some(()) + .ok_or_else(|| MlsError::MissingRequiredPsk); #[cfg(not(feature = "psk"))] - let external_id_is_valid = Ok(()); + let has_required_psk_secret = Ok(()); #[cfg(not(feature = "by_ref_proposal"))] if !valid { @@ -405,8 +397,8 @@ where return Err(MlsError::InvalidPskNonceLength); } else if !is_new_id { return Err(MlsError::DuplicatePskIds); - } else if external_id_is_valid.is_err() { - return external_id_is_valid; + } else if has_required_psk_secret.is_err() { + return has_required_psk_secret; } #[cfg(feature = "by_ref_proposal")] @@ -418,7 +410,7 @@ where } else if !is_new_id { Err(MlsError::DuplicatePskIds) } else { - external_id_is_valid + has_required_psk_secret }; if !apply_strategy(strategy, p.is_by_reference(), res)? {