Skip to content

Commit

Permalink
Don't take ownership for messages where not necessary (#84)
Browse files Browse the repository at this point in the history
* Take welcome message by reference

* Take group info by reference

* Take ciphertext by reference

* Fixup

* Fixup

* Simplify tests

* Fixup

---------

Co-authored-by: Marta Mularczyk <[email protected]>
  • Loading branch information
mulmarta and Marta Mularczyk authored Feb 26, 2024
1 parent 13eba2c commit fd94672
Show file tree
Hide file tree
Showing 25 changed files with 238 additions and 356 deletions.
5 changes: 2 additions & 3 deletions mls-rs-uniffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,8 @@ impl Client {
/// Join an existing group.
///
/// See [`mls_rs::Client::join_group`] for details.
pub async fn join_group(&self, welcome_message: Arc<Message>) -> Result<JoinInfo, Error> {
let welcome_message = arc_unwrap_or_clone(welcome_message);
let (group, new_member_info) = self.inner.join_group(None, welcome_message.inner).await?;
pub async fn join_group(&self, welcome_message: &Message) -> Result<JoinInfo, Error> {
let (group, new_member_info) = self.inner.join_group(None, &welcome_message.inner).await?;

let group = Arc::new(Group {
inner: Arc::new(Mutex::new(group)),
Expand Down
16 changes: 6 additions & 10 deletions mls-rs/examples/basic_server_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use mls_rs::{
builder::MlsConfig as ExternalMlsConfig, ExternalClient, ExternalReceivedMessage,
ExternalSnapshot,
},
group::{CachedProposal, ExportedTree, ReceivedMessage},
group::{CachedProposal, ReceivedMessage},
identity::{
basic::{BasicCredential, BasicIdentityProvider},
SigningIdentity,
Expand Down Expand Up @@ -39,11 +39,11 @@ struct BasicServer {

impl BasicServer {
// Client uploads group data after creating the group
fn create_group(group_info: &[u8], tree: ExportedTree) -> Result<Self, MlsError> {
fn create_group(group_info: &[u8]) -> Result<Self, MlsError> {
let server = make_server();
let group_info = MlsMessage::from_bytes(group_info)?;

let group = server.observe_group(group_info, Some(tree))?;
let group = server.observe_group(group_info, None)?;

Ok(Self {
group_state: group.snapshot().to_bytes()?,
Expand Down Expand Up @@ -143,22 +143,18 @@ fn main() -> Result<(), MlsError> {
let mut alice_group = alice.create_group(ExtensionList::default())?;
let bob_key_package = bob.generate_key_package_message()?;

let welcome = alice_group
let welcome = &alice_group
.commit_builder()
.add_member(bob_key_package)?
.build()?
.welcome_messages
.pop()
.expect("key package shouldn't be rejected");
.welcome_messages[0];

let (mut bob_group, _) = bob.join_group(None, welcome)?;
alice_group.apply_pending_commit()?;

// Server starts observing Alice's group
let group_info = alice_group.group_info_message(true)?.to_bytes()?;
let tree = alice_group.export_tree();

let mut server = BasicServer::create_group(&group_info, tree)?;
let mut server = BasicServer::create_group(&group_info)?;

// Bob uploads a proposal
let proposal = bob_group
Expand Down
4 changes: 2 additions & 2 deletions mls-rs/examples/basic_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fn main() -> Result<(), MlsError> {
let bob_key_package = bob.generate_key_package_message()?;

// Alice issues a commit that adds Bob to the group.
let mut alice_commit = alice_group
let alice_commit = alice_group
.commit_builder()
.add_member(bob_key_package)?
.build()?;
Expand All @@ -61,7 +61,7 @@ fn main() -> Result<(), MlsError> {
alice_group.apply_pending_commit()?;

// Bob joins the group with the welcome message created as part of Alice's commit.
let (mut bob_group, _) = bob.join_group(None, alice_commit.welcome_messages.pop().unwrap())?;
let (mut bob_group, _) = bob.join_group(None, &alice_commit.welcome_messages[0])?;

// Alice encrypts an application message to Bob.
let msg = alice_group.encrypt_application_message(b"hello world", Default::default())?;
Expand Down
6 changes: 3 additions & 3 deletions mls-rs/examples/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ fn main() -> Result<(), CustomError> {
.remove(0);

alice_tablet_group.apply_pending_commit()?;
let (mut alice_pc_group, _) = alice_pc_client.join_group(None, welcome)?;
let (mut alice_pc_group, _) = alice_pc_client.join_group(None, &welcome)?;

// Alice cannot add bob's devices yet
let bob_tablet_client = make_client(bob_tablet)?;
Expand All @@ -401,13 +401,13 @@ fn main() -> Result<(), CustomError> {
new_user: bob.credential,
};

let mut commit = alice_tablet_group
let commit = alice_tablet_group
.commit_builder()
.custom_proposal(add_bob.to_custom_proposal()?)
.add_member(key_package)?
.build()?;

bob_tablet_client.join_group(None, commit.welcome_messages.remove(0))?;
bob_tablet_client.join_group(None, &commit.welcome_messages[0])?;
alice_tablet_group.apply_pending_commit()?;
alice_pc_group.process_incoming_message(commit.commit_message)?;

Expand Down
9 changes: 4 additions & 5 deletions mls-rs/examples/large_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ fn make_groups_best_case<P: CryptoProvider + Clone>(
let bob_kpkg = bob_client.generate_key_package_message()?;

// Last group sends a commit adding the new client to the group.
let mut commit = groups
let commit = groups
.last_mut()
.unwrap()
.commit_builder()
Expand All @@ -85,8 +85,7 @@ fn make_groups_best_case<P: CryptoProvider + Clone>(
groups.last_mut().unwrap().apply_pending_commit()?;

// The new member joins.
let welcome_message = commit.welcome_messages.pop().unwrap();
let (bob_group, _info) = bob_client.join_group(None, welcome_message)?;
let (bob_group, _info) = bob_client.join_group(None, &commit.welcome_messages[0])?;

groups.push(bob_group);
}
Expand Down Expand Up @@ -115,15 +114,15 @@ fn make_groups_worst_case<P: CryptoProvider + Clone>(
commit_builder = commit_builder.add_member(bob_kpkg)?;
}

let welcome_message: mls_rs::MlsMessage = commit_builder.build()?.welcome_messages.remove(0);
let welcome_message = &commit_builder.build()?.welcome_messages[0];

alice_group.apply_pending_commit()?;

// Bob's clients join the group.
let mut groups = vec![alice_group];

for bob_client in &bob_clients {
let (bob_group, _info) = bob_client.join_group(None, welcome_message.clone())?;
let (bob_group, _info) = bob_client.join_group(None, welcome_message)?;
groups.push(bob_group);
}

Expand Down
30 changes: 10 additions & 20 deletions mls-rs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use crate::group::framing::MlsMessage;
use crate::group::{
framing::{Content, MlsMessagePayload, PublicMessage, Sender, WireFormat},
message_signature::AuthenticatedContent,
process_group_info,
proposal::{AddProposal, Proposal},
};
use crate::group::{ExportedTree, Group, NewMemberInfo};
Expand Down Expand Up @@ -536,7 +535,7 @@ where
pub async fn join_group(
&self,
tree_data: Option<ExportedTree<'_>>,
welcome_message: MlsMessage,
welcome_message: &MlsMessage,
) -> Result<(Group<C>, NewMemberInfo), MlsError> {
Group::join(
welcome_message,
Expand Down Expand Up @@ -627,7 +626,7 @@ where
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn external_add_proposal(
&self,
group_info: MlsMessage,
group_info: &MlsMessage,
tree_data: Option<crate::group::ExportedTree<'_>>,
authenticated_data: Vec<u8>,
) -> Result<MlsMessage, MlsError> {
Expand All @@ -638,7 +637,7 @@ where
}

let group_info = group_info
.into_group_info()
.as_group_info()
.ok_or(MlsError::UnexpectedMessageType)?;

let cipher_suite = group_info.group_context.cipher_suite;
Expand All @@ -649,15 +648,14 @@ where
.cipher_suite_provider(cipher_suite)
.ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))?;

let group_context = process_group_info(
crate::group::validate_group_info_joiner(
protocol_version,
group_info,
tree_data,
&self.config.identity_provider(),
&cipher_suite_provider,
)
.await?
.group_context;
.await?;

let key_package = self.generate_key_package().await?.key_package;

Expand All @@ -667,7 +665,7 @@ where

let message = AuthenticatedContent::new_signed(
&cipher_suite_provider,
&group_context,
&group_info.group_context,
Sender::NewMemberProposal,
Content::Proposal(Box::new(Proposal::Add(Box::new(AddProposal {
key_package,
Expand Down Expand Up @@ -830,12 +828,8 @@ mod tests {

let proposal = bob
.external_add_proposal(
alice_group
.group
.group_info_message_allowing_ext_commit(true)
.await
.unwrap(),
Some(alice_group.group.export_tree()),
&alice_group.group.group_info_message(true).await.unwrap(),
None,
vec![],
)
.await
Expand Down Expand Up @@ -890,7 +884,7 @@ mod tests {

let group_info_msg = alice_group
.group
.group_info_message_allowing_ext_commit(false)
.group_info_message_allowing_ext_commit(true)
.await
.unwrap();

Expand All @@ -904,10 +898,7 @@ mod tests {
.signing_identity(new_client_identity.clone(), secret_key, TEST_CIPHER_SUITE)
.build();

let mut builder = new_client
.external_commit_builder()
.unwrap()
.with_tree_data(alice_group.group.export_tree().into_owned());
let mut builder = new_client.external_commit_builder().unwrap();

if do_remove {
builder = builder.with_removal(1);
Expand Down Expand Up @@ -1013,7 +1004,6 @@ mod tests {
let (_, external_commit) = carol
.external_commit_builder()
.unwrap()
.with_tree_data(bob_group.group.export_tree().into_owned())
.build(group_info_msg)
.await
.unwrap();
Expand Down
19 changes: 10 additions & 9 deletions mls-rs/src/external_client/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use crate::{
snapshot::RawGroupState,
state::GroupState,
transcript_hash::InterimTranscriptHash,
validate_group_info, ContentType, ExportedTree, GroupContext, GroupInfo, Roster, Welcome,
validate_group_info_joiner, ContentType, ExportedTree, GroupContext, GroupInfo, Roster,
Welcome,
},
identity::SigningIdentity,
protocol_version::ProtocolVersion,
Expand Down Expand Up @@ -126,9 +127,9 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
group_info.group_context.cipher_suite,
)?;

let join_context = validate_group_info(
let public_tree = validate_group_info_joiner(
protocol_version,
group_info,
&group_info,
tree_data,
&config.identity_provider(),
&cipher_suite_provider,
Expand All @@ -137,19 +138,19 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {

let interim_transcript_hash = InterimTranscriptHash::create(
&cipher_suite_provider,
&join_context.group_context.confirmed_transcript_hash,
&join_context.confirmation_tag,
&group_info.group_context.confirmed_transcript_hash,
&group_info.confirmation_tag,
)
.await?;

Ok(Self {
config,
signing_data,
state: GroupState::new(
join_context.group_context,
join_context.public_tree,
group_info.group_context,
public_tree,
interim_transcript_hash,
join_context.confirmation_tag,
group_info.confirmation_tag,
),
cipher_suite_provider,
})
Expand Down Expand Up @@ -597,7 +598,7 @@ where
#[cfg(feature = "private_message")]
async fn process_ciphertext(
&mut self,
cipher_text: PrivateMessage,
cipher_text: &PrivateMessage,
) -> Result<EventOrContent<Self::OutputType>, MlsError> {
Ok(EventOrContent::Event(ExternalReceivedMessage::Ciphertext(
cipher_text.content_type,
Expand Down
12 changes: 6 additions & 6 deletions mls-rs/src/group/ciphertext_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ where
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn open(
&mut self,
ciphertext: PrivateMessage,
ciphertext: &PrivateMessage,
) -> Result<AuthenticatedContent, MlsError> {
// Decrypt the sender data with the derived sender_key and sender_nonce from the message
// epoch's key schedule
Expand Down Expand Up @@ -236,7 +236,7 @@ where
.decrypt(
&self.cipher_suite_provider,
&ciphertext.ciphertext,
&PrivateContentAAD::from(&ciphertext).mls_encode_to_vec()?,
&PrivateContentAAD::from(ciphertext).mls_encode_to_vec()?,
&sender_data.reuse_guard,
)
.await
Expand All @@ -252,7 +252,7 @@ where
group_id: ciphertext.group_id.clone(),
epoch: ciphertext.epoch,
sender,
authenticated_data: ciphertext.authenticated_data,
authenticated_data: ciphertext.authenticated_data.clone(),
content: ciphertext_content.content,
},
auth: ciphertext_content.auth,
Expand Down Expand Up @@ -335,7 +335,7 @@ mod test {

let mut receiver_processor = test_processor(&mut receiver_group, cipher_suite);

let decrypted = receiver_processor.open(ciphertext).await.unwrap();
let decrypted = receiver_processor.open(&ciphertext).await.unwrap();

assert_eq!(decrypted, test_data.content);
}
Expand Down Expand Up @@ -384,7 +384,7 @@ mod test {
.await
.unwrap();

let res = ciphertext_processor.open(ciphertext).await;
let res = ciphertext_processor.open(&ciphertext).await;

assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf))
}
Expand All @@ -403,7 +403,7 @@ mod test {
ciphertext.ciphertext = random_bytes(ciphertext.ciphertext.len());
receiver_group.group.private_tree.self_index = LeafIndex::new(1);

let res = ciphertext_processor.open(ciphertext).await;
let res = ciphertext_processor.open(&ciphertext).await;

assert!(res.is_err());
}
Expand Down
4 changes: 2 additions & 2 deletions mls-rs/src/group/commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ mod tests {
.welcome_messages
.remove(0);

let (_, context) = bob_client.join_group(None, welcome_message).await.unwrap();
let (_, context) = bob_client.join_group(None, &welcome_message).await.unwrap();

assert_eq!(
context
Expand Down Expand Up @@ -1243,7 +1243,7 @@ mod tests {
.find(|w| w.welcome_key_package_references().contains(&&kp_ref))
.unwrap();

client.join_group(None, welcome.clone()).await.unwrap();
client.join_group(None, welcome).await.unwrap();

assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
}
Expand Down
Loading

0 comments on commit fd94672

Please sign in to comment.