Skip to content

Commit

Permalink
implementing safe encrypt/decrypt with context
Browse files Browse the repository at this point in the history
  • Loading branch information
Ellen Arteca committed Jan 28, 2025
1 parent 29affc7 commit 30369cc
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 0 deletions.
32 changes: 32 additions & 0 deletions mls-rs/src/group/component_operation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use crate::client::MlsError;
use crate::tree_kem::hpke_encryption::HpkeEncryptable;
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};

pub type ComponentID = u32;

#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
pub struct ComponentOperationLabel {
component_id: ComponentID,
context: Vec<u8>,
}

impl HpkeEncryptable for ComponentOperationLabel {
const ENCRYPT_LABEL: &'static str = "MLS 1.0 Application";

fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError> {
Self::mls_decode(&mut bytes.as_slice()).map_err(Into::into)
}

fn get_bytes(&self) -> Result<Vec<u8>, MlsError> {
self.mls_encode_to_vec().map_err(Into::into)
}
}

impl ComponentOperationLabel {
pub fn new(component_id: u32, context: Vec<u8>) -> Self {
Self {
component_id,
context,
}
}
}
174 changes: 174 additions & 0 deletions mls-rs/src/group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ use crate::psk::{
#[cfg(feature = "private_message")]
use ciphertext_processor::*;

use component_operation::{ComponentID, ComponentOperationLabel};
use confirmation_tag::*;
use framing::*;
use key_schedule::*;
Expand Down Expand Up @@ -111,6 +112,7 @@ pub use self::message_processor::CachedProposal;
mod ciphertext_processor;

mod commit;
pub mod component_operation;
pub(crate) mod confirmation_tag;
pub(crate) mod epoch;
pub(crate) mod framing;
Expand Down Expand Up @@ -628,6 +630,33 @@ where
Ok(hpke_ciphertext)
}

pub fn safe_encrypt_with_context_to_recipient(
&self,
recipient_index: u32,
component_id: ComponentID,
context: &[u8],
associated_data: Option<&[u8]>,
plaintext: &[u8],
) -> Result<HpkeCiphertext, MlsError> {
let component_operation_label =
ComponentOperationLabel::new(component_id, context.to_vec());
let member_leaf_node = self
.group_state()
.public_tree
.get_leaf_node(LeafIndex(recipient_index))?;
let member_public_key = &member_leaf_node.public_key;
let hpke_ciphertext = self
.cipher_suite_provider
.hpke_seal(
member_public_key,
&component_operation_label.get_bytes()?,
associated_data,
plaintext,
)
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
Ok(hpke_ciphertext)
}

/// HPKE decrypts a message sent to the current member.
///
/// Takes `HpkeCiphertext` generated by `hpke_encrypt_to_recipient` intended for the
Expand Down Expand Up @@ -658,6 +687,32 @@ where
Ok(plaintext)
}

pub fn safe_decrypt_with_context_for_current_member(
&self,
component_id: ComponentID,
context: &[u8],
associated_data: Option<&[u8]>,
hpke_ciphertext: HpkeCiphertext,
) -> Result<Vec<u8>, MlsError> {
let component_operation_label =
ComponentOperationLabel::new(component_id, context.to_vec());
let self_private_key = &self.private_tree.secret_keys[0]
.as_ref()
.ok_or(MlsError::InvalidTreeKemPrivateKey)?;
let self_public_key = &self.current_user_leaf_node()?.public_key;
let plaintext = self
.cipher_suite_provider
.hpke_open(
&hpke_ciphertext,
self_private_key,
self_public_key,
&component_operation_label.get_bytes()?,
associated_data,
)
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
Ok(plaintext)
}

