From 19f2d833dafb6cc8bca3366b87ac6939687db552 Mon Sep 17 00:00:00 2001 From: Marta Mularczyk Date: Fri, 5 Jan 2024 11:54:46 +0100 Subject: [PATCH] Compute proposal reference from MlsMessage --- mls-rs/src/group/framing.rs | 52 +++++++++++++++++++++++++++++++- mls-rs/src/group/proposal_ref.rs | 11 +++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/mls-rs/src/group/framing.rs b/mls-rs/src/group/framing.rs index b3299b15..c5d66e19 100644 --- a/mls-rs/src/group/framing.rs +++ b/mls-rs/src/group/framing.rs @@ -419,6 +419,23 @@ impl MlsMessage { kp.to_reference(cipher_suite).await.map(Some) } + + /// If this is a plaintext proposal, return the proposal reference that can be matched e.g. with + /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals). + #[cfg(feature = "by_ref_proposal")] + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub async fn proposal_reference( + &self, + cipher_suite: &C, + ) -> Result, MlsError> { + matches!(&self.payload, MlsMessagePayload::Plain(ptx) if ptx.content.content_type() == ContentType::Proposal) + .then_some(()) + .ok_or(MlsError::UnexpectedMessageType)?; + + ProposalRef::from_bytes(cipher_suite, &self.payload.mls_encode_to_vec()?) + .await + .map(|r| r.to_vec()) + } } #[cfg(feature = "custom_proposal")] @@ -545,7 +562,14 @@ pub(crate) mod test_utils { mod tests { use assert_matches::assert_matches; - use crate::group::framing::test_utils::get_test_ciphertext_content; + use crate::{ + client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, + crypto::test_utils::test_cipher_suite_provider, + group::{ + framing::test_utils::get_test_ciphertext_content, + proposal_ref::test_utils::auth_content_from_proposal, + }, + }; use super::*; @@ -575,4 +599,30 @@ mod tests { assert_matches!(decoded, Err(mls_rs_codec::Error::Custom(_))); } + + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn proposal_ref() { + let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE); + + let test_auth = auth_content_from_proposal( + Proposal::Remove(RemoveProposal { + to_remove: LeafIndex(0), + }), + Sender::External(0), + ); + + let expected_ref = ProposalRef::from_content(&cs, &test_auth).await.unwrap(); + + let test_message = MlsMessage { + version: TEST_PROTOCOL_VERSION, + payload: MlsMessagePayload::Plain(PublicMessage { + content: test_auth.content, + auth: test_auth.auth, + membership_tag: None, + }), + }; + + let computed_ref = test_message.proposal_reference(&cs).unwrap(); + assert_eq!(computed_ref, expected_ref.to_vec()); + } } diff --git a/mls-rs/src/group/proposal_ref.rs b/mls-rs/src/group/proposal_ref.rs index bfbd9d28..71d8e5b8 100644 --- a/mls-rs/src/group/proposal_ref.rs +++ b/mls-rs/src/group/proposal_ref.rs @@ -33,6 +33,17 @@ impl ProposalRef { .await?, )) } + + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub(crate) async fn from_bytes( + cipher_suite_provider: &CS, + bytes: &[u8], + ) -> Result { + Ok(ProposalRef( + HashReference::compute(bytes, b"MLS 1.0 Proposal Reference", cipher_suite_provider) + .await?, + )) + } } #[cfg(test)]