diff --git a/mls-rs-crypto-awslc/Cargo.toml b/mls-rs-crypto-awslc/Cargo.toml index 1f1b5b83..5f4a4647 100644 --- a/mls-rs-crypto-awslc/Cargo.toml +++ b/mls-rs-crypto-awslc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mls-rs-crypto-awslc" -version = "0.14.0" +version = "0.14.1" edition = "2021" description = "AWS-LC based CryptoProvider for mls-rs" homepage = "https://github.com/awslabs/mls-rs" diff --git a/mls-rs-crypto-awslc/src/aead.rs b/mls-rs-crypto-awslc/src/aead.rs index 71438508..7892c3ed 100644 --- a/mls-rs-crypto-awslc/src/aead.rs +++ b/mls-rs-crypto-awslc/src/aead.rs @@ -9,7 +9,7 @@ use mls_rs_crypto_traits::AeadId; use crate::AwsLcCryptoError; #[derive(Clone, Copy)] -pub struct AwsLcAead(AeadId); +pub struct AwsLcAead(pub(crate) AeadId); impl AwsLcAead { pub fn new(cipher_suite: CipherSuite) -> Option { diff --git a/mls-rs-crypto-awslc/src/ecdsa.rs b/mls-rs-crypto-awslc/src/ecdsa.rs index ad7baa06..9849ec69 100644 --- a/mls-rs-crypto-awslc/src/ecdsa.rs +++ b/mls-rs-crypto-awslc/src/ecdsa.rs @@ -25,7 +25,7 @@ use crate::{ }; #[derive(Clone)] -pub struct AwsLcEcdsa(Curve); +pub struct AwsLcEcdsa(pub(crate) Curve); impl Deref for AwsLcEcdsa { type Target = Curve; diff --git a/mls-rs-crypto-awslc/src/kdf.rs b/mls-rs-crypto-awslc/src/kdf.rs index aaa2397b..9b731376 100644 --- a/mls-rs-crypto-awslc/src/kdf.rs +++ b/mls-rs-crypto-awslc/src/kdf.rs @@ -14,7 +14,7 @@ use mls_rs_crypto_traits::{Hash, KdfId}; use crate::AwsLcCryptoError; #[derive(Clone, Copy)] -pub struct AwsLcHkdf(KdfId); +pub struct AwsLcHkdf(pub(crate) KdfId); impl AwsLcHkdf { pub fn new(cipher_suite: CipherSuite) -> Option { @@ -93,7 +93,7 @@ impl mls_rs_crypto_traits::KdfType for AwsLcHkdf { #[derive(Clone, Copy, Debug)] pub struct AwsLcHash { - algo: &'static digest::Algorithm, + pub(crate) algo: &'static digest::Algorithm, } impl AwsLcHash { diff --git a/mls-rs-crypto-awslc/src/kem/ml_kem.rs b/mls-rs-crypto-awslc/src/kem/ml_kem.rs index 27093a2c..3c2f823d 100644 --- a/mls-rs-crypto-awslc/src/kem/ml_kem.rs +++ b/mls-rs-crypto-awslc/src/kem/ml_kem.rs @@ -19,8 +19,8 @@ use crate::{check_non_null, kdf::AwsLcHkdf, AwsLcCryptoError}; #[derive(Clone)] pub struct MlKemKem { - kdf: AwsLcHkdf, - ml_kem: MlKem, + pub(crate) kdf: AwsLcHkdf, + pub(crate) ml_kem: MlKem, } impl MlKemKem { diff --git a/mls-rs-crypto-awslc/src/lib.rs b/mls-rs-crypto-awslc/src/lib.rs index 2edf8f1d..743719ab 100644 --- a/mls-rs-crypto-awslc/src/lib.rs +++ b/mls-rs-crypto-awslc/src/lib.rs @@ -34,26 +34,31 @@ use mls_rs_core::{ }; use ecdsa::AwsLcEcdsa; -use kdf::{AwsLcHash, AwsLcHkdf}; +use kdf::AwsLcHkdf; use kem::ecdh::Ecdh; use mls_rs_crypto_hpke::{ context::{ContextR, ContextS}, dhkem::DhKem, hpke::{Hpke, HpkeError}, }; -use mls_rs_crypto_traits::{AeadType, Hash, KdfType, KemId}; +use mls_rs_crypto_traits::{AeadId, AeadType, Curve, Hash, KdfId, KdfType, KemId}; use thiserror::Error; use zeroize::Zeroizing; #[cfg(feature = "post-quantum")] -use self::{ - kdf::{shake::AwsLcShake128, Sha3}, - kem::ml_kem::MlKemKem, -}; +use self::{kdf::shake::AwsLcShake128, kem::ml_kem::MlKemKem}; #[cfg(feature = "post-quantum")] use mls_rs_crypto_hpke::kem_combiner::{CombinedKem, XWingSharedSecretHashInput}; +#[cfg(feature = "post-quantum")] +pub use self::kem::ml_kem::MlKem; + +#[cfg(feature = "post-quantum")] +pub use self::kdf::Sha3; + +pub use self::kdf::AwsLcHash; + #[derive(Clone)] pub struct AwsLcCipherSuite { cipher_suite: CipherSuite, @@ -153,6 +158,151 @@ impl AwsLcCryptoProvider { } } +#[derive(Clone, Default)] +pub struct AwsLcCipherSuiteBuilder { + signing: Option, + aead: Option, + kdf: Option, + hpke: Option, + mac_algo: Option, + hash: Option, + fallback_cipher_suite: Option, +} + +impl AwsLcCipherSuiteBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn signing(self, signing: Curve) -> Self { + Self { + signing: Some(AwsLcEcdsa(signing)), + ..self + } + } + + pub fn aead(self, aead: AeadId) -> Self { + Self { + aead: Some(AwsLcAead(aead)), + ..self + } + } + + pub fn kdf(self, kdf: KdfId) -> Self { + Self { + kdf: Some(AwsLcHkdf(kdf)), + ..self + } + } + + pub fn mac_algo(self, mac_algo: hmac::Algorithm) -> Self { + Self { + mac_algo: Some(mac_algo), + ..self + } + } + + pub fn hash(self, hash: AwsLcHash) -> Self { + Self { + hash: Some(hash), + ..self + } + } + + pub fn hpke(self, cipher_suite: CipherSuite) -> Self { + Self { + hpke: classical_hpke(cipher_suite), + ..self + } + } + + pub fn fallback_cipher_suite(self, cipher_suite: CipherSuite) -> Self { + Self { + fallback_cipher_suite: Some(cipher_suite), + ..self + } + } + + #[cfg(feature = "post-quantum")] + pub fn pq_hpke(self, ml_kem: MlKem, kdf: KdfId, aead: AeadId) -> Self { + let ml_kem = MlKemKem { + ml_kem, + kdf: AwsLcHkdf(kdf), + }; + + Self { + hpke: Some(AwsLcHpke::PostQuantum(Hpke::new( + ml_kem, + AwsLcHkdf(kdf), + Some(AwsLcAead(aead)), + ))), + ..self + } + } + + #[cfg(feature = "post-quantum")] + pub fn combined_hpke( + self, + classical_cipher_suite: CipherSuite, + ml_kem: MlKem, + kdf: KdfId, + aead: AeadId, + hash: AwsLcHash, + ) -> Self { + let ml_kem = MlKemKem { + ml_kem, + kdf: AwsLcHkdf(kdf), + }; + + let ecdh = dhkem(classical_cipher_suite); + + let hpke = ecdh.map(|ecdh| { + let kem = CombinedKem::new_xwing(ml_kem, ecdh, hash, AwsLcShake128); + + AwsLcHpke::Combined(Hpke::new(kem, AwsLcHkdf(kdf), Some(AwsLcAead(aead)))) + }); + + Self { hpke, ..self } + } + + pub fn build(self, cipher_suite: CipherSuite) -> Option { + let fallback_cs = self.fallback_cipher_suite.unwrap_or(cipher_suite); + let hpke = self.hpke.or_else(|| classical_hpke(fallback_cs))?; + let kdf = self.kdf.or_else(|| AwsLcHkdf::new(fallback_cs))?; + let aead = self.aead.or_else(|| AwsLcAead::new(fallback_cs))?; + let signing = self.signing.or_else(|| AwsLcEcdsa::new(fallback_cs))?; + + let mac_algo = self.mac_algo.or(match fallback_cs { + CipherSuite::CURVE25519_AES128 + | CipherSuite::CURVE25519_CHACHA + | CipherSuite::P256_AES128 => Some(hmac::HMAC_SHA256), + CipherSuite::P384_AES256 => Some(hmac::HMAC_SHA384), + CipherSuite::P521_AES256 => Some(hmac::HMAC_SHA512), + _ => None, + })?; + + let hash = self.hash.or_else(|| AwsLcHash::new(fallback_cs))?; + + Some(AwsLcCipherSuite { + cipher_suite, + hpke, + aead, + kdf, + signing, + mac_algo, + hash, + }) + } +} + +fn classical_hpke(cipher_suite: CipherSuite) -> Option { + Some(AwsLcHpke::Classical(Hpke::new( + dhkem(cipher_suite)?, + AwsLcHkdf::new(cipher_suite)?, + Some(AwsLcAead::new(cipher_suite)?), + ))) +} + impl CryptoProvider for AwsLcCryptoProvider { type CipherSuiteProvider = AwsLcCipherSuite; diff --git a/mls-rs-crypto-awslc/tests/cipher_suite_builder.rs b/mls-rs-crypto-awslc/tests/cipher_suite_builder.rs new file mode 100644 index 00000000..370e150a --- /dev/null +++ b/mls-rs-crypto-awslc/tests/cipher_suite_builder.rs @@ -0,0 +1,51 @@ +use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider}; +use mls_rs_crypto_awslc::{AwsLcCipherSuiteBuilder, AwsLcHash}; +use mls_rs_crypto_traits::{AeadId, Curve, KdfId}; + +use aws_lc_rs::hmac; + +#[test] +fn custom_cipher_suite() { + let cs = AwsLcCipherSuiteBuilder::new() + .aead(AeadId::Aes128Gcm) + .hash(AwsLcHash::new(CipherSuite::P256_AES128).unwrap()) + .kdf(KdfId::HkdfSha384) + .mac_algo(hmac::HMAC_SHA384) + .signing(Curve::P521) + .hpke(CipherSuite::P521_AES256) + .build(CipherSuite::new(12345)) + .unwrap(); + + let (sk, pk) = cs.kem_derive(b"12345").unwrap(); + let ctxt = cs.hpke_seal(&pk, b"info", Some(b"aad"), b"pt").unwrap(); + + cs.hpke_open(&ctxt, &sk, &pk, b"info", Some(b"aad")) + .unwrap(); +} + +#[cfg(feature = "post-quantum")] +#[test] +fn custom_pq_cipher_suite() { + use mls_rs_crypto_awslc::{MlKem, Sha3}; + + let hash = AwsLcHash::new_sha3(Sha3::SHA3_384).unwrap(); + + let cs = AwsLcCipherSuiteBuilder::new() + .hash(hash) + .combined_hpke( + CipherSuite::CURVE25519_AES128, + MlKem::MlKem1024, + KdfId::HkdfSha384, + AeadId::Aes256Gcm, + hash, + ) + .fallback_cipher_suite(CipherSuite::P521_AES256) + .build(CipherSuite::new(12345)) + .unwrap(); + + let (sk, pk) = cs.kem_derive(b"12345").unwrap(); + let ctxt = cs.hpke_seal(&pk, b"info", Some(b"aad"), b"pt").unwrap(); + + cs.hpke_open(&ctxt, &sk, &pk, b"info", Some(b"aad")) + .unwrap(); +}