/// Index within the group's state for the local group instance.
///
/// This index corresponds to indexes in content descriptions within
Expand Down Expand Up @@ -2108,6 +2163,7 @@ mod tests {
client::test_utils::{test_client_with_key_pkg_custom, TEST_CUSTOM_PROPOSAL_TYPE},
client_builder::{ClientBuilder, MlsConfig},
group::{
component_operation::ComponentID,
mls_rules::{CommitDirection, CommitSource},
proposal_filter::ProposalBundle,
},
Expand Down Expand Up @@ -2384,6 +2440,41 @@ mod tests {
assert_eq!(plaintext.to_vec(), hpke_decrypted);
}

#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn safe_context_test_hpke_encrypt_decrypt() {
let component_id: ComponentID = 1;
let (alice_group, bob_group) =
test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, true).await;
let receiver_index = alice_group.current_member_index();
let sender_index = bob_group.current_member_index();

let context_info: Vec<u8> = vec![
receiver_index.try_into().unwrap(),
sender_index.try_into().unwrap(),
];
let plaintext = b"message";

let hpke_ciphertext = bob_group
.safe_encrypt_with_context_to_recipient(
receiver_index,
component_id,
&context_info,
None,
plaintext,
)
.unwrap();
let hpke_decrypted = alice_group
.safe_decrypt_with_context_for_current_member(
component_id,
&context_info,
None,
hpke_ciphertext,
)
.unwrap();

assert_eq!(plaintext.to_vec(), hpke_decrypted);
}

#[cfg(feature = "non_domain_separated_hpke_encrypt_decrypt")]
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_hpke_non_recipient_cant_decrypt() {
Expand Down Expand Up @@ -2415,6 +2506,47 @@ mod tests {
assert_matches!(hpke_decrypted, Err(MlsError::CryptoProviderError(_)));
}

#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn safe_context_test_hpke_non_recipient_cant_decrypt() {
let component_id: ComponentID = 345;
let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let (mut bob, _) = alice.join("bob").await;
let (carol, commit) = alice.join("carol").await;

// Apply the commit that adds carol
bob.process_incoming_message(commit).await.unwrap();

let receiver_index = alice.current_member_index();
let sender_index = bob.current_member_index();

let context_info: Vec<u8> = vec![
receiver_index.try_into().unwrap(),
sender_index.try_into().unwrap(),
];
let plaintext = b"message";

let hpke_ciphertext = bob
.safe_encrypt_with_context_to_recipient(
receiver_index,
component_id,
&context_info,
None,
plaintext,
)
.unwrap();

// different recipient tries to decrypt
let hpke_decrypted = carol.safe_decrypt_with_context_for_current_member(
component_id,
&context_info,
None,
hpke_ciphertext,
);

// should fail because carol can't decrypt the message encrypted for alice
assert_matches!(hpke_decrypted, Err(MlsError::CryptoProviderError(_)));
}

#[cfg(feature = "non_domain_separated_hpke_encrypt_decrypt")]
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_hpke_can_decrypt_after_group_changes() {
Expand Down Expand Up @@ -2451,6 +2583,48 @@ mod tests {
assert_eq!(plaintext.to_vec(), hpke_decrypted);
}

#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn safe_context_test_hpke_can_decrypt_after_group_changes() {
let component_id: ComponentID = 2;
let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let (mut bob, _) = alice.join("bob").await;

let receiver_index = alice.current_member_index();
let sender_index = bob.current_member_index();
let context_info: Vec<u8> = vec![
receiver_index.try_into().unwrap(),
sender_index.try_into().unwrap(),
];
let associated_data: Vec<u8> = vec![1, 2, 3, 4];
let plaintext = b"message";

// encrypt the message to alice
let hpke_ciphertext = bob
.safe_encrypt_with_context_to_recipient(
receiver_index,
component_id,
&context_info,
Some(&associated_data),
plaintext,
)
.unwrap();

// add carol to the group
let (_carol, commit) = alice.join("carol").await;
bob.process_incoming_message(commit).await.unwrap();

// make sure alice can still decrypt
let hpke_decrypted = alice
.safe_decrypt_with_context_for_current_member(
component_id,
&context_info,
Some(&associated_data),
hpke_ciphertext,
)
.unwrap();
assert_eq!(plaintext.to_vec(), hpke_decrypted);
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn test_two_member_group(
protocol_version: ProtocolVersion,
Expand Down

0 comments on commit 30369cc

Please sign in to comment.