Skip to content

Commit

Permalink
[uniffi] Add support for custom GroupStateStorage interface (#86)
Browse files Browse the repository at this point in the history
* [uniffi] Add support for custom GroupStateStorage interface

* Allow external languages to implement callbacks

* Remove unwrap from group_state.rs

* Fix simple_scenario_sync and remove async for now

* Ignore async tests
  • Loading branch information
tomleavy authored Feb 27, 2024
1 parent 6d107b8 commit 7d8c898
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 44 deletions.
1 change: 1 addition & 0 deletions mls-rs-uniffi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ crate-type = ["lib", "cdylib"]
name = "mls_rs_uniffi"

[dependencies]
async-trait = "0.1.77"
maybe-async = "0.2.10"
mls-rs = { path = "../mls-rs" }
mls-rs-core = { path = "../mls-rs-core" }
Expand Down
43 changes: 43 additions & 0 deletions mls-rs-uniffi/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use std::{fmt::Debug, sync::Arc};

use mls_rs::{
client_builder::{self, WithGroupStateStorage},
identity::basic,
};
use mls_rs_core::error::IntoAnyError;
use mls_rs_crypto_openssl::OpensslCryptoProvider;

use self::group_state::GroupStateStorageWrapper;

mod group_state;

#[derive(Debug, thiserror::Error, uniffi::Error)]
#[uniffi(flat_error)]
#[non_exhaustive]
pub enum FFICallbackError {
#[error("data preparation error")]
DataPreparationError {
#[from]
inner: mls_rs_core::mls_rs_codec::Error,
},
#[error("unexpected callback error")]
UnexpectedCallbackError {
#[from]
inner: uniffi::UnexpectedUniFFICallbackError,
},
}

impl IntoAnyError for FFICallbackError {}

pub type UniFFIConfig = client_builder::WithIdentityProvider<
basic::BasicIdentityProvider,
client_builder::WithCryptoProvider<
OpensslCryptoProvider,
WithGroupStateStorage<GroupStateStorageWrapper, client_builder::BaseConfig>,
>,
>;

#[derive(Debug, Clone, uniffi::Record)]
pub struct ClientConfig {
pub group_state_storage: Arc<dyn group_state::GroupStateStorage>,
}
130 changes: 130 additions & 0 deletions mls-rs-uniffi/src/config/group_state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use std::{fmt::Debug, sync::Arc};

use mls_rs_core::mls_rs_codec::{MlsDecode, MlsEncode};

use super::FFICallbackError;

#[derive(Clone, Debug, uniffi::Record)]
pub struct GroupState {
pub id: Vec<u8>,
pub data: Vec<u8>,
}

impl mls_rs_core::group::GroupState for GroupState {
fn id(&self) -> Vec<u8> {
self.id.clone()
}
}

#[derive(Clone, Debug, uniffi::Record)]
pub struct EpochRecord {
pub id: u64,
pub data: Vec<u8>,
}

impl mls_rs_core::group::EpochRecord for EpochRecord {
fn id(&self) -> u64 {
self.id
}
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
#[uniffi::export(with_foreign)]
pub trait GroupStateStorage: Send + Sync + Debug {
async fn state(&self, group_id: Vec<u8>) -> Result<Option<Vec<u8>>, FFICallbackError>;
async fn epoch(
&self,
group_id: Vec<u8>,
epoch_id: u64,
) -> Result<Option<Vec<u8>>, FFICallbackError>;

async fn write(
&self,
state: GroupState,
epoch_inserts: Vec<EpochRecord>,
epoch_updates: Vec<EpochRecord>,
) -> Result<(), FFICallbackError>;

async fn max_epoch_id(&self, group_id: Vec<u8>) -> Result<Option<u64>, FFICallbackError>;
}

#[derive(Debug, Clone)]
pub(crate) struct GroupStateStorageWrapper(Arc<dyn GroupStateStorage>);

impl From<Arc<dyn GroupStateStorage>> for GroupStateStorageWrapper {
fn from(value: Arc<dyn GroupStateStorage>) -> Self {
Self(value)
}
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper {
type Error = FFICallbackError;

async fn state<T>(&self, group_id: &[u8]) -> Result<Option<T>, Self::Error>
where
T: mls_rs_core::group::GroupState + MlsEncode + MlsDecode,
{
let state_data = self.0.state(group_id.to_vec())?;

state_data
.as_deref()
.map(|v| T::mls_decode(&mut &*v))
.transpose()
.map_err(Into::into)
}

async fn epoch<T>(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<T>, Self::Error>
where
T: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode,
{
let epoch_data = self.0.epoch(group_id.to_vec(), epoch_id)?;

epoch_data
.as_deref()
.map(|v| T::mls_decode(&mut &*v))
.transpose()
.map_err(Into::into)
}

async fn write<ST, ET>(
&mut self,
state: ST,
epoch_inserts: Vec<ET>,
epoch_updates: Vec<ET>,
) -> Result<(), Self::Error>
where
ST: mls_rs_core::group::GroupState + MlsEncode + MlsDecode + Send + Sync,
ET: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode + Send + Sync,
{
let state = GroupState {
id: state.id(),
data: state.mls_encode_to_vec()?,
};

let epoch_to_record = |v: ET| -> Result<_, Self::Error> {
Ok(EpochRecord {
id: v.id(),
data: v.mls_encode_to_vec()?,
})
};

let inserts = epoch_inserts
.into_iter()
.map(epoch_to_record)
.collect::<Result<Vec<_>, _>>()?;

let updates = epoch_updates
.into_iter()
.map(epoch_to_record)
.collect::<Result<Vec<_>, _>>()?;

self.0.write(state, inserts, updates)
}

async fn max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error> {
self.0.max_epoch_id(group_id.to_vec())
}
}
65 changes: 44 additions & 21 deletions mls-rs-uniffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
//!
//! [UniFFI]: https://mozilla.github.io/uniffi-rs/
mod config;
#[cfg(test)]
pub mod test_utils;

use std::sync::Arc;

use config::{ClientConfig, UniFFIConfig};
#[cfg(not(mls_build_async))]
use std::sync::Mutex;
#[cfg(mls_build_async)]
use tokio::sync::Mutex;

use mls_rs::client_builder;
use mls_rs::error::{IntoAnyError, MlsError};
use mls_rs::group;
use mls_rs::identity::basic;
Expand All @@ -52,12 +53,12 @@ fn arc_unwrap_or_clone<T: Clone>(arc: Arc<T>) -> T {
#[uniffi(flat_error)]
#[non_exhaustive]
pub enum Error {
#[error("A mls-rs error occurred")]
#[error("A mls-rs error occurred: {inner}")]
MlsError {
#[from]
inner: mls_rs::error::MlsError,
},
#[error("An unknown error occurred")]
#[error("An unknown error occurred: {inner}")]
AnyError {
#[from]
inner: mls_rs::error::AnyError,
Expand Down Expand Up @@ -96,11 +97,6 @@ pub struct SignatureKeypair {
secret_key: Arc<SignatureSecretKey>,
}

pub type Config = client_builder::WithIdentityProvider<
basic::BasicIdentityProvider,
client_builder::WithCryptoProvider<OpensslCryptoProvider, client_builder::BaseConfig>,
>;

/// Light-weight wrapper around a [`mls_rs::ExtensionList`].
#[derive(uniffi::Object, Debug, Clone)]
pub struct ExtensionList {
Expand Down Expand Up @@ -247,7 +243,7 @@ pub async fn generate_signature_keypair(
/// See [`mls_rs::Client`] for details.
#[derive(Clone, Debug, uniffi::Object)]
pub struct Client {
inner: mls_rs::client::Client<Config>,
inner: mls_rs::client::Client<UniFFIConfig>,
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
Expand All @@ -260,21 +256,27 @@ impl Client {
///
/// See [`mls_rs::Client::builder`] for details.
#[uniffi::constructor]
pub fn new(id: Vec<u8>, signature_keypair: SignatureKeypair) -> Self {
pub fn new(
id: Vec<u8>,
signature_keypair: SignatureKeypair,
client_config: ClientConfig,
) -> Self {
let cipher_suite = signature_keypair.cipher_suite;
let public_key = arc_unwrap_or_clone(signature_keypair.public_key);
let secret_key = arc_unwrap_or_clone(signature_keypair.secret_key);
let crypto_provider = OpensslCryptoProvider::new();
let basic_credential = BasicCredential::new(id);
let signing_identity =
identity::SigningIdentity::new(basic_credential.into_credential(), public_key.inner);
Client {
inner: mls_rs::Client::builder()
.crypto_provider(crypto_provider)
.identity_provider(basic::BasicIdentityProvider::new())
.signing_identity(signing_identity, secret_key.inner, cipher_suite.into())
.build(),
}

let client = mls_rs::Client::builder()
.crypto_provider(crypto_provider)
.identity_provider(basic::BasicIdentityProvider::new())
.signing_identity(signing_identity, secret_key.inner, cipher_suite.into())
.group_state_storage(client_config.group_state_storage.into())
.build();

Client { inner: client }
}

/// Generate a new key package for this client.
Expand Down Expand Up @@ -327,6 +329,19 @@ impl Client {
group_info_extensions,
})
}

/// Load an existing group.
///
/// See [`mls_rs::Client::load_group`] for details.
pub async fn load_group(&self, group_id: Vec<u8>) -> Result<Group, Error> {
self.inner
.load_group(&group_id)
.await
.map(|g| Group {
inner: Arc::new(Mutex::new(g)),
})
.map_err(Into::into)
}
}

#[derive(Clone, Debug, uniffi::Object)]
Expand Down Expand Up @@ -379,25 +394,25 @@ impl From<identity::SigningIdentity> for SigningIdentity {
/// See [`mls_rs::Group`] for details.
#[derive(Clone, uniffi::Object)]
pub struct Group {
inner: Arc<Mutex<mls_rs::Group<Config>>>,
inner: Arc<Mutex<mls_rs::Group<UniFFIConfig>>>,
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
impl Group {
#[cfg(not(mls_build_async))]
fn inner(&self) -> std::sync::MutexGuard<'_, mls_rs::Group<Config>> {
fn inner(&self) -> std::sync::MutexGuard<'_, mls_rs::Group<UniFFIConfig>> {
self.inner.lock().unwrap()
}

#[cfg(mls_build_async)]
async fn inner(&self) -> tokio::sync::MutexGuard<'_, mls_rs::Group<Config>> {
async fn inner(&self) -> tokio::sync::MutexGuard<'_, mls_rs::Group<UniFFIConfig>> {
self.inner.lock().await
}
}

/// Find the identity for the member with a given index.
fn index_to_identity(
group: &mls_rs::Group<Config>,
group: &mls_rs::Group<UniFFIConfig>,
index: u32,
) -> Result<identity::SigningIdentity, Error> {
let member = group
Expand All @@ -421,6 +436,13 @@ async fn signing_identity_to_identifier(
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[uniffi::export]
impl Group {
/// Write the current state of the group to storage defined by
/// [`ClientConfig::group_state_storage`]
pub async fn write_to_storage(&self) -> Result<(), Error> {
let mut group = self.inner().await;
group.write_to_storage().await.map_err(Into::into)
}

/// Perform a commit of received proposals (or an empty commit).
///
/// TODO: ensure `path_required` is always set in
Expand Down Expand Up @@ -597,6 +619,7 @@ mod sync_tests {
mod async_tests {
use crate::test_utils::run_python;

#[ignore]
#[test]
fn test_simple_scenario() -> Result<(), Box<dyn std::error::Error>> {
run_python(include_str!("../test_bindings/simple_scenario_async.py"))
Expand Down
19 changes: 0 additions & 19 deletions mls-rs-uniffi/test_bindings/simple_scenario_async.py

This file was deleted.

Loading

0 comments on commit 7d8c898

Please sign in to comment.