diff --git a/Cargo.toml b/Cargo.toml index 8a59b64c..794ddb93 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,10 +17,12 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [features] -default = ["libolm-compat"] +default = ["libolm-compat", "interolm"] js = ["getrandom/js"] strict-signatures = [] libolm-compat = [] +interolm = [] + # The low-level-api feature exposes extra APIs that are only useful in advanced # use cases and require extra care to use. low-level-api = [] @@ -35,7 +37,7 @@ ed25519-dalek = { version = "2.0.0", default-features = false, features = ["rand getrandom = "0.2.10" hkdf = "0.12.3" hmac = "0.12.1" -matrix-pickle = { version = "0.1.1" } +matrix-pickle = { git = "https://github.com/matrix-org/matrix-pickle.git", branch = "dkasak/decode-for-option-rebased" } pkcs7 = "0.4.1" prost = "0.12.1" rand = "0.8.5" @@ -46,6 +48,7 @@ sha2 = "0.10.8" subtle = "2.5.0" thiserror = "1.0.49" x25519-dalek = { version = "2.0.0", features = ["serde", "reusable_secrets", "static_secrets"] } +xeddsa = "1.0.2" zeroize = "1.6.0" [dev-dependencies] diff --git a/README.md b/README.md index 7425db76..a368fbf0 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,11 @@ A Rust implementation of Olm and Megolm vodozemac is a Rust reimplementation of [libolm](https://gitlab.matrix.org/matrix-org/olm), a cryptographic library used for end-to-end encryption in [Matrix](https://matrix.org). At its core, it -is an implementation of the [Olm][olm-docs] and [Megolm][megolm-docs] cryptographic ratchets, -along with a high-level API to easily establish cryptographic communication -channels employing those ratchets with other parties. It also implements some -other miscellaneous cryptographic functionality which is useful for building -Matrix clients, such as [SAS][sas]. +is an implementation of the [Olm][olm-docs] and [Megolm][megolm-docs] +cryptographic ratchets, along with a high-level API to easily establish +cryptographic communication channels employing those ratchets with other +parties. It also implements some other miscellaneous cryptographic +functionality which is useful for building Matrix clients, such as [SAS][sas]. [olm-docs]: diff --git a/src/cipher/key.rs b/src/cipher/key.rs index bcf6fb1e..ebeacb68 100644 --- a/src/cipher/key.rs +++ b/src/cipher/key.rs @@ -34,11 +34,18 @@ struct ExpandedKeys(Box<[u8; 80]>); impl ExpandedKeys { const OLM_HKDF_INFO: &'static [u8] = b"OLM_KEYS"; const MEGOLM_HKDF_INFO: &'static [u8] = b"MEGOLM_KEYS"; + #[cfg(feature = "interolm")] + const INTEROLM_HKDF_INFO: &'static [u8] = b"OLM_KEYS"; fn new(message_key: &[u8; 32]) -> Self { Self::new_helper(message_key, Self::OLM_HKDF_INFO) } + #[cfg(feature = "interolm")] + fn new_interolm(message_key: &[u8; 32]) -> Self { + Self::new_helper(message_key, Self::INTEROLM_HKDF_INFO) + } + fn new_megolm(message_key: &[u8; 128]) -> Self { Self::new_helper(message_key, Self::MEGOLM_HKDF_INFO) } @@ -74,6 +81,13 @@ impl CipherKeys { Self::from_expanded_keys(expanded_keys) } + #[cfg(feature = "interolm")] + pub fn new_interolm(message_key: &[u8; 32]) -> Self { + let expanded_keys = ExpandedKeys::new_interolm(message_key); + + Self::from_expanded_keys(expanded_keys) + } + pub fn new_megolm(message_key: &[u8; 128]) -> Self { let expanded_keys = ExpandedKeys::new_megolm(message_key); diff --git a/src/cipher/mod.rs b/src/cipher/mod.rs index 432ee729..99f7dfeb 100644 --- a/src/cipher/mod.rs +++ b/src/cipher/mod.rs @@ -23,10 +23,11 @@ use aes::{ Aes256, }; use hmac::{digest::MacError, Hmac, Mac as MacT}; -use key::CipherKeys; use sha2::Sha256; use thiserror::Error; +use crate::{cipher::key::CipherKeys, Curve25519PublicKey}; + type Aes256CbcEnc = cbc::Encryptor; type Aes256CbcDec = cbc::Decryptor; type HmacSha256 = Hmac; @@ -77,6 +78,38 @@ impl From<[u8; Mac::TRUNCATED_LEN]> for MessageMac { } } +#[cfg(feature = "interolm")] +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct InterolmMessageMac(pub(crate) [u8; Mac::TRUNCATED_LEN]); + +#[cfg(feature = "interolm")] +impl InterolmMessageMac { + pub fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +#[cfg(feature = "interolm")] +impl From for InterolmMessageMac { + fn from(m: Mac) -> Self { + Self(m.truncate()) + } +} + +#[cfg(feature = "interolm")] +impl From<[u8; Mac::TRUNCATED_LEN]> for InterolmMessageMac { + fn from(m: [u8; Mac::TRUNCATED_LEN]) -> Self { + Self(m) + } +} + +#[cfg(feature = "interolm")] +impl From for MessageMac { + fn from(value: InterolmMessageMac) -> Self { + Self::Truncated(value.0) + } +} + #[derive(Debug, Error)] pub enum DecryptionError { #[error("Failed decrypting, invalid padding")] @@ -99,6 +132,13 @@ impl Cipher { Self { keys } } + #[cfg(feature = "interolm")] + pub fn new_interolm(key: &[u8; 32]) -> Self { + let keys = CipherKeys::new_interolm(key); + + Self { keys } + } + pub fn new_megolm(&key: &[u8; 128]) -> Self { let keys = CipherKeys::new_megolm(&key); @@ -114,7 +154,7 @@ impl Cipher { fn get_hmac(&self) -> HmacSha256 { // We don't use HmacSha256::new() here because it expects a 64-byte - // large HMAC key while the Olm spec defines a 32-byte one instead. + // HMAC key while the Olm spec uses a 32-byte one instead. // // https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md#version-1 HmacSha256::new_from_slice(self.keys.mac_key()).expect("Invalid HMAC key size") @@ -129,10 +169,24 @@ impl Cipher { let mut hmac = self.get_hmac(); hmac.update(message); - let mac_bytes = hmac.finalize().into_bytes(); + let mac = hmac.finalize().into_bytes().into(); + + Mac(mac) + } + + pub fn mac_interolm( + &self, + sender_identity: Curve25519PublicKey, + receiver_identity: Curve25519PublicKey, + message: &[u8], + ) -> Mac { + let mut hmac = self.get_hmac(); - let mut mac = [0u8; 32]; - mac.copy_from_slice(&mac_bytes); + hmac.update(&sender_identity.to_interolm_bytes()); + hmac.update(&receiver_identity.to_interolm_bytes()); + hmac.update(message); + + let mac = hmac.finalize().into_bytes().into(); Mac(mac) } @@ -178,6 +232,22 @@ impl Cipher { hmac.verify_truncated_left(tag) } + #[cfg(not(fuzzing))] + pub fn verify_interolm_mac( + &self, + message: &[u8], + sender_identity: Curve25519PublicKey, + receiver_identity: Curve25519PublicKey, + tag: &[u8], + ) -> Result<(), MacError> { + let mut hmac = self.get_hmac(); + + hmac.update(&sender_identity.to_interolm_bytes()); + hmac.update(&receiver_identity.to_interolm_bytes()); + hmac.update(message); + hmac.verify_truncated_left(tag) + } + /// A verify_mac method that always succeeds. /// /// Useful if we're fuzzing vodozemac, since MAC verification discards a lot @@ -191,4 +261,15 @@ impl Cipher { pub fn verify_truncated_mac(&self, _: &[u8], _: &[u8]) -> Result<(), MacError> { Ok(()) } + + #[cfg(fuzzing)] + pub fn verify_interolm_mac( + &self, + _: &[u8], + _: Curve25519PublicKey, + _: Curve25519PublicKey, + _: &[u8], + ) -> Result<(), MacError> { + Ok(()) + } } diff --git a/src/lib.rs b/src/lib.rs index 606b4d81..cb74635d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -196,7 +196,7 @@ #![deny( clippy::mem_forget, clippy::unwrap_used, - dead_code, + //dead_code, trivial_casts, trivial_numeric_casts, unsafe_code, @@ -211,6 +211,7 @@ mod cipher; mod types; mod utilities; +mod xeddsa; pub mod hazmat; pub mod megolm; @@ -225,6 +226,8 @@ pub use types::{ }; pub use utilities::{base64_decode, base64_encode}; +pub use crate::xeddsa::{SignatureError as XEdDsaSignatureError, XEdDsaSignature}; + /// Error type describing the various ways Vodozemac pickles can fail to be /// decoded. #[derive(Debug, thiserror::Error)] @@ -235,7 +238,7 @@ pub enum PickleError { /// The encrypted pickle could not have been decrypted. #[error("The pickle couldn't be decrypted: {0}")] Decryption(#[from] crate::cipher::DecryptionError), - /// The serialized Vodozemac object couldn't be deserialized. + /// The serialized Vodozemac object couldn't be deserialised. #[error("The pickle couldn't be deserialized: {0}")] Serialization(#[from] serde_json::Error), } diff --git a/src/megolm/inbound_group_session.rs b/src/megolm/inbound_group_session.rs index 233d52c7..916b7a88 100644 --- a/src/megolm/inbound_group_session.rs +++ b/src/megolm/inbound_group_session.rs @@ -99,7 +99,7 @@ pub struct DecryptedMessage { } impl InboundGroupSession { - pub fn new(key: &SessionKey, session_config: SessionConfig) -> Self { + pub fn new(key: &SessionKey, config: SessionConfig) -> Self { let initial_ratchet = Ratchet::from_bytes(key.session_key.ratchet.clone(), key.session_key.ratchet_index); let latest_ratchet = initial_ratchet.clone(); @@ -109,11 +109,11 @@ impl InboundGroupSession { latest_ratchet, signing_key: key.session_key.signing_key, signing_key_verified: true, - config: session_config, + config, } } - pub fn import(session_key: &ExportedSessionKey, session_config: SessionConfig) -> Self { + pub fn import(session_key: &ExportedSessionKey, config: SessionConfig) -> Self { let initial_ratchet = Ratchet::from_bytes(session_key.ratchet.clone(), session_key.ratchet_index); let latest_ratchet = initial_ratchet.clone(); @@ -123,7 +123,7 @@ impl InboundGroupSession { latest_ratchet, signing_key: session_key.signing_key, signing_key_verified: false, - config: session_config, + config, } } diff --git a/src/olm/account/fallback_keys.rs b/src/olm/account/fallback_keys.rs deleted file mode 100644 index 49a9c0e3..00000000 --- a/src/olm/account/fallback_keys.rs +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2021 Damir Jelić, Denis Kasak -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use serde::{Deserialize, Serialize}; - -use crate::{ - types::{Curve25519SecretKey, KeyId}, - Curve25519PublicKey, -}; - -#[derive(Serialize, Deserialize, Clone)] -pub(super) struct FallbackKey { - pub key_id: KeyId, - pub key: Curve25519SecretKey, - pub published: bool, -} - -impl FallbackKey { - fn new(key_id: KeyId) -> Self { - let key = Curve25519SecretKey::new(); - - Self { key_id, key, published: false } - } - - pub fn public_key(&self) -> Curve25519PublicKey { - Curve25519PublicKey::from(&self.key) - } - - pub fn secret_key(&self) -> &Curve25519SecretKey { - &self.key - } - - pub fn key_id(&self) -> KeyId { - self.key_id - } - - pub fn mark_as_published(&mut self) { - self.published = true; - } - - pub fn published(&self) -> bool { - self.published - } -} - -#[derive(Serialize, Deserialize, Clone)] -pub(super) struct FallbackKeys { - pub key_id: u64, - pub fallback_key: Option, - pub previous_fallback_key: Option, -} - -impl FallbackKeys { - pub fn new() -> Self { - Self { key_id: 0, fallback_key: None, previous_fallback_key: None } - } - - pub fn mark_as_published(&mut self) { - if let Some(f) = self.fallback_key.as_mut() { - f.mark_as_published() - } - } - - pub fn generate_fallback_key(&mut self) -> Option { - let key_id = KeyId(self.key_id); - self.key_id += 1; - - let ret = self.previous_fallback_key.take().map(|f| f.public_key()); - - self.previous_fallback_key = self.fallback_key.take(); - self.fallback_key = Some(FallbackKey::new(key_id)); - - ret - } - - pub fn get_secret_key(&self, public_key: &Curve25519PublicKey) -> Option<&Curve25519SecretKey> { - self.fallback_key - .as_ref() - .filter(|f| f.public_key() == *public_key) - .or_else(|| { - self.previous_fallback_key.as_ref().filter(|f| f.public_key() == *public_key) - }) - .map(|f| f.secret_key()) - } - - pub fn forget_previous_fallback_key(&mut self) -> Option { - self.previous_fallback_key.take() - } - - pub fn unpublished_fallback_key(&self) -> Option<&FallbackKey> { - self.fallback_key.as_ref().filter(|f| !f.published()) - } -} - -#[cfg(test)] -mod test { - use super::FallbackKeys; - - #[test] - fn fallback_key_fetching() { - let err = "Missing fallback key"; - let mut fallback_keys = FallbackKeys::new(); - - fallback_keys.generate_fallback_key(); - - let public_key = fallback_keys.fallback_key.as_ref().expect(err).public_key(); - let secret_bytes = fallback_keys.fallback_key.as_ref().expect(err).key.to_bytes(); - - let fetched_key = fallback_keys.get_secret_key(&public_key).expect(err); - - assert_eq!(secret_bytes, fetched_key.to_bytes()); - - fallback_keys.generate_fallback_key(); - - let fetched_key = fallback_keys.get_secret_key(&public_key).expect(err); - assert_eq!(secret_bytes, fetched_key.to_bytes()); - - let public_key = fallback_keys.fallback_key.as_ref().expect(err).public_key(); - let secret_bytes = fallback_keys.fallback_key.as_ref().expect(err).key.to_bytes(); - - let fetched_key = fallback_keys.get_secret_key(&public_key).expect(err); - - assert_eq!(secret_bytes, fetched_key.to_bytes()); - } -} diff --git a/src/olm/account/mod.rs b/src/olm/account/mod.rs index 673c3340..bf808b28 100644 --- a/src/olm/account/mod.rs +++ b/src/olm/account/mod.rs @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod fallback_keys; mod one_time_keys; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use rand::thread_rng; use serde::{Deserialize, Serialize}; @@ -23,16 +22,17 @@ use thiserror::Error; use x25519_dalek::ReusableSecret; pub use self::one_time_keys::OneTimeKeyGenerationResult; -use self::{ - fallback_keys::FallbackKeys, - one_time_keys::{OneTimeKeys, OneTimeKeysPickle}, -}; +use self::one_time_keys::{OneTimeKeys, OneTimeKeysPickle}; use super::{ - messages::PreKeyMessage, - session::{DecryptionError, Session}, + messages::AnyPreKeyMessage, + session::{ + ratchet::{RatchetKey, RemoteRatchetKey}, + DecryptionError, Session, + }, + session_config::Version, session_keys::SessionKeys, shared_secret::{RemoteShared3DHSecret, Shared3DHSecret}, - SessionConfig, + AnyNormalMessage, SessionConfig, }; use crate::{ types::{ @@ -40,6 +40,7 @@ use crate::{ Ed25519Keypair, Ed25519KeypairPickle, Ed25519PublicKey, KeyId, }, utilities::{pickle, unpickle}, + xeddsa::XEdDsaSignature, Ed25519Signature, PickleError, }; @@ -54,6 +55,9 @@ pub enum SessionCreationError { /// already been used up. #[error("The pre-key message contained an unknown one-time key: {0}")] MissingOneTimeKey(Curve25519PublicKey), + /// The pre-key message contained a reference to an unknown pre-key ID. + #[error("The pre-key message contained an unknown one-time key ID: {0}")] + MissingOneTimeKeyID(KeyId), /// The pre-key message contains a curve25519 identity key that doesn't /// match to the identity key that was given. #[error( @@ -91,15 +95,21 @@ pub struct Account { /// A permanent Ed25519 key used for signing. Also known as the fingerprint /// key. signing_key: Ed25519Keypair, - /// The permanent Curve25519 key used for 3DH. Also known as the sender key - /// or the identity key. + /// The permanent Curve25519 key used for (X)3DH. Also known as the sender + /// key or the identity key. + /// + /// In case of interolm, it can also be used for signing using + /// XEdDSA. diffie_hellman_key: Curve25519Keypair, - /// The ephemeral (one-time) Curve25519 keys used as part of the 3DH. - one_time_keys: OneTimeKeys, + /// The ephemeral (one-time) Curve25519 keys used as part of the (X)3DH. + one_time_keys: OneTimeKeys<{ 100 * PUBLIC_MAX_ONE_TIME_KEYS }>, /// The ephemeral Curve25519 keys used in lieu of a one-time key as part of /// the 3DH, in case we run out of those. We keep track of both the current /// and the previous fallback key in any given moment. - fallback_keys: FallbackKeys, + /// + /// When using interolm, we do X3DH instead of 3DH, so the fallback + /// key is always used and treated as a signed pre-key. + fallback_keys: OneTimeKeys<10>, } impl Account { @@ -109,7 +119,7 @@ impl Account { signing_key: Ed25519Keypair::new(), diffie_hellman_key: Curve25519Keypair::new(), one_time_keys: OneTimeKeys::new(), - fallback_keys: FallbackKeys::new(), + fallback_keys: OneTimeKeys::new(), } } @@ -133,6 +143,12 @@ impl Account { self.signing_key.sign(message.as_bytes()) } + /// Sign the given message using our Curve25519 key using XEdDSA. + #[cfg(feature = "interolm")] + pub fn sign_interolm(&self, message: &[u8]) -> XEdDsaSignature { + self.diffie_hellman_key.secret_key.sign(message) + } + /// Get the maximum number of one-time keys the client should keep on the /// server. /// @@ -156,8 +172,9 @@ impl Account { pub fn create_outbound_session( &self, session_config: SessionConfig, - identity_key: Curve25519PublicKey, - one_time_key: Curve25519PublicKey, + remote_identity_key: Curve25519PublicKey, + signed_pre_key: Curve25519PublicKey, + one_time_key: Option, ) -> Session { let rng = thread_rng(); @@ -165,19 +182,28 @@ impl Account { let public_base_key = Curve25519PublicKey::from(&base_key); let shared_secret = Shared3DHSecret::new( + &session_config, self.diffie_hellman_key.secret_key(), &base_key, - &identity_key, - &one_time_key, + &remote_identity_key, + &signed_pre_key, + one_time_key.as_ref(), ); let session_keys = SessionKeys { identity_key: self.curve25519_key(), + other_identity_key: remote_identity_key, base_key: public_base_key, + signed_pre_key, one_time_key, }; - Session::new(session_config, shared_secret, session_keys) + match session_config.version { + Version::V1 | Version::V2 => Session::new(session_config, shared_secret, session_keys), + Version::VInterolm(_) => { + Session::new_interolm(session_config, shared_secret, session_keys) + } + } } fn find_one_time_key(&self, public_key: &Curve25519PublicKey) -> Option<&Curve25519SecretKey> { @@ -186,6 +212,16 @@ impl Account { .or_else(|| self.fallback_keys.get_secret_key(public_key)) } + fn find_public_fallback_key(&self, key_id: KeyId) -> Option { + self.fallback_keys + .get_public_key_by_id(&key_id) + .or_else(|| self.one_time_keys.get_public_key_by_id(&key_id)) + } + + fn find_public_one_time_key_by_id(&self, key_id: KeyId) -> Option { + self.one_time_keys.get_public_key_by_id(&key_id) + } + /// Remove a one-time key that has previously been published but not yet /// used. /// @@ -208,72 +244,139 @@ impl Account { self.one_time_keys.remove_secret_key(&public_key) } - /// Create a [`Session`] from the given pre-key message and identity key + /// Create a [`Session`] from the given pre-key message and identity key. pub fn create_inbound_session( &mut self, their_identity_key: Curve25519PublicKey, - pre_key_message: &PreKeyMessage, + pre_key_message: &AnyPreKeyMessage, ) -> Result { - if their_identity_key != pre_key_message.identity_key() { + let (config, initiator_identity_key, base_key, signed_pre_key, one_time_key, message) = + match pre_key_message { + AnyPreKeyMessage::Native(msg) => { + let config = if msg.message.mac_truncated() { + SessionConfig::version_1() + } else { + SessionConfig::version_2() + }; + + ( + config, + msg.session_keys.identity_key, + msg.session_keys.base_key, + // The OTK is signed in Olm and thus treated as a signed pre-key here. + msg.session_keys.one_time_key, + // There is no unsigned one-time key in Olm. + None, + AnyNormalMessage::Native(&msg.message), + ) + } + AnyPreKeyMessage::Interolm(msg) => { + let signed_pre_key_id = KeyId(msg.signed_pre_key_id.into()); + let otk_id = msg.pre_key_id.map(|i| KeyId(i.into())); + let registration_id = msg.registration_id; + + let their_signed_prekey = self + .find_public_fallback_key(signed_pre_key_id) + .ok_or(SessionCreationError::MissingOneTimeKeyID(signed_pre_key_id))?; + + let their_otk = otk_id.and_then(|id| self.find_public_one_time_key_by_id(id)); + + ( + SessionConfig::version_interolm(registration_id, signed_pre_key_id, otk_id), + msg.identity_key, + msg.base_key, + their_signed_prekey, + their_otk, + AnyNormalMessage::Interolm(&msg.message), + ) + } + }; + + if their_identity_key != initiator_identity_key { Err(SessionCreationError::MismatchedIdentityKey( their_identity_key, - pre_key_message.identity_key(), + initiator_identity_key, )) } else { // Find the matching private part of the OTK that the message claims // was used to create the session that encrypted it. - let public_otk = pre_key_message.one_time_key(); - let private_otk = self - .find_one_time_key(&public_otk) - .ok_or(SessionCreationError::MissingOneTimeKey(public_otk))?; + let public_otk = one_time_key; + let private_otk = public_otk.and_then(|p| self.find_one_time_key(&p)); + + // Find the matching private part of the signed pre-key ("fallback" + // key) that the message claims was used to create the session that + // encrypted it. + let public_signed_pre_key = signed_pre_key; + let private_signed_pre_key = self + .find_one_time_key(&public_signed_pre_key) + .ok_or(SessionCreationError::MissingOneTimeKey(public_signed_pre_key))?; // Construct a 3DH shared secret from the various curve25519 keys. let shared_secret = RemoteShared3DHSecret::new( + &config, self.diffie_hellman_key.secret_key(), + private_signed_pre_key, private_otk, - &pre_key_message.identity_key(), - &pre_key_message.base_key(), + &initiator_identity_key, + &base_key, ); // These will be used to uniquely identify the Session. let session_keys = SessionKeys { - identity_key: pre_key_message.identity_key(), - base_key: pre_key_message.base_key(), - one_time_key: pre_key_message.one_time_key(), - }; - - let config = if pre_key_message.message.mac_truncated() { - SessionConfig::version_1() - } else { - SessionConfig::version_2() + identity_key: initiator_identity_key, + base_key, + other_identity_key: self.curve25519_key(), + signed_pre_key, + one_time_key, }; // Create a Session, AKA a double ratchet, this one will have an // inactive sending chain until we decide to encrypt a message. - let mut session = Session::new_remote( - config, - shared_secret, - pre_key_message.message.ratchet_key, - session_keys, - ); + let mut session: Session = match config.version { + Version::V1 | Version::V2 => { + Session::new_remote(&config, shared_secret, message.ratchet_key(), session_keys) + } + #[cfg(feature = "interolm")] + Version::VInterolm(..) => Session::new_interolm_remote( + &config, + shared_secret, + RemoteRatchetKey(message.ratchet_key()), + session_keys, + RatchetKey(private_signed_pre_key.clone()), + ), + }; // Decrypt the message to check if the Session is actually valid. - let plaintext = session.decrypt_decoded(&pre_key_message.message)?; + let plaintext = session.decrypt_decoded(message)?; + + // We only drop the one-time key now, which is why we can't use a + // one-time key type that consumes `self`. If we didn't do it like + // this, someone could maliciously pretend to use up our one-time + // key and make us drop the private part. Unsuspecting users that + // actually try to use such an one-time key won't be able to + // communicate with us. This is strictly worse than the one-time key + // exhaustion scenario. + + // In native Olm/3DH, the pre-key is always signed and there's no + // unsigned pre-key. So the key to drop is actually in the + // `signed_pre_key` field. This oddity is a result of us growing + // libsignal interoperability. + let key_to_drop = match config.version { + Version::V1 | Version::V2 => Some(session_keys.signed_pre_key), + #[cfg(feature = "interolm")] + Version::VInterolm(..) => session_keys.one_time_key, + }; - // We only drop the one-time key now, this is why we can't use a - // one-time key type that takes `self`. If we didn't do this, - // someone could maliciously pretend to use up our one-time key and - // make us drop the private part. Unsuspecting users that actually - // try to use such an one-time key won't be able to commnuicate with - // us. This is strictly worse than the one-time key exhaustion - // scenario. - self.remove_one_time_key_helper(pre_key_message.one_time_key()); + if let Some(otk) = key_to_drop { + self.remove_one_time_key_helper(otk); + } Ok(InboundCreationResult { session, plaintext }) } } /// Generates the supplied number of one time keys. + /// /// Returns the public parts of the one-time keys that were created and /// discarded. /// @@ -293,7 +396,7 @@ impl Account { /// /// The one-time keys should be published to a server and marked as /// published using the `mark_keys_as_published()` method. - pub fn one_time_keys(&self) -> HashMap { + pub fn one_time_keys(&self) -> BTreeMap { self.one_time_keys .unpublished_public_keys .iter() @@ -301,16 +404,16 @@ impl Account { .collect() } + pub fn one_time_keys_private(&self) -> HashMap { + self.one_time_keys.private_keys.iter().map(|(key_id, key)| (*key_id, key.clone())).collect() + } + /// Generate a single new fallback key. /// /// The fallback key will be used by other users to establish a `Session` if /// all the one-time keys on the server have been used up. - /// - /// Returns the public Curve25519 key of the *previous* fallback key, that - /// is, the one that will get removed from the [`Account`] when this method - /// is called. This return value is mostly useful for logging purposes. - pub fn generate_fallback_key(&mut self) -> Option { - self.fallback_keys.generate_fallback_key() + pub fn generate_fallback_key(&mut self) { + self.fallback_keys.generate(1); } /// Get the currently unpublished fallback key. @@ -318,20 +421,12 @@ impl Account { /// The fallback key should be published just like the one-time keys, after /// it has been successfully published it needs to be marked as published /// using the `mark_keys_as_published()` method as well. - pub fn fallback_key(&self) -> HashMap { - let fallback_key = self.fallback_keys.unpublished_fallback_key(); - - if let Some(fallback_key) = fallback_key { - HashMap::from([(fallback_key.key_id(), fallback_key.public_key())]) - } else { - HashMap::new() - } - } - - /// The `Account` stores at most two private parts of the fallback key. This - /// method lets us forget the previously used fallback key. - pub fn forget_fallback_key(&mut self) -> bool { - self.fallback_keys.forget_previous_fallback_key().is_some() + pub fn fallback_keys(&self) -> BTreeMap { + self.fallback_keys + .unpublished_public_keys + .iter() + .map(|(key_id, key)| (*key_id, *key)) + .collect() } /// Mark all currently unpublished one-time and fallback keys as published. @@ -347,7 +442,7 @@ impl Account { signing_key: self.signing_key.clone().into(), diffie_hellman_key: self.diffie_hellman_key.clone().into(), one_time_keys: self.one_time_keys.clone().into(), - fallback_keys: self.fallback_keys.clone(), + fallback_keys: self.fallback_keys.clone().into(), } } @@ -442,7 +537,7 @@ pub struct AccountPickle { signing_key: Ed25519KeypairPickle, diffie_hellman_key: Curve25519KeypairPickle, one_time_keys: OneTimeKeysPickle, - fallback_keys: FallbackKeys, + fallback_keys: OneTimeKeysPickle, } /// A format suitable for serialization which implements [`serde::Serialize`] @@ -470,7 +565,7 @@ impl From for Account { signing_key: pickle.signing_key.into(), diffie_hellman_key: pickle.diffie_hellman_key.into(), one_time_keys: pickle.one_time_keys.into(), - fallback_keys: pickle.fallback_keys, + fallback_keys: pickle.fallback_keys.into(), } } } @@ -480,11 +575,7 @@ mod libolm { use matrix_pickle::{Decode, DecodeError, Encode, EncodeError}; use zeroize::Zeroize; - use super::{ - fallback_keys::{FallbackKey, FallbackKeys}, - one_time_keys::OneTimeKeys, - Account, - }; + use super::{one_time_keys::OneTimeKeys, Account}; use crate::{ types::{Curve25519Keypair, Curve25519SecretKey}, utilities::LibolmEd25519Keypair, @@ -500,16 +591,6 @@ mod libolm { private_key: Box<[u8; 32]>, } - impl From<&OneTimeKey> for FallbackKey { - fn from(key: &OneTimeKey) -> Self { - FallbackKey { - key_id: KeyId(key.key_id.into()), - key: Curve25519SecretKey::from_slice(&key.private_key), - published: key.published, - } - } - } - #[derive(Debug, Zeroize)] #[zeroize(drop)] struct FallbackKeysArray { @@ -571,46 +652,35 @@ mod libolm { next_key_id: u32, } - impl TryFrom<&FallbackKey> for OneTimeKey { - type Error = (); - - fn try_from(key: &FallbackKey) -> Result { - Ok(OneTimeKey { - key_id: key.key_id.0.try_into().map_err(|_| ())?, - published: key.published(), - public_key: key.public_key().to_bytes(), - private_key: key.secret_key().to_bytes(), - }) - } - } - impl From<&Account> for Pickle { fn from(account: &Account) -> Self { - let one_time_keys: Vec<_> = account - .one_time_keys - .secret_keys() - .iter() - .filter_map(|(key_id, secret_key)| { + let try_into_libolm_otk = + |(key_id, secret_key): (&KeyId, &Curve25519SecretKey)| -> Option { Some(OneTimeKey { key_id: key_id.0.try_into().ok()?, published: account.one_time_keys.is_secret_key_published(key_id), public_key: Curve25519PublicKey::from(secret_key).to_bytes(), private_key: secret_key.to_bytes(), }) - }) + }; + + let one_time_keys: Vec<_> = account + .one_time_keys + .secret_keys() + .iter() + .filter_map(try_into_libolm_otk) .collect(); + let mut published_fallback_keys = account + .fallback_keys + .secret_keys() + .iter() + .rev() + .filter(|(id, _)| account.fallback_keys.is_secret_key_published(id)); + let fallback_keys = FallbackKeysArray { - fallback_key: account - .fallback_keys - .fallback_key - .as_ref() - .and_then(|f| f.try_into().ok()), - previous_fallback_key: account - .fallback_keys - .previous_fallback_key - .as_ref() - .and_then(|f| f.try_into().ok()), + fallback_key: published_fallback_keys.next().and_then(try_into_libolm_otk), + previous_fallback_key: published_fallback_keys.next().and_then(try_into_libolm_otk), }; let next_key_id = account.one_time_keys.next_key_id.try_into().unwrap_or_default(); @@ -644,20 +714,22 @@ mod libolm { one_time_keys.next_key_id = pickle.next_key_id.into(); - let fallback_keys = FallbackKeys { - key_id: pickle - .fallback_keys - .fallback_key - .as_ref() - .map(|k| k.key_id.wrapping_add(1)) - .unwrap_or(0) as u64, - fallback_key: pickle.fallback_keys.fallback_key.as_ref().map(|k| k.into()), - previous_fallback_key: pickle - .fallback_keys - .previous_fallback_key - .as_ref() - .map(|k| k.into()), - }; + let mut fallback_keys = OneTimeKeys::new(); + + if let Some(key) = &pickle.fallback_keys.fallback_key { + let secret_key = Curve25519SecretKey::from_slice(&key.private_key); + let key_id = KeyId(key.key_id.into()); + fallback_keys.insert_secret_key(key_id, secret_key, key.published); + + let next_key_id = key.key_id.wrapping_add(1); + fallback_keys.next_key_id = next_key_id.into(); + } + + if let Some(key) = &pickle.fallback_keys.previous_fallback_key { + let secret_key = Curve25519SecretKey::from_slice(&key.private_key); + let key_id = KeyId(key.key_id.into()); + fallback_keys.insert_secret_key(key_id, secret_key, key.published); + } Ok(Self { signing_key: Ed25519Keypair::from_expanded_key( @@ -682,7 +754,7 @@ mod test { use crate::{ cipher::Mac, olm::{ - messages::{OlmMessage, PreKeyMessage}, + messages::{AnyNativeMessage, PreKeyMessage}, AccountPickle, }, run_corpus, Curve25519PublicKey as PublicKey, Ed25519Signature, @@ -712,8 +784,12 @@ mod test { let identity_keys = bob.parsed_identity_keys(); let curve25519_key = PublicKey::from_base64(identity_keys.curve25519())?; let one_time_key = PublicKey::from_base64(&one_time_key)?; - let mut alice_session = - alice.create_outbound_session(SessionConfig::version_1(), curve25519_key, one_time_key); + let mut alice_session = alice.create_outbound_session( + SessionConfig::version_1(), + curve25519_key, + one_time_key, + None, + ); let message = "It's a secret to everybody"; let olm_message: LibolmOlmMessage = alice_session.encrypt(message).into(); @@ -771,6 +847,7 @@ mod test { .next() .context("Failed getting bob's OTK, which should never happen here.")? .1, + None, ); bob.mark_keys_as_published(); @@ -778,13 +855,13 @@ mod test { let message = "It's a secret to everybody"; let olm_message = alice_session.encrypt(message); - if let OlmMessage::PreKey(m) = olm_message { - assert_eq!(m.session_keys(), alice_session.session_keys()); + if let AnyNativeMessage::PreKey(m) = olm_message { + assert_eq!(m.session_keys(), alice_session.session_keys().into()); let InboundCreationResult { session: mut bob_session, plaintext } = - bob.create_inbound_session(alice.curve25519_key(), &m)?; + bob.create_inbound_session(alice.curve25519_key(), &m.clone().into())?; assert_eq!(alice_session.session_id(), bob_session.session_id()); - assert_eq!(m.session_keys(), bob_session.session_keys()); + assert_eq!(m.session_keys(), bob_session.session_keys().into()); assert_eq!(message.as_bytes(), plaintext); @@ -835,11 +912,12 @@ mod test { let identity_key = PublicKey::from_base64(alice.parsed_identity_keys().curve25519())?; - let InboundCreationResult { session, plaintext } = if let OlmMessage::PreKey(m) = &message { - bob.create_inbound_session(identity_key, m)? - } else { - bail!("Got invalid message type from olm_rs {:?}", message); - }; + let InboundCreationResult { session, plaintext } = + if let AnyNativeMessage::PreKey(m) = message { + bob.create_inbound_session(identity_key, &m.clone().into())? + } else { + bail!("Got invalid message type from olm_rs {:?}", message); + }; assert_eq!(alice_session.session_id(), session.session_id()); assert!(bob.one_time_keys.private_keys.is_empty()); @@ -857,8 +935,9 @@ mod test { bob.generate_fallback_key(); let one_time_key = - bob.fallback_key().values().next().cloned().expect("Didn't find a valid fallback key"); + bob.fallback_keys().values().next().cloned().expect("Didn't find a valid fallback key"); assert!(bob.one_time_keys.private_keys.is_empty()); + assert_eq!(bob.fallback_keys.private_keys.len(), 1); let alice_session = alice.create_outbound_session( &bob.curve25519_key().to_base64(), @@ -870,13 +949,17 @@ mod test { let message = alice_session.encrypt(text).into(); let identity_key = PublicKey::from_base64(alice.parsed_identity_keys().curve25519())?; - if let OlmMessage::PreKey(m) = &message { + if let AnyNativeMessage::PreKey(m) = message { let InboundCreationResult { session, plaintext } = - bob.create_inbound_session(identity_key, m)?; + bob.create_inbound_session(identity_key, &m.clone().into())?; - assert_eq!(m.session_keys(), session.session_keys()); + assert_eq!(m.session_keys(), session.session_keys().into()); assert_eq!(alice_session.session_id(), session.session_id()); - assert!(bob.fallback_keys.fallback_key.is_some()); + assert_eq!( + bob.fallback_keys.private_keys.len(), + 1, + "We should still have one fallback key" + ); assert_eq!(text.as_bytes(), plaintext); } else { @@ -946,7 +1029,7 @@ mod test { assert_eq!( olm_fallback_key.curve25519(), unpickled - .fallback_key() + .fallback_keys() .values() .next() .expect("We should have a fallback key") @@ -987,11 +1070,12 @@ mod test { SessionConfig::default(), alice.curve25519_key(), *alice.one_time_keys().values().next().expect("Should have one-time key"), + None, ); let message = session.encrypt("Test"); - if let OlmMessage::PreKey(m) = message { + if let AnyNativeMessage::PreKey(m) = message { let mut message = m.to_bytes(); let message_len = message.len(); @@ -1002,7 +1086,7 @@ mod test { let message = PreKeyMessage::try_from(message)?; - match alice.create_inbound_session(malory.curve25519_key(), &message) { + match alice.create_inbound_session(malory.curve25519_key(), &message.into()) { Err(SessionCreationError::Decryption(_)) => {} e => bail!("Expected a decryption error, got {:?}", e), } @@ -1017,6 +1101,28 @@ mod test { } } + #[test] + #[cfg(feature = "interolm")] + fn test_signing() { + let account = Account::new(); + let message = "sahasrahla"; + + let signature = account.sign_interolm(message.as_bytes()); + account + .diffie_hellman_key + .public_key + .verify_signature(message.as_bytes(), signature) + .expect("The signature should be valid"); + + let corrupted_message = message.to_owned() + "!"; + + account + .diffie_hellman_key + .public_key + .verify_signature(corrupted_message.as_bytes(), signature) + .expect_err("The signature should be invalid"); + } + #[test] #[cfg(feature = "libolm-compat")] fn fuzz_corpus_unpickling() { diff --git a/src/olm/account/one_time_keys.rs b/src/olm/account/one_time_keys.rs index 1ec2a9f2..4cbde5ec 100644 --- a/src/olm/account/one_time_keys.rs +++ b/src/olm/account/one_time_keys.rs @@ -16,7 +16,6 @@ use std::collections::{BTreeMap, HashMap}; use serde::{Deserialize, Serialize}; -use super::PUBLIC_MAX_ONE_TIME_KEYS; use crate::{ types::{Curve25519SecretKey, KeyId}, Curve25519PublicKey, @@ -25,11 +24,14 @@ use crate::{ #[derive(Serialize, Deserialize, Clone)] #[serde(from = "OneTimeKeysPickle")] #[serde(into = "OneTimeKeysPickle")] -pub(super) struct OneTimeKeys { +pub(super) struct OneTimeKeys { pub next_key_id: u64, pub unpublished_public_keys: BTreeMap, + // XXX: This is now a bit of a mess. We can probably rationalize away some + // of these maps. pub private_keys: BTreeMap, pub key_ids_by_key: HashMap, + pub keys_by_key_id: HashMap, } /// The result type for the one-time key generation operation. @@ -41,8 +43,8 @@ pub struct OneTimeKeyGenerationResult { pub removed: Vec, } -impl OneTimeKeys { - const MAX_ONE_TIME_KEYS: usize = 100 * PUBLIC_MAX_ONE_TIME_KEYS; +impl OneTimeKeys { + const MAX_ONE_TIME_KEYS: usize = N; pub fn new() -> Self { Self { @@ -50,6 +52,7 @@ impl OneTimeKeys { unpublished_public_keys: Default::default(), private_keys: Default::default(), key_ids_by_key: Default::default(), + keys_by_key_id: Default::default(), } } @@ -57,6 +60,10 @@ impl OneTimeKeys { self.unpublished_public_keys.clear(); } + pub fn get_public_key_by_id(&self, key_id: &KeyId) -> Option { + self.keys_by_key_id.get(key_id).copied() + } + pub fn get_secret_key(&self, public_key: &Curve25519PublicKey) -> Option<&Curve25519SecretKey> { self.key_ids_by_key.get(public_key).and_then(|key_id| self.private_keys.get(key_id)) } @@ -66,8 +73,10 @@ impl OneTimeKeys { public_key: &Curve25519PublicKey, ) -> Option { self.key_ids_by_key.remove(public_key).and_then(|key_id| { - self.unpublished_public_keys.remove(&key_id); - self.private_keys.remove(&key_id) + self.keys_by_key_id.remove(&key_id).and_then(|_| { + self.unpublished_public_keys.remove(&key_id); + self.private_keys.remove(&key_id) + }) }) } @@ -84,6 +93,7 @@ impl OneTimeKeys { let public_key = if let Some(private_key) = self.private_keys.remove(&key_id) { let public_key = Curve25519PublicKey::from(&private_key); self.key_ids_by_key.remove(&public_key); + self.keys_by_key_id.remove(&key_id); Some(public_key) } else { @@ -104,6 +114,7 @@ impl OneTimeKeys { self.private_keys.insert(key_id, key); self.key_ids_by_key.insert(public_key, key_id); + self.keys_by_key_id.insert(key_id, public_key); if !published { self.unpublished_public_keys.insert(key_id, public_key); @@ -153,12 +164,14 @@ pub(super) struct OneTimeKeysPickle { private_keys: BTreeMap, } -impl From for OneTimeKeys { +impl From for OneTimeKeys { fn from(pickle: OneTimeKeysPickle) -> Self { let mut key_ids_by_key = HashMap::new(); + let mut keys_by_key_id = HashMap::new(); for (k, v) in pickle.private_keys.iter() { key_ids_by_key.insert(v.into(), *k); + keys_by_key_id.insert(*k, v.into()); } Self { @@ -166,12 +179,13 @@ impl From for OneTimeKeys { unpublished_public_keys: pickle.public_keys.iter().map(|(&k, &v)| (k, v)).collect(), private_keys: pickle.private_keys, key_ids_by_key, + keys_by_key_id, } } } -impl From for OneTimeKeysPickle { - fn from(keys: OneTimeKeys) -> Self { +impl From> for OneTimeKeysPickle { + fn from(keys: OneTimeKeys) -> Self { OneTimeKeysPickle { next_key_id: keys.next_key_id, public_keys: keys.unpublished_public_keys.iter().map(|(&k, &v)| (k, v)).collect(), @@ -187,19 +201,20 @@ mod test { #[test] fn store_limit() { - let mut store = OneTimeKeys::new(); + const MAX_ONE_TIME_KEYS: usize = 50; + let mut store: OneTimeKeys = OneTimeKeys::new(); assert!(store.private_keys.is_empty()); - store.generate(OneTimeKeys::MAX_ONE_TIME_KEYS); - assert_eq!(store.private_keys.len(), OneTimeKeys::MAX_ONE_TIME_KEYS); - assert_eq!(store.unpublished_public_keys.len(), OneTimeKeys::MAX_ONE_TIME_KEYS); - assert_eq!(store.key_ids_by_key.len(), OneTimeKeys::MAX_ONE_TIME_KEYS); + store.generate(MAX_ONE_TIME_KEYS); + assert_eq!(store.private_keys.len(), MAX_ONE_TIME_KEYS); + assert_eq!(store.unpublished_public_keys.len(), MAX_ONE_TIME_KEYS); + assert_eq!(store.key_ids_by_key.len(), MAX_ONE_TIME_KEYS); store.mark_as_published(); assert!(store.unpublished_public_keys.is_empty()); - assert_eq!(store.private_keys.len(), OneTimeKeys::MAX_ONE_TIME_KEYS); - assert_eq!(store.key_ids_by_key.len(), OneTimeKeys::MAX_ONE_TIME_KEYS); + assert_eq!(store.private_keys.len(), MAX_ONE_TIME_KEYS); + assert_eq!(store.key_ids_by_key.len(), MAX_ONE_TIME_KEYS); let oldest_key_id = store.private_keys.keys().next().copied().expect("Couldn't get the first key ID"); @@ -207,8 +222,8 @@ mod test { store.generate(10); assert_eq!(store.unpublished_public_keys.len(), 10); - assert_eq!(store.private_keys.len(), OneTimeKeys::MAX_ONE_TIME_KEYS); - assert_eq!(store.key_ids_by_key.len(), OneTimeKeys::MAX_ONE_TIME_KEYS); + assert_eq!(store.private_keys.len(), MAX_ONE_TIME_KEYS); + assert_eq!(store.key_ids_by_key.len(), MAX_ONE_TIME_KEYS); let oldest_key_id = store.private_keys.keys().next().copied().expect("Couldn't get the first key ID"); diff --git a/src/olm/messages/message.rs b/src/olm/messages/message.rs index 802f52d1..14063bd4 100644 --- a/src/olm/messages/message.rs +++ b/src/olm/messages/message.rs @@ -17,8 +17,9 @@ use std::fmt::Debug; use prost::Message as ProstMessage; use serde::{Deserialize, Serialize}; +use super::AnyNormalMessage; use crate::{ - cipher::{Mac, MessageMac}, + cipher::{InterolmMessageMac, Mac, MessageMac}, utilities::{base64_decode, base64_encode, extract_mac, VarInt}, Curve25519PublicKey, DecodeError, }; @@ -167,6 +168,12 @@ impl Message { } } +impl<'a> From<&'a Message> for AnyNormalMessage<'a> { + fn from(m: &'a Message) -> Self { + Self::Native(m) + } +} + impl Serialize for Message { fn serialize(&self, serializer: S) -> Result where @@ -254,6 +261,124 @@ impl Debug for Message { } } +#[cfg(feature = "interolm")] +#[derive(Clone, PartialEq, Eq)] +pub struct InterolmMessage { + pub(crate) version: u8, + pub(crate) ratchet_key: Curve25519PublicKey, + pub(crate) counter: u32, + pub(crate) previous_counter: u32, + pub(crate) ciphertext: Vec, + pub(crate) mac: InterolmMessageMac, +} + +#[cfg(feature = "interolm")] +impl InterolmMessage { + const VERSION: u8 = 51; + + pub(crate) fn new( + ratchet_key: Curve25519PublicKey, + counter: u32, + previous_counter: u32, + ciphertext: Vec, + ) -> Self { + Self { + version: Self::VERSION, + ratchet_key, + counter, + previous_counter, + ciphertext, + mac: InterolmMessageMac([0u8; Mac::TRUNCATED_LEN]), + } + } + + pub fn from_bytes(bytes: &[u8]) -> Result { + let version = *bytes.first().ok_or(DecodeError::MissingVersion)?; + + if version != Self::VERSION { + Err(DecodeError::InvalidVersion(Self::VERSION, version)) + } else if bytes.len() < Mac::TRUNCATED_LEN + 2 { + Err(DecodeError::MessageTooShort(bytes.len())) + } else { + let decoded = InterolmProtoBufMessage::decode( + bytes + .get(1..bytes.len() - Mac::TRUNCATED_LEN) + .ok_or_else(|| DecodeError::MessageTooShort(bytes.len()))?, + )?; + + let mac_slice = &bytes[bytes.len() - Mac::TRUNCATED_LEN..]; + + if mac_slice.len() != Mac::TRUNCATED_LEN { + Err(DecodeError::InvalidMacLength(Mac::TRUNCATED_LEN, mac_slice.len())) + } else { + let mac = InterolmMessageMac(mac_slice.try_into().expect("Can never happen")); + let ratchet_key = Curve25519PublicKey::from_slice(&decoded.ratchet_key)?; + let counter = decoded.counter; + let previous_counter = decoded.previous_counter; + let ciphertext = decoded.ciphertext; + + Ok(InterolmMessage { + version, + ratchet_key, + counter, + previous_counter, + ciphertext, + mac, + }) + } + } + } + + pub fn to_bytes(&self) -> Vec { + let mut message = self.to_mac_bytes(); + message.extend(self.mac.as_bytes()); + + message + } + + pub fn from_base64(message: &str) -> Result { + let decoded = base64_decode(message)?; + Self::from_bytes(&decoded) + } + + pub fn to_base64(&self) -> String { + base64_encode(self.to_bytes()) + } + + pub fn to_mac_bytes(&self) -> Vec { + InterolmProtoBufMessage { + ratchet_key: self.ratchet_key.to_interolm_bytes().to_vec(), + counter: self.counter, + previous_counter: self.previous_counter, + ciphertext: self.ciphertext.clone(), + } + .encode_manual() + } + + pub(crate) fn set_mac(&mut self, mac: Mac) { + self.mac.0 = mac.truncate(); + } +} + +impl<'a> From<&'a InterolmMessage> for AnyNormalMessage<'a> { + fn from(m: &'a InterolmMessage) -> Self { + Self::Interolm(m) + } +} + +impl Debug for InterolmMessage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Self { version, ratchet_key, counter, previous_counter, ciphertext: _, mac: _ } = self; + + f.debug_struct("InterolmMessage") + .field("version", version) + .field("ratchet_key", ratchet_key) + .field("counter", counter) + .field("previous_counter", previous_counter) + .finish() + } +} + #[derive(ProstMessage, PartialEq, Eq)] struct ProtoBufMessage { #[prost(bytes, tag = "1")] @@ -264,6 +389,49 @@ struct ProtoBufMessage { ciphertext: Vec, } +#[cfg(feature = "interolm")] +#[derive(PartialEq, Eq, ProstMessage)] +struct InterolmProtoBufMessage { + #[prost(bytes, tag = "1")] + ratchet_key: Vec, + #[prost(uint32, tag = "2")] + counter: u32, + #[prost(uint32, tag = "3")] + previous_counter: u32, + #[prost(bytes, tag = "4")] + ciphertext: Vec, +} + +#[cfg(feature = "interolm")] +impl InterolmProtoBufMessage { + const RATCHET_TAG: &'static [u8; 1] = b"\x0A"; + const INDEX_TAG: &'static [u8; 1] = b"\x10"; + const PREVIOUS_INDEX_TAG: &'static [u8; 1] = b"\x18"; + const CIPHER_TAG: &'static [u8; 1] = b"\x22"; + + fn encode_manual(&self) -> Vec { + let counter = self.counter.to_var_int(); + let previous_counter = self.previous_counter.to_var_int(); + let ratchet_len = self.ratchet_key.len().to_var_int(); + let ciphertext_len = self.ciphertext.len().to_var_int(); + + [ + [InterolmMessage::VERSION].as_ref(), + Self::RATCHET_TAG.as_ref(), + &ratchet_len, + &self.ratchet_key, + Self::INDEX_TAG.as_ref(), + &counter, + Self::PREVIOUS_INDEX_TAG.as_ref(), + &previous_counter, + Self::CIPHER_TAG.as_ref(), + &ciphertext_len, + &self.ciphertext, + ] + .concat() + } +} + impl ProtoBufMessage { const RATCHET_TAG: &'static [u8; 1] = b"\x0A"; const INDEX_TAG: &'static [u8; 1] = b"\x10"; @@ -292,7 +460,7 @@ impl ProtoBufMessage { #[cfg(test)] mod test { use super::Message; - use crate::Curve25519PublicKey; + use crate::{olm::InterolmMessage, Curve25519PublicKey}; #[test] fn encode() { @@ -309,4 +477,32 @@ mod test { assert_eq!(encoded.to_mac_bytes(), message.as_ref()); assert_eq!(encoded.to_bytes(), message_mac.as_ref()); } + + #[test] + fn interolm_re_encode() { + let message = &[ + 51, 10, 33, 5, 190, 36, 85, 201, 27, 92, 134, 42, 25, 250, 119, 63, 8, 146, 237, 196, + 47, 189, 116, 179, 143, 41, 171, 119, 96, 182, 250, 30, 175, 30, 104, 26, 16, 2, 24, 0, + 34, 64, 186, 87, 78, 176, 178, 217, 29, 185, 227, 41, 209, 55, 212, 24, 24, 96, 51, + 126, 53, 57, 42, 104, 132, 165, 184, 183, 167, 231, 84, 9, 117, 73, 131, 95, 7, 215, + 133, 34, 111, 40, 21, 115, 74, 154, 253, 184, 187, 237, 133, 32, 231, 2, 74, 56, 216, + 17, 200, 91, 74, 55, 33, 193, 89, 193, 35, 196, 248, 166, 3, 98, 194, 158, + ]; + + let decoded = InterolmMessage::from_bytes(message) + .expect("We should be able to decode the Interolm message"); + + let encoded = decoded.to_bytes(); + + assert_eq!( + message.as_slice(), + &encoded, + "Re-encoding the message should yield the same bytes" + ); + + let expected_mac_bytes = &message[0..message.len() - 8]; + let mac_bytes = decoded.to_mac_bytes(); + + assert_eq!(expected_mac_bytes, mac_bytes, "The MAC bytes should be correctly gathered"); + } } diff --git a/src/olm/messages/mod.rs b/src/olm/messages/mod.rs index 33471177..1870c898 100644 --- a/src/olm/messages/mod.rs +++ b/src/olm/messages/mod.rs @@ -15,23 +15,104 @@ mod message; mod pre_key; -pub use message::Message; -pub use pre_key::PreKeyMessage; +pub use message::{InterolmMessage, Message}; +pub use pre_key::{InterolmPreKeyMessage, PreKeyMessage}; use serde::{Deserialize, Serialize}; -use crate::DecodeError; +use crate::{Curve25519PublicKey, DecodeError}; -/// Enum over the different Olm message types. +/// A type covering all possible messages supported by vodozemac. +/// +/// Includes both normal and pre-key messages of both the native and Interolm +/// message variants. +#[allow(dead_code)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AnyMessage { + Native(AnyNativeMessage), + #[cfg(feature = "interolm")] + Interolm(AnyInterolmMessage), +} + +impl From for AnyMessage { + fn from(value: AnyNativeMessage) -> Self { + Self::Native(value) + } +} + +impl From for AnyMessage { + fn from(value: AnyInterolmMessage) -> Self { + Self::Interolm(value) + } +} + +/// A type covering all possible "normal" (non-prekey) messages supported by +/// vodozemac. +/// +/// Includes both the native and Interolm message variants. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AnyNormalMessage<'a> { + Native(&'a Message), + #[cfg(feature = "interolm")] + Interolm(&'a InterolmMessage), +} + +impl AnyNormalMessage<'_> { + pub(crate) fn ratchet_key(&self) -> Curve25519PublicKey { + match self { + AnyNormalMessage::Native(m) => m.ratchet_key, + AnyNormalMessage::Interolm(m) => m.ratchet_key, + } + } + + pub(crate) fn chain_index(&self) -> u64 { + match self { + AnyNormalMessage::Native(m) => m.chain_index, + AnyNormalMessage::Interolm(m) => m.counter.into(), + } + } +} + +/// A type covering all pre-key messages supported by vodozemac. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AnyPreKeyMessage { + Native(PreKeyMessage), + #[cfg(feature = "interolm")] + Interolm(InterolmPreKeyMessage), +} + +impl From for AnyPreKeyMessage { + fn from(value: PreKeyMessage) -> Self { + AnyPreKeyMessage::Native(value) + } +} + +#[cfg(feature = "interolm")] +impl From for AnyPreKeyMessage { + fn from(value: InterolmPreKeyMessage) -> Self { + AnyPreKeyMessage::Interolm(value) + } +} + +#[cfg(feature = "interolm")] +impl From for AnyInterolmMessage { + fn from(value: InterolmPreKeyMessage) -> Self { + Self::PreKey(value) + } +} + +/// A type representing the native Olm message types. /// /// Olm uses two types of messages. The underlying transport protocol must /// provide a means for recipients to distinguish between them. /// -/// [`OlmMessage`] provides [`Serialize`] and [`Deserialize`] implementations -/// that are compatible with [Matrix]. +/// [`AnyNativeMessage`] provides [`Serialize`] and [`Deserialize`] +/// implementations that are compatible with [Matrix]. +/// +/// The type is called "native" because we also support Interolm messages. /// /// [Matrix]: https://spec.matrix.org/latest/client-server-api/#molmv1curve25519-aes-sha2 #[derive(Debug, Clone, PartialEq, Eq)] -pub enum OlmMessage { +pub enum AnyNativeMessage { /// A normal message, contains only the ciphertext and metadata to decrypt /// it. Normal(Message), @@ -42,13 +123,13 @@ pub enum OlmMessage { PreKey(PreKeyMessage), } -impl From for OlmMessage { +impl From for AnyNativeMessage { fn from(m: Message) -> Self { Self::Normal(m) } } -impl From for OlmMessage { +impl From for AnyNativeMessage { fn from(m: PreKeyMessage) -> Self { Self::PreKey(m) } @@ -62,7 +143,7 @@ struct MessageSerdeHelper { ciphertext: String, } -impl Serialize for OlmMessage { +impl Serialize for AnyNativeMessage { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, @@ -75,17 +156,17 @@ impl Serialize for OlmMessage { } } -impl<'de> Deserialize<'de> for OlmMessage { +impl<'de> Deserialize<'de> for AnyNativeMessage { fn deserialize>(d: D) -> Result { let value = MessageSerdeHelper::deserialize(d)?; - OlmMessage::from_parts(value.message_type, &value.ciphertext) + AnyNativeMessage::from_parts(value.message_type, &value.ciphertext) .map_err(serde::de::Error::custom) } } -impl OlmMessage { - /// Create a `OlmMessage` from a message type and a ciphertext. +impl AnyNativeMessage { + /// Create an `AnyNativeMessage` from a message type and a ciphertext. pub fn from_parts(message_type: usize, ciphertext: &str) -> Result { match message_type { 0 => Ok(Self::PreKey(PreKeyMessage::try_from(ciphertext)?)), @@ -97,27 +178,27 @@ impl OlmMessage { /// Get the message as a byte array. pub fn message(&self) -> &[u8] { match self { - OlmMessage::Normal(m) => &m.ciphertext, - OlmMessage::PreKey(m) => &m.message.ciphertext, + AnyNativeMessage::Normal(m) => &m.ciphertext, + AnyNativeMessage::PreKey(m) => &m.message.ciphertext, } } /// Get the type of the message. pub fn message_type(&self) -> MessageType { match self { - OlmMessage::Normal(_) => MessageType::Normal, - OlmMessage::PreKey(_) => MessageType::PreKey, + AnyNativeMessage::Normal(_) => MessageType::Normal, + AnyNativeMessage::PreKey(_) => MessageType::PreKey, } } - /// Convert the `OlmMessage` into a message type, and base64 encoded message - /// tuple. + /// Convert the `AnyNativeMessage` into a message type, and base64 encoded + /// message tuple. pub fn to_parts(self) -> (usize, String) { let message_type = self.message_type(); match self { - OlmMessage::Normal(m) => (message_type.into(), m.to_base64()), - OlmMessage::PreKey(m) => (message_type.into(), m.to_base64()), + AnyNativeMessage::Normal(m) => (message_type.into(), m.to_base64()), + AnyNativeMessage::PreKey(m) => (message_type.into(), m.to_base64()), } } } @@ -149,11 +230,24 @@ impl From for usize { } } +#[cfg(feature = "interolm")] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AnyInterolmMessage { + /// A normal message, contains only the ciphertext and metadata to decrypt + /// it. + Normal(InterolmMessage), + /// A pre-key message, contains metadata to establish a [`Session`] as well + /// as a [`Message`]. + /// + /// [`Session`]: crate::olm::Session + PreKey(InterolmPreKeyMessage), +} + #[cfg(test)] use olm_rs::session::OlmMessage as LibolmMessage; #[cfg(test)] -impl From for OlmMessage { +impl From for AnyNativeMessage { fn from(other: LibolmMessage) -> Self { let (message_type, ciphertext) = other.to_tuple(); @@ -162,13 +256,17 @@ impl From for OlmMessage { } #[cfg(test)] -impl From for LibolmMessage { - fn from(value: OlmMessage) -> LibolmMessage { +impl From for LibolmMessage { + fn from(value: AnyNativeMessage) -> LibolmMessage { match value { - OlmMessage::Normal(m) => LibolmMessage::from_type_and_ciphertext(1, m.to_base64()) - .expect("Can't create a valid libolm message"), - OlmMessage::PreKey(m) => LibolmMessage::from_type_and_ciphertext(0, m.to_base64()) - .expect("Can't create a valid libolm pre-key message"), + AnyNativeMessage::Normal(m) => { + LibolmMessage::from_type_and_ciphertext(1, m.to_base64()) + .expect("Can't create a valid libolm message") + } + AnyNativeMessage::PreKey(m) => { + LibolmMessage::from_type_and_ciphertext(0, m.to_base64()) + .expect("Can't create a valid libolm pre-key message") + } } } } @@ -180,7 +278,7 @@ mod tests { use serde_json::json; use super::*; - use crate::run_corpus; + use crate::{run_corpus, utilities::base64_decode, Curve25519PublicKey}; const PRE_KEY_MESSAGE: &str = "AwoghAEuxPZ+w7M3pgUae4tDNiggUpOsQ/zci457VAti\ AEYSIO3xOKRDBWKicIfxjSmYCYZ9DD4RMLjvvclbMlE5\ @@ -210,6 +308,31 @@ mod tests { ); } + #[test] + fn from_interolm() { + let message = + "MwgCEiEF/VRCSPW3XOxQK75pnA18atUmaj4KSP5E3Fhk8QZMdkAaIQVzGcWwnUJF3Y83c3E7V/B1\ + sdAdPO0Igal5I2ak4xw9fCJCMwohBQH58JyqI+8NqoaTYKB/4h4GCtiXpRvg+WLm6JTgRsNgEAAY\ + ACIQ2N/SJfeTaikQb8DmRWja6Vkzmhm1yBq8KAEwAQ=="; + + let identity_key = + Curve25519PublicKey::from_base64("BXMZxbCdQkXdjzdzcTtX8HWx0B087QiBqXkjZqTjHD18") + .expect("The type-prefixed Curve25519 can be decoded"); + + let parsed = InterolmPreKeyMessage::from_base64(message) + .expect("We can parse Interolm pre-key messages"); + + assert_eq!( + identity_key, parsed.identity_key, + "The identity key from the message matches the static identity key" + ); + + let bytes = base64_decode(message).unwrap(); + let encoded = parsed.to_bytes(); + + assert_eq!(bytes, encoded); + } + #[test] fn from_json() -> Result<()> { let value = json!({ @@ -217,8 +340,8 @@ mod tests { "body": PRE_KEY_MESSAGE, }); - let message: OlmMessage = serde_json::from_value(value.clone())?; - assert_matches!(message, OlmMessage::PreKey(_)); + let message: AnyNativeMessage = serde_json::from_value(value.clone())?; + assert_matches!(message, AnyNativeMessage::PreKey(_)); let serialized = serde_json::to_value(message)?; assert_eq!(value, serialized, "The serialization cycle isn't a noop"); @@ -228,8 +351,8 @@ mod tests { "body": MESSAGE, }); - let message: OlmMessage = serde_json::from_value(value.clone())?; - assert_matches!(message, OlmMessage::Normal(_)); + let message: AnyNativeMessage = serde_json::from_value(value.clone())?; + assert_matches!(message, AnyNativeMessage::Normal(_)); let serialized = serde_json::to_value(message)?; assert_eq!(value, serialized, "The serialization cycle isn't a noop"); @@ -239,8 +362,8 @@ mod tests { #[test] fn from_parts() -> Result<()> { - let message = OlmMessage::from_parts(0, PRE_KEY_MESSAGE)?; - assert_matches!(message, OlmMessage::PreKey(_)); + let message = AnyNativeMessage::from_parts(0, PRE_KEY_MESSAGE)?; + assert_matches!(message, AnyNativeMessage::PreKey(_)); assert_eq!( message.message_type(), MessageType::PreKey, @@ -249,7 +372,7 @@ mod tests { assert_eq!(message.to_parts(), (0, PRE_KEY_MESSAGE.to_string()), "Roundtrip not identity."); - let message = OlmMessage::from_parts(1, MESSAGE)?; + let message = AnyNativeMessage::from_parts(1, MESSAGE)?; assert_eq!( message.message_type(), MessageType::Normal, @@ -257,7 +380,7 @@ mod tests { ); assert_eq!(message.to_parts(), (1, MESSAGE.to_string()), "Roundtrip not identity."); - OlmMessage::from_parts(3, PRE_KEY_MESSAGE) + AnyNativeMessage::from_parts(3, PRE_KEY_MESSAGE) .expect_err("Unknown message types can't be parsed"); Ok(()) diff --git a/src/olm/messages/pre_key.rs b/src/olm/messages/pre_key.rs index 409ae9c4..2cfd97aa 100644 --- a/src/olm/messages/pre_key.rs +++ b/src/olm/messages/pre_key.rs @@ -17,9 +17,9 @@ use std::fmt::Debug; use prost::Message as ProstMessage; use serde::{Deserialize, Serialize}; -use super::Message; +use super::{message::InterolmMessage, Message}; use crate::{ - olm::SessionKeys, + olm::{session_config::InterolmSessionMetadata, OlmSessionKeys, SessionKeys}, utilities::{base64_decode, base64_encode}, Curve25519PublicKey, DecodeError, }; @@ -32,7 +32,7 @@ use crate::{ /// [`Session`]: crate::olm::Session #[derive(Clone, Debug, PartialEq, Eq)] pub struct PreKeyMessage { - pub(crate) session_keys: SessionKeys, + pub(crate) session_keys: OlmSessionKeys, pub(crate) message: Message, } @@ -70,7 +70,7 @@ impl PreKeyMessage { /// can be used to retrieve individual keys from this collection. /// /// [`Session`]: crate::olm::Session - pub fn session_keys(&self) -> SessionKeys { + pub fn session_keys(&self) -> OlmSessionKeys { self.session_keys } @@ -157,7 +157,7 @@ impl PreKeyMessage { PreKeyMessage::new(session_keys, message) } - pub(crate) fn new(session_keys: SessionKeys, message: Message) -> Self { + pub(crate) fn new(session_keys: OlmSessionKeys, message: Message) -> Self { Self { session_keys, message } } } @@ -213,13 +213,98 @@ impl TryFrom<&[u8]> for PreKeyMessage { let message = decoded.message.try_into()?; - let session_keys = SessionKeys { one_time_key, identity_key, base_key }; + let session_keys = OlmSessionKeys { one_time_key, identity_key, base_key }; Ok(Self { session_keys, message }) } } } +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct InterolmPreKeyMessage { + pub(crate) registration_id: u32, + pub(crate) pre_key_id: Option, + pub(crate) signed_pre_key_id: u32, + pub(crate) identity_key: Curve25519PublicKey, + pub(crate) base_key: Curve25519PublicKey, + pub(crate) message: InterolmMessage, +} + +impl InterolmPreKeyMessage { + const VERSION: u8 = 51; + + pub fn new( + session_keys: SessionKeys, + metadata: InterolmSessionMetadata, + message: InterolmMessage, + ) -> Self { + let registration_id = metadata.registration_id; + let signed_pre_key_id = metadata + .signed_pre_key_id + .0 + .try_into() + .expect("Key IDs will always be < 2^32 for Interolm"); + let pre_key_id = metadata + .one_time_key_id + .map(|id| id.0.try_into().expect("Key IDs will always be < 2^32 for Interolm")); + let identity_key = session_keys.identity_key; + let base_key = session_keys.base_key; + + Self { registration_id, pre_key_id, signed_pre_key_id, base_key, identity_key, message } + } + + pub fn from_base64(message: &str) -> Result { + let decoded = base64_decode(message)?; + Self::from_bytes(&decoded) + } + + pub fn to_base64(&self) -> String { + base64_encode(self.to_bytes()) + } + + pub fn from_bytes(message: &[u8]) -> Result { + let version = *message.first().ok_or(DecodeError::MissingVersion)?; + + if version != Self::VERSION { + Err(DecodeError::InvalidVersion(Self::VERSION, version)) + } else { + let decoded = InterolmProtoBufPrekeyMessage::decode(&message[1..message.len()])?; + let base_key = Curve25519PublicKey::from_slice(&decoded.base_key)?; + let identity_key = Curve25519PublicKey::from_slice(&decoded.identity_key)?; + let message = InterolmMessage::from_bytes(&decoded.message)?; + + Ok(Self { + registration_id: decoded.registration_id, + pre_key_id: decoded.pre_key_id, + signed_pre_key_id: decoded.signed_pre_key_id, + identity_key, + base_key, + message, + }) + } + } + + pub fn to_bytes(&self) -> Vec { + let message = InterolmProtoBufPrekeyMessage { + registration_id: self.registration_id, + pre_key_id: self.pre_key_id, + signed_pre_key_id: self.signed_pre_key_id, + base_key: self.base_key.to_interolm_bytes().to_vec(), + identity_key: self.identity_key.to_interolm_bytes().to_vec(), + message: self.message.to_bytes(), + }; + + let mut bytes = Vec::with_capacity(1 + message.encoded_len()); + bytes.push(Self::VERSION); + + message + .encode(&mut bytes) + .expect("Couldn't encode a pre-key Interolm message message into protobuf"); + + bytes + } +} + #[derive(Clone, ProstMessage)] struct ProtoBufPreKeyMessage { #[prost(bytes, tag = "1")] @@ -231,3 +316,19 @@ struct ProtoBufPreKeyMessage { #[prost(bytes, tag = "4")] message: Vec, } + +#[derive(Clone, ProstMessage)] +pub struct InterolmProtoBufPrekeyMessage { + #[prost(uint32, tag = "5")] + registration_id: u32, + #[prost(uint32, optional, tag = "1")] + pre_key_id: Option, + #[prost(uint32, tag = "6")] + signed_pre_key_id: u32, + #[prost(bytes, tag = "2")] + base_key: Vec, + #[prost(bytes, tag = "3")] + identity_key: Vec, + #[prost(bytes, tag = "4")] + message: Vec, +} diff --git a/src/olm/mod.rs b/src/olm/mod.rs index 16e68ca5..97f8c8f5 100644 --- a/src/olm/mod.rs +++ b/src/olm/mod.rs @@ -50,7 +50,7 @@ //! //! ```rust //! use anyhow::Result; -//! use vodozemac::olm::{Account, InboundCreationResult, OlmMessage, SessionConfig}; +//! use vodozemac::olm::{Account, InboundCreationResult, AnyNativeMessage, SessionConfig}; //! //! fn main() -> Result<()> { //! let alice = Account::new(); @@ -60,15 +60,15 @@ //! let bob_otk = *bob.one_time_keys().values().next().unwrap(); //! //! let mut alice_session = alice -//! .create_outbound_session(SessionConfig::version_2(), bob.curve25519_key(), bob_otk); +//! .create_outbound_session(SessionConfig::version_2(), bob.curve25519_key(), bob_otk, None); //! //! bob.mark_keys_as_published(); //! //! let message = "Keep it between us, OK?"; //! let alice_msg = alice_session.encrypt(message); //! -//! if let OlmMessage::PreKey(m) = alice_msg.clone() { -//! let result = bob.create_inbound_session(alice.curve25519_key(), &m)?; +//! if let AnyNativeMessage::PreKey(m) = alice_msg.clone() { +//! let result = bob.create_inbound_session(alice.curve25519_key(), &m.into())?; //! //! let mut bob_session = result.session; //! let what_bob_received = result.plaintext; @@ -92,10 +92,10 @@ //! ## Sending messages //! //! To encrypt a message, just call `Session::encrypt(msg_content)`. This will -//! either produce an `OlmMessage::PreKey(..)` or `OlmMessage::Normal(..)` -//! depending on whether the session is fully established. A session is fully -//! established once you receive (and decrypt) at least one message from the -//! other side. +//! either produce an `AnyNativeMessage::PreKey(..)` or +//! `AnyNativeMessage::Normal(..)` depending on whether the session is fully +//! established. A session is fully established once you receive (and decrypt) +//! at least one message from the other side. mod account; mod messages; @@ -108,7 +108,10 @@ pub use account::{ Account, AccountPickle, IdentityKeys, InboundCreationResult, OneTimeKeyGenerationResult, SessionCreationError, }; -pub use messages::{Message, MessageType, OlmMessage, PreKeyMessage}; +pub use messages::{ + AnyInterolmMessage, AnyMessage, AnyNativeMessage, AnyNormalMessage, AnyPreKeyMessage, + InterolmMessage, InterolmPreKeyMessage, Message, MessageType, PreKeyMessage, +}; pub use session::{ratchet::RatchetPublicKey, DecryptionError, Session, SessionPickle}; -pub use session_config::SessionConfig; -pub use session_keys::SessionKeys; +pub use session_config::{InterolmSessionMetadata, SessionConfig}; +pub use session_keys::{OlmSessionKeys, SessionKeys}; diff --git a/src/olm/session/double_ratchet.rs b/src/olm/session/double_ratchet.rs index d76820dc..d48b5ae7 100644 --- a/src/olm/session/double_ratchet.rs +++ b/src/olm/session/double_ratchet.rs @@ -15,13 +15,18 @@ use serde::{Deserialize, Serialize}; use super::{ - chain_key::ChainKey, + chain_key::{ChainKey, RemoteChainKey}, message_key::MessageKey, - ratchet::{Ratchet, RatchetPublicKey, RemoteRatchetKey}, + ratchet::{Ratchet, RatchetKey, RatchetPublicKey, RemoteRatchetKey}, receiver_chain::ReceiverChain, root_key::{RemoteRootKey, RootKey}, }; -use crate::olm::{messages::Message, shared_secret::Shared3DHSecret}; +use crate::olm::{ + messages::Message, + session_config::SessionCreator, + shared_secret::{RemoteShared3DHSecret, Shared3DHSecret}, + InterolmMessage, SessionConfig, SessionKeys, +}; #[derive(Serialize, Deserialize, Clone)] #[serde(transparent)] @@ -37,10 +42,10 @@ impl DoubleRatchet { } } - pub fn next_message_key(&mut self) -> MessageKey { + pub fn next_message_key(&mut self, config: &SessionConfig) -> MessageKey { match &mut self.inner { DoubleRatchetState::Inactive(ratchet) => { - let mut ratchet = ratchet.activate(); + let mut ratchet = ratchet.activate(config); let message_key = ratchet.next_message_key(); self.inner = DoubleRatchetState::Active(ratchet); @@ -51,16 +56,33 @@ impl DoubleRatchet { } } - pub fn encrypt(&mut self, plaintext: &[u8]) -> Message { - self.next_message_key().encrypt(plaintext) + pub fn encrypt(&mut self, config: &SessionConfig, plaintext: &[u8]) -> Message { + self.next_message_key(config).encrypt(plaintext) + } + + pub fn encrypt_truncated_mac(&mut self, config: &SessionConfig, plaintext: &[u8]) -> Message { + self.next_message_key(config).encrypt_truncated_mac(plaintext) } - pub fn encrypt_truncated_mac(&mut self, plaintext: &[u8]) -> Message { - self.next_message_key().encrypt_truncated_mac(plaintext) + #[cfg(feature = "interolm")] + pub fn encrypt_interolm( + &mut self, + config: &SessionConfig, + session_creator: SessionCreator, + session_keys: &SessionKeys, + previous_counter: u32, + plaintext: &[u8], + ) -> InterolmMessage { + self.next_message_key(config).encrypt_interolm( + session_keys, + session_creator, + previous_counter, + plaintext, + ) } - pub fn active(shared_secret: Shared3DHSecret) -> Self { - let (root_key, chain_key) = shared_secret.expand(); + pub fn active(config: &SessionConfig, shared_secret: Shared3DHSecret) -> Self { + let (root_key, chain_key) = shared_secret.expand(config); let root_key = RootKey::new(root_key); let chain_key = ChainKey::new(chain_key); @@ -73,35 +95,77 @@ impl DoubleRatchet { Self { inner: ratchet.into() } } - #[cfg(feature = "libolm-compat")] - pub fn from_ratchet_and_chain_key(ratchet: Ratchet, chain_key: ChainKey) -> Self { - Self { - inner: ActiveDoubleRatchet { - active_ratchet: ratchet, - symmetric_key_ratchet: chain_key, - } - .into(), - } - } - pub fn inactive(root_key: RemoteRootKey, ratchet_key: RemoteRatchetKey) -> Self { let ratchet = InactiveDoubleRatchet { root_key, ratchet_key }; Self { inner: ratchet.into() } } - pub fn advance(&mut self, ratchet_key: RemoteRatchetKey) -> (DoubleRatchet, ReceiverChain) { + #[cfg(feature = "interolm")] + pub fn active_interolm( + config: &SessionConfig, + shared_secret: Shared3DHSecret, + their_ratchet_key: RemoteRatchetKey, + ) -> (Self, ReceiverChain) { + // Interolm considers the second item of this KDF expansion to be the receiver + // chain key, and therefore the ratchet is created in the inactive + // state. This is different from Olm where the ratchet starts in the + // active state since we derive the sender chain key directly from the + // shared secret. Therefore when talking to an Interolm implementation, + // to obtain an active ratchet, we start off in the inactive state and + // then immediately advance a step. + let (remote_root_key, remote_chain_key) = shared_secret.expand(config); + let remote_root_key = RemoteRootKey::new(remote_root_key); + let remote_chain_key = RemoteChainKey::new(remote_chain_key); + let receiver_chain = ReceiverChain::new(their_ratchet_key, remote_chain_key); + + let inactive_ratchet = + InactiveDoubleRatchet { root_key: remote_root_key, ratchet_key: their_ratchet_key }; + let active_ratchet = inactive_ratchet.activate(config); + + let dh_ratchet = Self { inner: active_ratchet.into() }; + + (dh_ratchet, receiver_chain) + } + + #[cfg(feature = "interolm")] + pub fn inactive_interolm( + config: &SessionConfig, + shared_secret: RemoteShared3DHSecret, + our_ratchet_key: RatchetKey, + their_ratchet_key: RemoteRatchetKey, + ) -> (Self, ReceiverChain) { + let (root_key, chain_key) = shared_secret.expand(config); + + let root_key = RootKey::new(root_key); + let chain_key = ChainKey::new(chain_key); + + let ratchet = ActiveDoubleRatchet { + active_ratchet: Ratchet::new_with_ratchet_key(root_key, our_ratchet_key), + symmetric_key_ratchet: chain_key, + }; + + let (inner_ratchet, receiver_chain) = ratchet.advance(config, their_ratchet_key); + + (Self { inner: inner_ratchet.into() }, receiver_chain) + } + + pub fn advance( + &mut self, + config: &SessionConfig, + ratchet_key: RemoteRatchetKey, + ) -> (DoubleRatchet, ReceiverChain) { let (ratchet, receiver_chain) = match &self.inner { - DoubleRatchetState::Active(r) => r.advance(ratchet_key), + DoubleRatchetState::Active(r) => r.advance(config, ratchet_key), DoubleRatchetState::Inactive(r) => { - let ratchet = r.activate(); + let ratchet = r.activate(config); // Advancing an inactive ratchet shouldn't be possible since the // other side did not yet receive our new ratchet key. // // This will likely end up in a decryption error but for // consistency sake and avoiding the leakage of our internal // state it's better to error out there. - let ret = ratchet.advance(ratchet_key); + let ret = ratchet.advance(config, ratchet_key); self.inner = ratchet.into(); @@ -111,6 +175,17 @@ impl DoubleRatchet { (Self { inner: DoubleRatchetState::Inactive(ratchet) }, receiver_chain) } + + #[cfg(feature = "libolm-compat")] + pub fn from_ratchet_and_chain_key(ratchet: Ratchet, chain_key: ChainKey) -> Self { + Self { + inner: ActiveDoubleRatchet { + active_ratchet: ratchet, + symmetric_key_ratchet: chain_key, + } + .into(), + } + } } #[derive(Serialize, Deserialize, Clone)] @@ -140,8 +215,8 @@ struct InactiveDoubleRatchet { } impl InactiveDoubleRatchet { - fn activate(&self) -> ActiveDoubleRatchet { - let (root_key, chain_key, ratchet_key) = self.root_key.advance(&self.ratchet_key); + fn activate(&self, config: &SessionConfig) -> ActiveDoubleRatchet { + let (root_key, chain_key, ratchet_key) = self.root_key.advance(config, &self.ratchet_key); let active_ratchet = Ratchet::new_with_ratchet_key(root_key, ratchet_key); ActiveDoubleRatchet { active_ratchet, symmetric_key_ratchet: chain_key } @@ -155,8 +230,12 @@ struct ActiveDoubleRatchet { } impl ActiveDoubleRatchet { - fn advance(&self, ratchet_key: RemoteRatchetKey) -> (InactiveDoubleRatchet, ReceiverChain) { - let (root_key, remote_chain) = self.active_ratchet.advance(ratchet_key); + fn advance( + &self, + config: &SessionConfig, + ratchet_key: RemoteRatchetKey, + ) -> (InactiveDoubleRatchet, ReceiverChain) { + let (root_key, remote_chain) = self.active_ratchet.advance(config, ratchet_key); let ratchet = InactiveDoubleRatchet { root_key, ratchet_key }; let receiver_chain = ReceiverChain::new(ratchet_key, remote_chain); diff --git a/src/olm/session/message_key.rs b/src/olm/session/message_key.rs index 8dfaff55..7d940c17 100644 --- a/src/olm/session/message_key.rs +++ b/src/olm/session/message_key.rs @@ -19,8 +19,8 @@ use zeroize::Zeroize; use super::{ratchet::RatchetPublicKey, DecryptionError}; use crate::{ - cipher::{Cipher, Mac}, - olm::messages::Message, + cipher::{Cipher, InterolmMessageMac, Mac}, + olm::{messages::Message, session_config::SessionCreator, InterolmMessage, SessionKeys}, }; pub struct MessageKey { @@ -87,6 +87,45 @@ impl MessageKey { message } + #[cfg(feature = "interolm")] + pub fn encrypt_interolm( + self, + session_keys: &SessionKeys, + session_creator: SessionCreator, + previous_counter: u32, + plaintext: &[u8], + ) -> InterolmMessage { + let cipher = Cipher::new_interolm(&self.key); + + let ciphertext = cipher.encrypt(plaintext); + + let mut message = InterolmMessage::new( + *self.ratchet_key.as_ref(), + self.index.try_into().expect("Interolm doesn't support encrypting more than 2^32 messages with a single sender chain"), + previous_counter, + ciphertext, + ); + + let sender_identity; + let receiver_identity; + + match session_creator { + SessionCreator::Us => { + sender_identity = session_keys.identity_key; + receiver_identity = session_keys.other_identity_key; + } + SessionCreator::Them => { + sender_identity = session_keys.other_identity_key; + receiver_identity = session_keys.identity_key; + } + }; + + let mac = cipher.mac_interolm(sender_identity, receiver_identity, &message.to_mac_bytes()); + message.set_mac(mac); + + message + } + /// Get a reference to the message key's key. #[cfg(feature = "low-level-api")] pub fn key(&self) -> &[u8; 32] { @@ -115,25 +154,59 @@ impl RemoteMessageKey { self.index } - pub fn decrypt_truncated_mac(&self, message: &Message) -> Result, DecryptionError> { + pub fn decrypt(&self, message: &Message) -> Result, DecryptionError> { let cipher = Cipher::new(&self.key); - if let crate::cipher::MessageMac::Truncated(m) = &message.mac { - cipher.verify_truncated_mac(&message.to_mac_bytes(), m)?; + if let crate::cipher::MessageMac::Full(m) = &message.mac { + cipher.verify_mac(&message.to_mac_bytes(), m)?; Ok(cipher.decrypt(&message.ciphertext)?) } else { - Err(DecryptionError::InvalidMACLength(Mac::TRUNCATED_LEN, Mac::LENGTH)) + Err(DecryptionError::InvalidMACLength(Mac::LENGTH, Mac::TRUNCATED_LEN)) } } - pub fn decrypt(&self, message: &Message) -> Result, DecryptionError> { + pub fn decrypt_truncated_mac(&self, message: &Message) -> Result, DecryptionError> { let cipher = Cipher::new(&self.key); - if let crate::cipher::MessageMac::Full(m) = &message.mac { - cipher.verify_mac(&message.to_mac_bytes(), m)?; + if let crate::cipher::MessageMac::Truncated(m) = &message.mac { + cipher.verify_truncated_mac(&message.to_mac_bytes(), m)?; Ok(cipher.decrypt(&message.ciphertext)?) } else { - Err(DecryptionError::InvalidMACLength(Mac::LENGTH, Mac::TRUNCATED_LEN)) + Err(DecryptionError::InvalidMACLength(Mac::TRUNCATED_LEN, Mac::LENGTH)) } } + + #[cfg(feature = "interolm")] + pub fn decrypt_interolm( + &self, + session_keys: &SessionKeys, + session_creator: SessionCreator, + message: &InterolmMessage, + ) -> Result, DecryptionError> { + let cipher = Cipher::new_interolm(&self.key); + + let sender_identity; + let receiver_identity; + + match session_creator { + SessionCreator::Us => { + sender_identity = session_keys.other_identity_key; + receiver_identity = session_keys.identity_key; + } + SessionCreator::Them => { + sender_identity = session_keys.identity_key; + receiver_identity = session_keys.other_identity_key; + } + }; + + let InterolmMessageMac(m) = &message.mac; + cipher.verify_interolm_mac( + &message.to_mac_bytes(), + sender_identity, + receiver_identity, + m, + )?; + + Ok(cipher.decrypt(&message.ciphertext)?) + } } diff --git a/src/olm/session/mod.rs b/src/olm/session/mod.rs index 7125bc5d..bea69664 100644 --- a/src/olm/session/mod.rs +++ b/src/olm/session/mod.rs @@ -29,39 +29,43 @@ use double_ratchet::DoubleRatchet; use hmac::digest::MacError; use ratchet::RemoteRatchetKey; use receiver_chain::ReceiverChain; -use root_key::RemoteRootKey; use serde::{Deserialize, Serialize}; use thiserror::Error; use zeroize::Zeroize; +use self::ratchet::RatchetKey; use super::{ - session_config::Version, + session_config::{SessionCreator, Version}, session_keys::SessionKeys, shared_secret::{RemoteShared3DHSecret, Shared3DHSecret}, - SessionConfig, + AnyMessage, AnyNormalMessage, InterolmPreKeyMessage, SessionConfig, }; #[cfg(feature = "low-level-api")] use crate::hazmat::olm::MessageKey; use crate::{ - olm::messages::{Message, OlmMessage, PreKeyMessage}, + olm::{ + messages::{AnyInterolmMessage, AnyNativeMessage, PreKeyMessage}, + session::root_key::{RemoteRootKey, RootKey}, + }, + types::Curve25519SecretKey, utilities::{pickle, unpickle}, Curve25519PublicKey, PickleError, }; const MAX_RECEIVING_CHAINS: usize = 5; -/// Error type for Olm-based decryption failures. +/// Error type for decryption failures. #[derive(Error, Debug)] pub enum DecryptionError { /// The message authentication code of the message was invalid. - #[error("Failed decrypting Olm message, invalid MAC: {0}")] + #[error("Failed decrypting message, invalid MAC: {0}")] InvalidMAC(#[from] MacError), /// The length of the message authentication code of the message did not /// match our expected length. - #[error("Failed decrypting Olm message, invalid MAC length: expected {0}, got {1}")] + #[error("Failed decrypting message, invalid MAC length: expected {0}, got {1}")] InvalidMACLength(usize, usize), /// The ciphertext of the message isn't padded correctly. - #[error("Failed decrypting Olm message, invalid padding")] + #[error("Failed decrypting message, invalid padding")] InvalidPadding(#[from] UnpadError), /// The session is missing the correct message key to decrypt the message, /// either because it was already used up, or because the Session has been @@ -71,6 +75,9 @@ pub enum DecryptionError { /// Too many messages have been skipped to attempt decrypting this message. #[error("The message gap was too big, got {0}, max allowed {1}")] TooBigMessageGap(u64, u64), + /// We were expecting one algorithm but the message was in another. + #[error("The message had an unexpected algorithm: expected {0}, got {1}")] + WrongAlgorithm(String, String), } #[derive(Serialize, Deserialize, Clone)] @@ -91,10 +98,6 @@ impl ChainStore { self.inner.push(ratchet) } - fn is_empty(&self) -> bool { - self.inner.is_empty() - } - #[cfg(test)] pub fn len(&self) -> usize { self.inner.len() @@ -108,6 +111,33 @@ impl ChainStore { fn find_ratchet(&mut self, ratchet_key: &RemoteRatchetKey) -> Option<&mut ReceiverChain> { self.inner.iter_mut().find(|r| r.belongs_to(ratchet_key)) } + + #[cfg(feature = "interolm")] + fn previous_chain(&self) -> Option { + let num_chains = self.inner.len(); + + if num_chains >= 2 { + self.inner.get(num_chains - 2).cloned() + } else { + None + } + } + + #[cfg(feature = "interolm")] + fn previous_counter(&self) -> u32 { + match self.previous_chain() { + Some(chain) => { + if chain.hkdf_ratchet.chain_index() > 0 { + (chain.hkdf_ratchet.chain_index() - 1) + .try_into() + .expect("Interolm counter should fit into u32") + } else { + 0 + } + } + None => 0, + } + } } impl Default for ChainStore { @@ -125,13 +155,13 @@ impl Default for ChainStore { /// Olm sessions have two important properties: /// /// 1. They are based on a double ratchet algorithm which continuously -/// introduces new entropy into the channel as messages are sent and -/// received. This imbues the channel with *self-healing* properties, -/// allowing it to recover from a momentary loss of confidentiality in the event -/// of a key compromise. +/// introduces new entropy into the channel as messages are sent and +/// received. This imbues the channel with *self-healing* properties, +/// allowing it to recover from a momentary loss of confidentiality in the +/// event of a key compromise. /// 2. They are *asynchronous*, allowing the participant to start sending -/// messages to the other side even if the other participant is not online at -/// the moment. +/// messages to the other side even if the other participant is not online at +/// the moment. /// /// An Olm [`Session`] is acquired from an [`Account`], by calling either /// @@ -150,16 +180,20 @@ pub struct Session { sending_ratchet: DoubleRatchet, receiving_chains: ChainStore, config: SessionConfig, + session_creator: SessionCreator, } impl Debug for Session { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let Self { session_keys: _, sending_ratchet, receiving_chains, config } = self; + let Self { session_keys: _, sending_ratchet, receiving_chains, config, session_creator } = + self; f.debug_struct("Session") .field("session_id", &self.session_id()) .field("sending_chain_index", &sending_ratchet.chain_index()) .field("receiving_chains", &receiving_chains.inner) + .field("message_received", &self.has_received_message()) + .field("session_creator", session_creator) .field("config", config) .finish_non_exhaustive() } @@ -171,23 +205,24 @@ impl Session { shared_secret: Shared3DHSecret, session_keys: SessionKeys, ) -> Self { - let local_ratchet = DoubleRatchet::active(shared_secret); + let local_ratchet = DoubleRatchet::active(&config, shared_secret); Self { session_keys, sending_ratchet: local_ratchet, receiving_chains: Default::default(), config, + session_creator: SessionCreator::Us, } } pub(super) fn new_remote( - config: SessionConfig, + config: &SessionConfig, shared_secret: RemoteShared3DHSecret, remote_ratchet_key: Curve25519PublicKey, session_keys: SessionKeys, ) -> Self { - let (root_key, remote_chain_key) = shared_secret.expand(); + let (root_key, remote_chain_key) = shared_secret.expand(config); let remote_ratchet_key = RemoteRatchetKey::from(remote_ratchet_key); let root_key = RemoteRootKey::new(root_key); @@ -203,7 +238,58 @@ impl Session { session_keys, sending_ratchet: local_ratchet, receiving_chains: ratchet_store, + config: *config, + session_creator: SessionCreator::Them, + } + } + + #[cfg(feature = "interolm")] + pub(super) fn new_interolm( + config: SessionConfig, + shared_secret: Shared3DHSecret, + session_keys: SessionKeys, + ) -> Self { + let their_ratchet_key = RemoteRatchetKey(session_keys.signed_pre_key); + + let (local_ratchet, receiver_chain) = + DoubleRatchet::active_interolm(&config, shared_secret, their_ratchet_key); + + let mut ratchet_store = ChainStore::new(); + ratchet_store.push(receiver_chain); + + Self { + session_keys, + sending_ratchet: local_ratchet, + receiving_chains: ratchet_store, + config, + session_creator: SessionCreator::Us, + } + } + + #[cfg(feature = "interolm")] + pub(super) fn new_interolm_remote( + config: &SessionConfig, + shared_secret: RemoteShared3DHSecret, + remote_ratchet_key: RemoteRatchetKey, + session_keys: SessionKeys, + our_ratchet_key: RatchetKey, + ) -> Self { + let (local_ratchet, receiver_chain) = DoubleRatchet::inactive_interolm( config, + shared_secret, + our_ratchet_key, + remote_ratchet_key, + ); + + let mut ratchet_store = ChainStore::new(); + ratchet_store.push(receiver_chain); + + Self { + session_keys, + sending_ratchet: local_ratchet, + receiving_chains: ratchet_store, + config: *config, + session_creator: SessionCreator::Them, } } @@ -219,27 +305,109 @@ impl Session { /// Used to decide if outgoing messages should be sent as normal or pre-key /// messages. pub fn has_received_message(&self) -> bool { - !self.receiving_chains.is_empty() + let initial_ratchet_key = RemoteRatchetKey(self.session_keys().signed_pre_key); + + // Interolm immediately initializes a receiving chain, using the signed prekey + // as the initial (remote) ratchet key, even though it never received a + // message from the other side. Therefore we need to filter that chain + // out when trying to determine whether we've ever received a message + // from the other side. + let is_empty = self + .receiving_chains + .inner + .iter() + .filter(|c| !c.belongs_to(&initial_ratchet_key)) + .next() + .is_none(); + + !is_empty + } + + pub fn is_message_for_this_session(&self, message: &AnyMessage) -> Option { + match message { + AnyMessage::Native(AnyNativeMessage::PreKey(n)) => { + Some(n.session_id() == self.session_id()) + } + AnyMessage::Interolm(AnyInterolmMessage::PreKey(s)) => { + if let Version::VInterolm(meta_data) = self.config.version { + let pre_key_id = meta_data + .one_time_key_id + .map(|k| k.0.try_into().expect("Interolm key IDs are bound to 32 bits")); + let signed_pre_key_id: u32 = meta_data + .signed_pre_key_id + .0 + .try_into() + .expect("Interolm key IDs are bound to 32 bits"); + + Some( + pre_key_id == s.pre_key_id + && signed_pre_key_id == s.signed_pre_key_id + && meta_data.registration_id == s.registration_id, + ) + } else { + Some(false) + } + } + _ => None, + } } - /// Encrypt the `plaintext` and construct an [`OlmMessage`]. + /// Encrypt the `plaintext` and construct an [`AnyNativeMessage`]. /// /// The message will either be a pre-key message or a normal message, /// depending on whether the session is fully established. A session is /// fully established once you receive (and decrypt) at least one /// message from the other side. - pub fn encrypt(&mut self, plaintext: impl AsRef<[u8]>) -> OlmMessage { + pub fn encrypt(&mut self, plaintext: impl AsRef<[u8]>) -> AnyNativeMessage { let message = match self.config.version { - Version::V1 => self.sending_ratchet.encrypt_truncated_mac(plaintext.as_ref()), - Version::V2 => self.sending_ratchet.encrypt(plaintext.as_ref()), + Version::V1 => { + self.sending_ratchet.encrypt_truncated_mac(&self.config, plaintext.as_ref()) + } + Version::V2 => self.sending_ratchet.encrypt(&self.config, plaintext.as_ref()), + #[cfg(feature = "interolm")] + Version::VInterolm(..) => panic!("`Session::encrypt` called on an Interolm session!"), + }; + + if self.has_received_message() { + AnyNativeMessage::Normal(message) + } else { + let message = PreKeyMessage::new(self.session_keys.into(), message); + + AnyNativeMessage::PreKey(message) + } + } + + /// Encrypt the `plaintext` for Interolm and construct an + /// [`AnyInterolmMessage`]. + /// + /// The message will either be a pre-key message or a normal message, + /// depending on whether the session is fully established. A session is + /// fully established once you receive (and decrypt) at least one + /// message from the other side. + #[cfg(feature = "interolm")] + pub fn encrypt_interolm(&mut self, plaintext: impl AsRef<[u8]>) -> AnyInterolmMessage { + let (metadata, message) = match self.config.version { + Version::V1 | Version::V2 => { + panic!("`Session::encrypt_interolm` called on a non-Interolm session!") + } + Version::VInterolm(metadata) => ( + metadata, + self.sending_ratchet.encrypt_interolm( + &self.config, + self.session_creator, + &self.session_keys, + self.receiving_chains.previous_counter(), + plaintext.as_ref(), + ), + ), }; if self.has_received_message() { - OlmMessage::Normal(message) + AnyInterolmMessage::Normal(message) } else { - let message = PreKeyMessage::new(self.session_keys, message); + let message = InterolmPreKeyMessage::new(self.session_keys, metadata, message); - OlmMessage::PreKey(message) + AnyInterolmMessage::PreKey(message) } } @@ -270,10 +438,31 @@ impl Session { /// result in a [`DecryptionError`]. /// /// [`DecryptionError`]: self::DecryptionError - pub fn decrypt(&mut self, message: &OlmMessage) -> Result, DecryptionError> { + pub fn decrypt(&mut self, message: &AnyNativeMessage) -> Result, DecryptionError> { + let decrypted = match message { + AnyNativeMessage::Normal(m) => self.decrypt_decoded(AnyNormalMessage::Native(m))?, + AnyNativeMessage::PreKey(m) => { + self.decrypt_decoded(AnyNormalMessage::Native(&m.message))? + } + }; + + Ok(decrypted) + } + + /// Try to decrypt an Interolm message, which will either return the + /// plaintext or result in a [`DecryptionError`]. + /// + /// [`DecryptionError`]: self::DecryptionError + #[cfg(feature = "interolm")] + pub fn decrypt_interolm( + &mut self, + message: &AnyInterolmMessage, + ) -> Result, DecryptionError> { let decrypted = match message { - OlmMessage::Normal(m) => self.decrypt_decoded(m)?, - OlmMessage::PreKey(m) => self.decrypt_decoded(&m.message)?, + AnyInterolmMessage::Normal(m) => self.decrypt_decoded(AnyNormalMessage::Interolm(m))?, + AnyInterolmMessage::PreKey(m) => { + self.decrypt_decoded(AnyNormalMessage::Interolm(&m.message))? + } }; Ok(decrypted) @@ -281,16 +470,22 @@ impl Session { pub(super) fn decrypt_decoded( &mut self, - message: &Message, + message: AnyNormalMessage<'_>, ) -> Result, DecryptionError> { - let ratchet_key = RemoteRatchetKey::from(message.ratchet_key); + let ratchet_key = RemoteRatchetKey::from(message.ratchet_key()); if let Some(ratchet) = self.receiving_chains.find_ratchet(&ratchet_key) { - ratchet.decrypt(message, &self.config) + ratchet.decrypt(&self.config, &self.session_keys, self.session_creator, message) } else { - let (sending_ratchet, mut remote_ratchet) = self.sending_ratchet.advance(ratchet_key); + let (sending_ratchet, mut remote_ratchet) = + self.sending_ratchet.advance(&self.config, ratchet_key); - let plaintext = remote_ratchet.decrypt(message, &self.config)?; + let plaintext = remote_ratchet.decrypt( + &self.config, + &self.session_keys, + self.session_creator, + message, + )?; self.sending_ratchet = sending_ratchet; self.receiving_chains.push(remote_ratchet); @@ -307,6 +502,7 @@ impl Session { sending_ratchet: self.sending_ratchet.clone(), receiving_chains: self.receiving_chains.clone(), config: self.config, + session_creator: self.session_creator, } } @@ -328,10 +524,8 @@ impl Session { use chain_key::ChainKey; use matrix_pickle::Decode; use message_key::RemoteMessageKey; - use ratchet::{Ratchet, RatchetKey}; - use root_key::RootKey; - use crate::{types::Curve25519SecretKey, utilities::unpickle_libolm}; + use crate::{olm::session::ratchet::Ratchet, utilities::unpickle_libolm}; #[derive(Debug, Decode, Zeroize)] #[zeroize(drop)] @@ -373,6 +567,19 @@ impl Session { index: u32, } + /// The set of keys that were used to establish the Olm Session. + // XXX: Could probably be removed (in favour of) when SessionKeysWire is renamed to + // OlmSessionKeys. + #[derive(Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Decode)] + pub struct OlmSessionKeys { + /// Alice's identity key. + pub identity_key: Curve25519PublicKey, + /// Alice's ephemeral (base) key. + pub base_key: Curve25519PublicKey, + /// Bob's OTK which Alice used. + pub one_time_key: Curve25519PublicKey, + } + impl From<&MessageKey> for RemoteMessageKey { fn from(key: &MessageKey) -> Self { RemoteMessageKey { key: key.message_key.clone(), index: key.index.into() } @@ -385,7 +592,7 @@ impl Session { version: u32, #[allow(dead_code)] received_message: bool, - session_keys: SessionKeys, + session_keys: OlmSessionKeys, #[secret] root_key: Box<[u8; 32]>, sender_chains: Vec, @@ -421,6 +628,15 @@ impl Session { } } + let session_keys = SessionKeys { + identity_key: pickle.session_keys.identity_key, + base_key: pickle.session_keys.base_key, + signed_pre_key: pickle.session_keys.one_time_key, + one_time_key: None, + other_identity_key: todo!("libolm session pickles don't contain this information, \ + so there's not enough information to reconstruct a `Session`"), + }; + if let Some(chain) = pickle.sender_chains.first() { // XXX: Passing in secret array as value. let ratchet_key = RatchetKey::from(Curve25519SecretKey::from_slice( @@ -438,10 +654,12 @@ impl Session { DoubleRatchet::from_ratchet_and_chain_key(ratchet, chain_key); Ok(Self { - session_keys: pickle.session_keys, + session_keys, sending_ratchet, receiving_chains, config: SessionConfig::version_1(), + session_creator: todo!("libolm session pickles don't contain this information, \ + so there's not enough information to reconstruct a `Session`") }) } else if let Some(chain) = receiving_chains.get(0) { let sending_ratchet = DoubleRatchet::inactive( @@ -450,10 +668,12 @@ impl Session { ); Ok(Self { - session_keys: pickle.session_keys, + session_keys, sending_ratchet, receiving_chains, config: SessionConfig::version_1(), + session_creator: todo!("libolm session pickles don't contain this information, \ + so there's not enough information to reconstruct a `Session`") }) } else { Err(crate::LibolmPickleError::InvalidSession) @@ -475,6 +695,7 @@ pub struct SessionPickle { receiving_chains: ChainStore, #[serde(default = "default_config")] config: SessionConfig, + session_creator: SessionCreator, } fn default_config() -> SessionConfig { @@ -505,6 +726,7 @@ impl From for Session { sending_ratchet: pickle.sending_ratchet, receiving_chains: pickle.receiving_chains, config: pickle.config, + session_creator: pickle.session_creator, } } } @@ -519,8 +741,11 @@ mod test { use super::Session; use crate::{ - olm::{Account, SessionConfig, SessionPickle}, - Curve25519PublicKey, + olm::{ + Account, AnyInterolmMessage, AnyMessage, InboundCreationResult, SessionConfig, + SessionPickle, + }, + Curve25519PublicKey, KeyId, }; const PICKLE_KEY: [u8; 32] = [0u8; 32]; @@ -541,8 +766,12 @@ mod test { let identity_keys = bob.parsed_identity_keys(); let curve25519_key = Curve25519PublicKey::from_base64(identity_keys.curve25519())?; let one_time_key = Curve25519PublicKey::from_base64(&one_time_key)?; - let mut alice_session = - alice.create_outbound_session(SessionConfig::version_1(), curve25519_key, one_time_key); + let mut alice_session = alice.create_outbound_session( + SessionConfig::version_1(), + curve25519_key, + one_time_key, + None, + ); let message = "It's a secret to everybody"; @@ -559,6 +788,46 @@ mod test { } } + fn interolm_sessions() -> Result<(Account, Account, Session, Session)> { + let alice = Account::new(); + let mut bob = Account::new(); + bob.generate_one_time_keys(2); + + let mut bob_prekeys: Vec<(KeyId, Curve25519PublicKey)> = + bob.one_time_keys().iter().map(|(t1, t2)| (t1.clone(), t2.clone())).take(2).collect(); + let (otk_id, otk) = + bob_prekeys.pop().expect("Bob should have an OTK because we just generated it"); + let (skey_id, skey) = bob_prekeys + .pop() + .expect("Bob should have a signed prekey because we just generated it"); + + bob.mark_keys_as_published(); + + let identity_keys = bob.identity_keys(); + let curve25519_key = identity_keys.curve25519; + + let mut alice_session = alice.create_outbound_session( + SessionConfig::version_interolm(0, skey_id, Some(otk_id)), + curve25519_key, + skey, + Some(otk), + ); + + let message = "It's a secret to everybody"; + let ciphertext = alice_session.encrypt_interolm(message); + + if let AnyMessage::Interolm(AnyInterolmMessage::PreKey(m)) = ciphertext.into() { + let InboundCreationResult { session, .. } = bob.create_inbound_session( + alice.identity_keys().curve25519, + &m.try_into().expect("We should be able to establish the session"), + )?; + + Ok((alice, bob, alice_session, session)) + } else { + bail!("Invalid message type"); + } + } + #[test] fn out_of_order_decryption() -> Result<()> { let (_, _, mut alice_session, bob_session) = sessions()?; @@ -650,4 +919,32 @@ mod test { Ok(()) } + + #[test] + fn message_received_flag_survives_pickling_roundtrip() -> Result<()> { + let (_, _, alice_session, mut bob_session) = interolm_sessions()?; + + assert!(!alice_session.has_received_message()); + + let pickle = alice_session.pickle().encrypt(&PICKLE_KEY); + let decrypted_pickle = SessionPickle::from_encrypted(&pickle, &PICKLE_KEY)?; + let mut alice_session = Session::from_pickle(decrypted_pickle); + + assert!(!alice_session.has_received_message()); + + let bob_msg = bob_session.encrypt_interolm("Hello Alice!"); + let _ = alice_session + .decrypt_interolm(&bob_msg) + .expect("Alice should be able to decrypt Bob's message"); + + assert!(alice_session.has_received_message()); + + let pickle = alice_session.pickle().encrypt(&PICKLE_KEY); + let decrypted_pickle = SessionPickle::from_encrypted(&pickle, &PICKLE_KEY)?; + let alice_session = Session::from_pickle(decrypted_pickle); + + assert!(alice_session.has_received_message()); + + Ok(()) + } } diff --git a/src/olm/session/ratchet.rs b/src/olm/session/ratchet.rs index a1a6b4e7..233fccfa 100644 --- a/src/olm/session/ratchet.rs +++ b/src/olm/session/ratchet.rs @@ -22,18 +22,18 @@ use super::{ chain_key::RemoteChainKey, root_key::{RemoteRootKey, RootKey}, }; -use crate::{types::Curve25519SecretKey, Curve25519PublicKey}; +use crate::{olm::SessionConfig, types::Curve25519SecretKey, Curve25519PublicKey}; #[derive(Serialize, Deserialize, Clone)] #[serde(transparent)] -pub(super) struct RatchetKey(Curve25519SecretKey); +pub(crate) struct RatchetKey(pub(crate) Curve25519SecretKey); #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct RatchetPublicKey(Curve25519PublicKey); #[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, Decode)] #[serde(transparent)] -pub struct RemoteRatchetKey(Curve25519PublicKey); +pub(crate) struct RemoteRatchetKey(pub(crate) Curve25519PublicKey); impl Debug for RemoteRatchetKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -110,9 +110,13 @@ impl Ratchet { Self { root_key, ratchet_key } } - pub fn advance(&self, remote_key: RemoteRatchetKey) -> (RemoteRootKey, RemoteChainKey) { + pub fn advance( + &self, + config: &SessionConfig, + remote_key: RemoteRatchetKey, + ) -> (RemoteRootKey, RemoteChainKey) { let (remote_root_key, remote_chain_key) = - self.root_key.advance(&self.ratchet_key, &remote_key); + self.root_key.advance(config, &self.ratchet_key, &remote_key); (remote_root_key, remote_chain_key) } diff --git a/src/olm/session/receiver_chain.rs b/src/olm/session/receiver_chain.rs index 96b2c1f4..02140ed2 100644 --- a/src/olm/session/receiver_chain.rs +++ b/src/olm/session/receiver_chain.rs @@ -21,7 +21,10 @@ use super::{ chain_key::RemoteChainKey, message_key::RemoteMessageKey, ratchet::RemoteRatchetKey, DecryptionError, }; -use crate::olm::{messages::Message, session_config::Version, SessionConfig}; +use crate::olm::{ + session_config::{SessionCreator, Version}, + AnyNormalMessage, SessionConfig, SessionKeys, +}; const MAX_MESSAGE_GAP: u64 = 2000; const MAX_MESSAGE_KEYS: usize = 40; @@ -73,17 +76,34 @@ enum FoundMessageKey<'a> { impl FoundMessageKey<'_> { fn decrypt( &self, - message: &Message, config: &SessionConfig, + session_keys: &SessionKeys, + session_creator: SessionCreator, + message: AnyNormalMessage<'_>, ) -> Result, DecryptionError> { let message_key = match self { FoundMessageKey::Existing(m) => m, FoundMessageKey::New(m) => &m.2, }; - match config.version { - Version::V1 => message_key.decrypt_truncated_mac(message), - Version::V2 => message_key.decrypt(message), + match message { + AnyNormalMessage::Native(message) => match config.version { + Version::V1 => message_key.decrypt_truncated_mac(message), + Version::V2 => message_key.decrypt(message), + #[cfg(feature = "interolm")] + Version::VInterolm(..) => { + Err(DecryptionError::WrongAlgorithm("Interolm".into(), "Olm".into())) + } + }, + #[cfg(feature = "interolm")] + AnyNormalMessage::Interolm(message) => match config.version { + Version::V1 | Version::V2 => { + Err(DecryptionError::WrongAlgorithm("Olm".into(), "Interolm".into())) + } + Version::VInterolm(..) => { + message_key.decrypt_interolm(session_keys, session_creator, message) + } + }, } } } @@ -91,7 +111,7 @@ impl FoundMessageKey<'_> { #[derive(Serialize, Deserialize, Clone)] pub(super) struct ReceiverChain { ratchet_key: RemoteRatchetKey, - hkdf_ratchet: RemoteChainKey, + pub(super) hkdf_ratchet: RemoteChainKey, skipped_message_keys: MessageKeyStore, } @@ -149,13 +169,15 @@ impl ReceiverChain { pub fn decrypt( &mut self, - message: &Message, config: &SessionConfig, + session_keys: &SessionKeys, + session_creator: SessionCreator, + message: AnyNormalMessage<'_>, ) -> Result, DecryptionError> { - let chain_index = message.chain_index; + let chain_index = message.chain_index(); let message_key = self.find_message_key(chain_index)?; - let plaintext = message_key.decrypt(message, config)?; + let plaintext = message_key.decrypt(config, session_keys, session_creator, message)?; match message_key { FoundMessageKey::Existing(m) => { diff --git a/src/olm/session/root_key.rs b/src/olm/session/root_key.rs index 859624ab..bae35466 100644 --- a/src/olm/session/root_key.rs +++ b/src/olm/session/root_key.rs @@ -21,8 +21,10 @@ use super::{ chain_key::{ChainKey, RemoteChainKey}, ratchet::{RatchetKey, RemoteRatchetKey}, }; +use crate::olm::{session_config::Version, SessionConfig}; -const ADVANCEMENT_SEED: &[u8; 11] = b"OLM_RATCHET"; +const ADVANCEMENT_SEED_OLM: &[u8; 11] = b"OLM_RATCHET"; +const ADVANCEMENT_SEED_INTEROLM: &[u8; 11] = b"OLM_RATCHET"; #[derive(Serialize, Deserialize, Clone, Zeroize)] #[serde(transparent)] @@ -38,6 +40,7 @@ pub(crate) struct RemoteRootKey { } fn kdf( + info: &[u8], root_key: &[u8; 32], ratchet_key: &RatchetKey, remote_ratchet_key: &RemoteRatchetKey, @@ -46,11 +49,27 @@ fn kdf( let hkdf: Hkdf = Hkdf::new(Some(root_key.as_ref()), shared_secret.as_bytes()); let mut output = Box::new([0u8; 64]); - hkdf.expand(ADVANCEMENT_SEED, output.as_mut_slice()).expect("Can't expand"); + hkdf.expand(info, output.as_mut_slice()).expect("Can't expand"); output } +fn kdf_olm( + root_key: &[u8; 32], + ratchet_key: &RatchetKey, + remote_ratchet_key: &RemoteRatchetKey, +) -> Box<[u8; 64]> { + kdf(ADVANCEMENT_SEED_OLM, root_key, ratchet_key, remote_ratchet_key) +} + +fn kdf_interolm( + root_key: &[u8; 32], + ratchet_key: &RatchetKey, + remote_ratchet_key: &RemoteRatchetKey, +) -> Box<[u8; 64]> { + kdf(ADVANCEMENT_SEED_INTEROLM, root_key, ratchet_key, remote_ratchet_key) +} + impl RemoteRootKey { pub(super) fn new(bytes: Box<[u8; 32]>) -> Self { Self { key: bytes } @@ -58,16 +77,21 @@ impl RemoteRootKey { pub(super) fn advance( &self, + config: &SessionConfig, remote_ratchet_key: &RemoteRatchetKey, ) -> (RootKey, ChainKey, RatchetKey) { let ratchet_key = RatchetKey::new(); - let output = kdf(&self.key, &ratchet_key, remote_ratchet_key); + + let output = match config.version { + Version::V1 | Version::V2 => kdf_olm(&self.key, &ratchet_key, remote_ratchet_key), + Version::VInterolm(..) => kdf_interolm(&self.key, &ratchet_key, remote_ratchet_key), + }; let mut chain_key = Box::new([0u8; 32]); let mut root_key = Box::new([0u8; 32]); - chain_key.copy_from_slice(&output[32..]); root_key.copy_from_slice(&output[..32]); + chain_key.copy_from_slice(&output[32..]); let chain_key = ChainKey::new(chain_key); let root_key = RootKey::new(root_key); @@ -83,10 +107,14 @@ impl RootKey { pub(super) fn advance( &self, + config: &SessionConfig, old_ratchet_key: &RatchetKey, remote_ratchet_key: &RemoteRatchetKey, ) -> (RemoteRootKey, RemoteChainKey) { - let output = kdf(&self.key, old_ratchet_key, remote_ratchet_key); + let output = match config.version { + Version::V1 | Version::V2 => kdf_olm(&self.key, old_ratchet_key, remote_ratchet_key), + Version::VInterolm(..) => kdf_interolm(&self.key, old_ratchet_key, remote_ratchet_key), + }; let mut chain_key = Box::new([0u8; 32]); let mut root_key = Box::new([0u8; 32]); diff --git a/src/olm/session_config.rs b/src/olm/session_config.rs index 9310871c..2a08fa51 100644 --- a/src/olm/session_config.rs +++ b/src/olm/session_config.rs @@ -14,38 +14,78 @@ use serde::{Deserialize, Serialize}; -/// A struct to configure how Olm sessions should work under the hood. -/// Currently only the MAC truncation behaviour can be configured. +use crate::KeyId; + +/// Knobs for protocol configuration. Currently only used for switching between +/// different protocol versions (Olm v1, Olm v2 and Interolm). #[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct SessionConfig { pub(super) version: Version, } +#[cfg(feature = "interolm")] +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct InterolmSessionMetadata { + pub signed_pre_key_id: KeyId, + pub one_time_key_id: Option, + pub registration_id: u32, +} + +#[cfg(feature = "interolm")] +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub enum SessionCreator { + Us, + Them, +} + #[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] pub(super) enum Version { - V1 = 1, - V2 = 2, + V1, + V2, + #[cfg(feature = "interolm")] + VInterolm(InterolmSessionMetadata), } impl SessionConfig { - /// Get the numeric version of this `SessionConfig`. + /// Get the numeric representation of the session version. pub fn version(&self) -> u8 { - self.version as u8 + match self.version { + Version::V1 => 1, + Version::V2 => 2, + Version::VInterolm(_) => 3, + } } - /// Create a `SessionConfig` for the Olm version 1. This version of Olm will - /// use AES-256 and HMAC with a truncated MAC to encrypt individual - /// messages. The MAC will be truncated to 8 bytes. + /// Create a `SessionConfig` for the Olm version 1. This version of Olm uses + /// AES-256 and HMAC with an 8-byte truncated MAC for individual message + /// encryption. pub fn version_1() -> Self { SessionConfig { version: Version::V1 } } - /// Create a `SessionConfig` for the Olm version 2. This version of Olm will - /// use AES-256 and HMAC to encrypt individual messages. The MAC won't be - /// truncated. + /// Create a `SessionConfig` for the Olm version 2. This version of Olm uses + /// AES-256 and HMAC to encrypt individual messages. The MAC is left + /// untruncated (32 bytes). pub fn version_2() -> Self { SessionConfig { version: Version::V2 } } + + /// Create a `SessionConfig` for the Interolm protocol. Similarly to Olm v1, + /// this uses AES-256 and a truncated 8-byte MAC. + #[cfg(feature = "interolm")] + pub fn version_interolm( + registration_id: u32, + signed_pre_key_id: KeyId, + one_time_key_id: Option, + ) -> Self { + SessionConfig { + version: Version::VInterolm(InterolmSessionMetadata { + signed_pre_key_id, + one_time_key_id, + registration_id, + }), + } + } } impl Default for SessionConfig { diff --git a/src/olm/session_keys.rs b/src/olm/session_keys.rs index 4fe6a9c4..4bae62d2 100644 --- a/src/olm/session_keys.rs +++ b/src/olm/session_keys.rs @@ -13,27 +13,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -use matrix_pickle::Decode; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use crate::{utilities::base64_encode, Curve25519PublicKey}; -/// The set of keys that were used to establish the Olm Session, -#[derive(Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Decode)] +/// The set of keys that were used to establish the session. +#[derive(Clone, Copy, Deserialize, Serialize, PartialEq, Eq)] pub struct SessionKeys { + /// Alice's identity key. pub identity_key: Curve25519PublicKey, + /// Alice's ephemeral (base) key. pub base_key: Curve25519PublicKey, - pub one_time_key: Curve25519PublicKey, + /// Bob's identity key. + pub other_identity_key: Curve25519PublicKey, + /// Bob's OTK which Alice used. + pub signed_pre_key: Curve25519PublicKey, + /// Bob's OTK which Alice used, if any. + pub one_time_key: Option, } impl SessionKeys { - /// Returns the globally unique session ID which these [`SessionKeys`] will - /// produce. + /// Returns the globally unique session ID which these [`SessionKeys`] + /// will produce. /// - /// A session ID is the SHA256 of the concatenation of three `SessionKeys`, - /// the account's identity key, the ephemeral base key and the one-time - /// key which is used to establish the session. + /// A session ID is the SHA256 of the concatenation of the session keys + /// which were used to establish the session: the account's identity key, + /// the ephemeral base key, the signed pre-key and the one-time key (if + /// any). /// /// Due to the construction, every session ID is (probabilistically) /// globally unique. @@ -43,8 +50,15 @@ impl SessionKeys { let digest = sha .chain_update(self.identity_key.as_bytes()) .chain_update(self.base_key.as_bytes()) - .chain_update(self.one_time_key.as_bytes()) - .finalize(); + .chain_update(self.signed_pre_key.as_bytes()); + + let digest = if let Some(otk) = self.one_time_key { + digest.chain_update(otk.as_bytes()) + } else { + digest + }; + + let digest = digest.finalize(); base64_encode(digest) } @@ -55,7 +69,65 @@ impl std::fmt::Debug for SessionKeys { f.debug_struct("SessionKeys") .field("identity_key", &self.identity_key.to_base64()) .field("base_key", &self.base_key.to_base64()) - .field("one_time_key", &self.one_time_key.to_base64()) + .field("other_identity_key", &self.other_identity_key.to_base64()) + .field("signed_pre_key", &self.signed_pre_key.to_base64()) + .field("one_time_key", &self.one_time_key.map(|x| x.to_base64())) + .finish() + } +} + +/// Represents the session keys as received over the network in the Olm and +/// Interolm protocols. +#[derive(Clone, Copy, Deserialize, Serialize, PartialEq, Eq)] +pub struct OlmSessionKeys { + /// Alice's identity key. + pub identity_key: Curve25519PublicKey, + /// Alice's ephemeral (base) key. + pub base_key: Curve25519PublicKey, + /// Bob's OTK which Alice used. + pub one_time_key: Curve25519PublicKey, +} + +impl OlmSessionKeys { + /// Returns the globally unique session ID which these [`SessionKeysWire`] + /// will produce. + /// + /// A session ID is the SHA256 of the concatenation of three session keys + /// which were used to establish the session: the account's identity key, + /// the ephemeral base key and the one-time key. + /// + /// Due to the construction, every session ID is (probabilistically) + /// globally unique. + pub fn session_id(&self) -> String { + let sha = Sha256::new(); + + let digest = sha + .chain_update(self.identity_key.as_bytes()) + .chain_update(self.base_key.as_bytes()) + .chain_update(self.one_time_key.as_bytes()); + + let digest = digest.finalize(); + + base64_encode(digest) + } +} + +impl std::fmt::Debug for OlmSessionKeys { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SessionKeysWire") + .field("identity_key", &self.identity_key.to_base64()) + .field("base_key", &self.base_key.to_base64()) + .field("signed_pre_key", &self.one_time_key.to_base64()) .finish() } } + +impl From for OlmSessionKeys { + fn from(value: SessionKeys) -> Self { + Self { + identity_key: value.identity_key, + base_key: value.base_key, + one_time_key: value.signed_pre_key, + } + } +} diff --git a/src/olm/shared_secret.rs b/src/olm/shared_secret.rs index 9507dbf4..eaa9d5fc 100644 --- a/src/olm/shared_secret.rs +++ b/src/olm/shared_secret.rs @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! A 3DH implementation following the Olm [spec]. +//! A 3DH and X3DH implementation following the [Olm] and [Signal] specs. +//! +//! # Olm //! //! The setup takes four Curve25519 inputs: Identity keys for Alice and Bob, //! (Ia, Ib), and one-time keys for Alice and Bob (Ea, Eb). @@ -29,31 +31,40 @@ //! R0, C0,0 = HKDF(0, S, "OLM_ROOT", 64) //! ``` //! -//! [spec]: https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md#initial-setup +//! # Signal +//! +//! Rather than repeating the contents here, we refer you to the [Signal] X3DH +//! spec. +//! +//! [Olm]: https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md#initial-setup +//! [Signal]: https://signal.org/docs/specifications/x3dh/ use hkdf::Hkdf; use sha2::Sha256; use x25519_dalek::{ReusableSecret, SharedSecret}; use zeroize::Zeroize; +use super::{session_config::Version, SessionConfig}; use crate::{types::Curve25519SecretKey as StaticSecret, Curve25519PublicKey as PublicKey}; #[derive(Zeroize)] #[zeroize(drop)] -pub struct Shared3DHSecret(Box<[u8; 96]>); +pub struct Shared3DHSecret(Vec); #[derive(Zeroize)] #[zeroize(drop)] -pub struct RemoteShared3DHSecret(Box<[u8; 96]>); +pub struct RemoteShared3DHSecret(Vec); -fn expand(shared_secret: &[u8; 96]) -> (Box<[u8; 32]>, Box<[u8; 32]>) { - let hkdf: Hkdf = Hkdf::new(Some(&[0]), shared_secret); +/// Expands secret input derived from the (X)3DH handshake into a root key and +/// chain key. +fn expand(secret_input: &[u8], info: &[u8]) -> (Box<[u8; 32]>, Box<[u8; 32]>) { + let hkdf: Hkdf = Hkdf::new(Some(&[0]), secret_input); let mut root_key = Box::new([0u8; 32]); let mut chain_key = Box::new([0u8; 32]); let mut expanded_keys = [0u8; 64]; - hkdf.expand(b"OLM_ROOT", &mut expanded_keys) + hkdf.expand(info, &mut expanded_keys) .expect("Can't expand the shared 3DH secret into the Olm root"); root_key.copy_from_slice(&expanded_keys[0..32]); @@ -64,55 +75,114 @@ fn expand(shared_secret: &[u8; 96]) -> (Box<[u8; 32]>, Box<[u8; 32]>) { (root_key, chain_key) } -fn merge_secrets( +/// Expands secret input derived from the (X)3DH handshake into an Olm root key +/// and chain key. +fn expand_olm(secret_input: &[u8]) -> (Box<[u8; 32]>, Box<[u8; 32]>) { + expand(secret_input, b"OLM_ROOT") +} + +fn merge_secrets_olm( first_secret: SharedSecret, second_secret: SharedSecret, third_secret: SharedSecret, -) -> Box<[u8; 96]> { - let mut secret = Box::new([0u8; 96]); +) -> Vec { + let mut secret = Vec::with_capacity(4 * 32); - secret[0..32].copy_from_slice(first_secret.as_bytes()); - secret[32..64].copy_from_slice(second_secret.as_bytes()); - secret[64..96].copy_from_slice(third_secret.as_bytes()); + secret.extend_from_slice(first_secret.as_bytes()); + secret.extend_from_slice(second_secret.as_bytes()); + secret.extend_from_slice(third_secret.as_bytes()); + + secret +} + +#[cfg(feature = "interolm")] +fn merge_secrets_x3dh( + first_secret: SharedSecret, + second_secret: SharedSecret, + third_secret: SharedSecret, + fourth_secret: Option, +) -> Vec { + let mut secret = Vec::with_capacity(5 * 32); + + secret.extend_from_slice(&[0xFFu8; 32]); + secret.extend_from_slice(first_secret.as_bytes()); + secret.extend_from_slice(second_secret.as_bytes()); + secret.extend_from_slice(third_secret.as_bytes()); + + if let Some(s) = fourth_secret { + secret.extend_from_slice(s.as_bytes()); + } secret } impl RemoteShared3DHSecret { pub(crate) fn new( + config: &SessionConfig, identity_key: &StaticSecret, - one_time_key: &StaticSecret, + signed_prekey: &StaticSecret, + one_time_key: Option<&StaticSecret>, remote_identity_key: &PublicKey, - remote_one_time_key: &PublicKey, + remote_base_key: &PublicKey, ) -> Self { - let first_secret = one_time_key.diffie_hellman(remote_identity_key); - let second_secret = identity_key.diffie_hellman(remote_one_time_key); - let third_secret = one_time_key.diffie_hellman(remote_one_time_key); - - Self(merge_secrets(first_secret, second_secret, third_secret)) + let first_secret = signed_prekey.diffie_hellman(remote_identity_key); + let second_secret = identity_key.diffie_hellman(remote_base_key); + let third_secret = signed_prekey.diffie_hellman(remote_base_key); + let fourth_secret = one_time_key.map(|otk| otk.diffie_hellman(remote_base_key)); + + match config.version { + Version::V1 | Version::V2 => { + Self(merge_secrets_olm(first_secret, second_secret, third_secret)) + } + #[cfg(feature = "interolm")] + Version::VInterolm(..) => { + Self(merge_secrets_x3dh(first_secret, second_secret, third_secret, fourth_secret)) + } + } } - pub fn expand(self) -> (Box<[u8; 32]>, Box<[u8; 32]>) { - expand(&self.0) + /// Expands secret input derived from the (X)3DH handshake into a root key + /// and chain key. + pub fn expand(self, config: &SessionConfig) -> (Box<[u8; 32]>, Box<[u8; 32]>) { + match config.version { + Version::V1 | Version::V2 => expand_olm(&self.0), + Version::VInterolm(..) => expand_olm(&self.0), + } } } impl Shared3DHSecret { pub(crate) fn new( + config: &SessionConfig, identity_key: &StaticSecret, - one_time_key: &ReusableSecret, + base_key: &ReusableSecret, remote_identity_key: &PublicKey, - remote_one_time_key: &PublicKey, + remote_signed_prekey: &PublicKey, + remote_one_time_key: Option<&PublicKey>, ) -> Self { - let first_secret = identity_key.diffie_hellman(remote_one_time_key); - let second_secret = one_time_key.diffie_hellman(&remote_identity_key.inner); - let third_secret = one_time_key.diffie_hellman(&remote_one_time_key.inner); - - Self(merge_secrets(first_secret, second_secret, third_secret)) + let first_secret = identity_key.diffie_hellman(remote_signed_prekey); + let second_secret = base_key.diffie_hellman(&remote_identity_key.inner); + let third_secret = base_key.diffie_hellman(&remote_signed_prekey.inner); + let fourth_secret = remote_one_time_key.map(|otk| base_key.diffie_hellman(&otk.inner)); + + match config.version { + Version::V1 | Version::V2 => { + Self(merge_secrets_olm(first_secret, second_secret, third_secret)) + } + #[cfg(feature = "interolm")] + Version::VInterolm(..) => { + Self(merge_secrets_x3dh(first_secret, second_secret, third_secret, fourth_secret)) + } + } } - pub fn expand(self) -> (Box<[u8; 32]>, Box<[u8; 32]>) { - expand(&self.0) + /// Expands secret input derived from the (X)3DH handshake into a root key + /// and chain key. + pub fn expand(self, config: &SessionConfig) -> (Box<[u8; 32]>, Box<[u8; 32]>) { + match config.version { + Version::V1 | Version::V2 => expand_olm(&self.0), + Version::VInterolm(..) => expand_olm(&self.0), + } } } @@ -122,11 +192,15 @@ mod test { use x25519_dalek::ReusableSecret; use super::{RemoteShared3DHSecret, Shared3DHSecret}; - use crate::{types::Curve25519SecretKey as StaticSecret, Curve25519PublicKey as PublicKey}; + use crate::{ + olm::SessionConfig, types::Curve25519SecretKey as StaticSecret, + Curve25519PublicKey as PublicKey, + }; #[test] fn triple_diffie_hellman() { let rng = thread_rng(); + let config = SessionConfig::default(); let alice_identity = StaticSecret::new(); let alice_one_time = ReusableSecret::random_from_rng(rng); @@ -135,23 +209,27 @@ mod test { let bob_one_time = StaticSecret::new(); let alice_secret = Shared3DHSecret::new( + &config, &alice_identity, &alice_one_time, &PublicKey::from(&bob_identity), &PublicKey::from(&bob_one_time), + None, ); let bob_secret = RemoteShared3DHSecret::new( + &config, &bob_identity, &bob_one_time, + None, &PublicKey::from(&alice_identity), &PublicKey::from(&alice_one_time), ); assert_eq!(alice_secret.0, bob_secret.0); - let alice_result = alice_secret.expand(); - let bob_result = bob_secret.expand(); + let alice_result = alice_secret.expand(&config); + let bob_result = bob_secret.expand(&config); assert_eq!(alice_result, bob_result); } diff --git a/src/types/curve25519.rs b/src/types/curve25519.rs index dd094a53..52d8eabc 100644 --- a/src/types/curve25519.rs +++ b/src/types/curve25519.rs @@ -22,7 +22,11 @@ use x25519_dalek::{EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, Sta use zeroize::Zeroize; use super::KeyError; -use crate::utilities::{base64_decode, base64_encode}; +use crate::{ + utilities::{base64_decode, base64_encode}, + xeddsa, + xeddsa::XEdDsaSignature, +}; /// Struct representing a Curve25519 secret key. #[derive(Clone, Deserialize, Serialize)] @@ -63,6 +67,12 @@ impl Curve25519SecretKey { key } + + /// Produce a XEdDSA signature for some message. + #[cfg(feature = "interolm")] + pub fn sign(&self, message: &[u8]) -> XEdDsaSignature { + xeddsa::sign(&self.0.to_bytes(), message) + } } impl Default for Curve25519SecretKey { @@ -120,9 +130,13 @@ impl Decode for Curve25519PublicKey { } impl Curve25519PublicKey { - /// The number of bytes a Curve25519 public key has. + /// Raw Curve25519 key length. pub const LENGTH: usize = 32; + /// Length of the alternative Curve25519 key format with a leading type byte + /// (such as used by Signal). + pub const ALTERNATIVE_LENGTH: usize = 33; + const BASE64_LENGTH: usize = 43; const PADDED_BASE64_LENGTH: usize = 44; @@ -144,10 +158,29 @@ impl Curve25519PublicKey { } /// Create a `Curve25519PublicKey` from a byte array. - pub fn from_bytes(bytes: [u8; 32]) -> Self { + pub fn from_bytes(bytes: [u8; Self::LENGTH]) -> Self { Self { inner: PublicKey::from(bytes) } } + /// Create a `Curve25519PublicKey` from a byte array representation which + /// includes a type byte marker (0x5) at the beginning. + pub fn from_bytes_interolm(bytes: [u8; Self::ALTERNATIVE_LENGTH]) -> Result { + if bytes[0] != 0x5 { + Err(KeyError::InvalidKeyFormat(bytes[0])) + } else { + Ok(Self::from_bytes(bytes[1..].try_into().expect("Must succeed as 33 - 1 is 32"))) + } + } + + pub fn to_interolm_bytes(&self) -> [u8; Self::ALTERNATIVE_LENGTH] { + let mut ret = [0u8; Self::ALTERNATIVE_LENGTH]; + + ret[0] = 0x5; + ret[1..].copy_from_slice(self.as_bytes()); + + ret + } + /// Instantiate a Curve25519 public key from an unpadded base64 /// representation. pub fn from_base64(input: &str) -> Result { @@ -172,6 +205,11 @@ impl Curve25519PublicKey { key.copy_from_slice(slice); Ok(Self::from(key)) + } else if key_len == Self::ALTERNATIVE_LENGTH { + let mut key = [0u8; Self::ALTERNATIVE_LENGTH]; + key.copy_from_slice(slice); + + Ok(Self::try_from(key)?) } else { Err(KeyError::InvalidKeyLength { key_type: "Curve25519", @@ -185,6 +223,16 @@ impl Curve25519PublicKey { pub fn to_base64(&self) -> String { base64_encode(self.inner.as_bytes()) } + + /// Verify XEdDSA signature. + #[cfg(feature = "interolm")] + pub fn verify_signature( + &self, + message: &[u8], + signature: XEdDsaSignature, + ) -> Result<(), xeddsa::SignatureError> { + xeddsa::verify(self.inner.as_bytes(), message, signature) + } } impl Display for Curve25519PublicKey { @@ -206,6 +254,14 @@ impl From<[u8; Self::LENGTH]> for Curve25519PublicKey { } } +impl TryFrom<[u8; Self::ALTERNATIVE_LENGTH]> for Curve25519PublicKey { + type Error = KeyError; + + fn try_from(bytes: [u8; Self::ALTERNATIVE_LENGTH]) -> Result { + Self::from_bytes_interolm(bytes) + } +} + impl<'a> From<&'a Curve25519SecretKey> for Curve25519PublicKey { fn from(secret: &'a Curve25519SecretKey) -> Curve25519PublicKey { Curve25519PublicKey { inner: PublicKey::from(secret.0.as_ref()) } @@ -283,4 +339,19 @@ mod tests { let base64_payload = "MDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDA"; assert!(matches!(Curve25519PublicKey::from_base64(base64_payload), Ok(..))); } + + #[test] + fn decoding_of_interolm_key_format_succeeds() { + let base64_payload = "BXMZxbCdQkXdjzdzcTtX8HWx0B087QiBqXkjZqTjHD18"; + assert!(matches!(Curve25519PublicKey::from_base64(base64_payload), Ok(..))); + } + + #[test] + fn decoding_of_interolm_key_format_with_wrong_marker_byte_fails() { + let base64_payload = "CXMZxbCdQkXdjzdzcTtX8HWx0B087QiBqXkjZqTjHD18"; + assert!(matches!( + Curve25519PublicKey::from_base64(base64_payload), + Err(KeyError::InvalidKeyFormat(..)) + )); + } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 1e1b1e2e..e53b9a9b 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::fmt::Display; + mod curve25519; mod ed25519; @@ -34,10 +36,42 @@ impl From for String { } } +impl From for KeyId { + fn from(value: u32) -> Self { + Self(value.into()) + } +} + impl KeyId { pub fn to_base64(self) -> String { crate::utilities::base64_encode(self.0.to_be_bytes()) } + + pub fn from_base64(base64: &str) -> Result { + let id = u64::from_be_bytes( + crate::utilities::base64_decode(base64)?.try_into().map_err(KeyIdError::OutOfRange)?, + ); + Ok(Self(id)) + } + + pub fn value(&self) -> u64 { + self.0 + } +} + +/// Error type describing failures when decoding a key ID. +#[derive(Error, Debug)] +pub enum KeyIdError { + #[error("Failed decoding key ID from base64: {}", .0)] + Base64Error(#[from] base64::DecodeError), + #[error("The key ID was not a valid u64 integer. Key ID bytes: {:?}", .0)] + OutOfRange(Vec), +} + +impl Display for KeyId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("KeyId({0})", self.0)) + } } /// Error type describing failures that can happen when we try decode or use a @@ -51,6 +85,8 @@ pub enum KeyError { Invalid number of bytes for {key_type}, expected {expected_length}, got {length}." )] InvalidKeyLength { key_type: &'static str, expected_length: usize, length: usize }, + #[error("The key is in the 33-byte format but the marker byte is wrong: expect 0x5, got {}", .0)] + InvalidKeyFormat(u8), #[error(transparent)] Signature(#[from] SignatureError), /// At least one of the keys did not have contributory behaviour and the diff --git a/src/xeddsa/mod.rs b/src/xeddsa/mod.rs new file mode 100644 index 00000000..0ce8f81d --- /dev/null +++ b/src/xeddsa/mod.rs @@ -0,0 +1,153 @@ +//! XEdDSA signing algorithm implementation +//! +//! Reference: + +use std::fmt::Debug; + +use rand::thread_rng; +use thiserror::Error; +use xeddsa::{ + xed25519::{PrivateKey, PublicKey}, + xeddsa::Error as XEdDsaError, + Sign, Verify, +}; + +use crate::utilities::{base64_decode, base64_encode}; + +pub const SIGNATURE_LENGTH: usize = 64; + +const CURVE25519_PUBLIC_KEY_LENGTH: usize = 32; +const CURVE25519_SECRET_KEY_LENGTH: usize = 32; + +/// An XEdDSA digital signature, can be used to verify the authenticity of a +/// message. +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct XEdDsaSignature(pub(crate) [u8; SIGNATURE_LENGTH]); + +impl XEdDsaSignature { + pub fn as_bytes(&self) -> &[u8] { + &self.0 + } + + pub fn from_base64(signature: &str) -> Result { + base64_decode(signature)?.as_slice().try_into() + } + + pub fn to_base64(&self) -> String { + base64_encode(self.0) + } +} + +impl TryFrom<&[u8]> for XEdDsaSignature { + type Error = SignatureError; + + fn try_from(value: &[u8]) -> Result { + let signature: [u8; SIGNATURE_LENGTH] = + value.try_into().map_err(|_| SignatureError::InvalidSignatureLength(value.len()))?; + Ok(Self(signature)) + } +} + +impl From<[u8; SIGNATURE_LENGTH]> for XEdDsaSignature { + fn from(value: [u8; SIGNATURE_LENGTH]) -> Self { + Self(value) + } +} + +impl Debug for XEdDsaSignature { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("XEdDSASignature").field(&self.to_base64()).finish() + } +} + +impl Verify for PublicKey { + fn verify(&self, message: &[u8], signature: &XEdDsaSignature) -> Result<(), xeddsa::Error> { + self.verify(message, &signature.0) + } +} + +/// Error type describing XEdDSA signature verification failures. +#[cfg(feature = "interolm")] +#[derive(Debug, Error)] +pub enum SignatureError { + /// The signature wasn't valid base64. + #[error("The signature couldn't be decoded: {0}")] + Base64(#[from] base64::DecodeError), + /// The decoded signature was of invalid length. + #[error("The signature has an invalid length: expected {}, got {0}", SIGNATURE_LENGTH)] + InvalidSignatureLength(usize), + /// The signature failed to be verified. + #[error("The signature was decoded successfully but is invalid.")] + InvalidSignature(#[from] XEdDsaError), +} + +pub(crate) fn sign(key: &[u8; CURVE25519_SECRET_KEY_LENGTH], message: &[u8]) -> XEdDsaSignature { + let key = PrivateKey(*key); + let rng = thread_rng(); + let result = key.sign(message, rng); + XEdDsaSignature(result) +} + +pub(crate) fn verify( + public_key: &[u8; CURVE25519_PUBLIC_KEY_LENGTH], + message: &[u8], + signature: XEdDsaSignature, +) -> Result<(), SignatureError> { + let key = PublicKey(*public_key); + + match key.verify(message, &signature) { + Ok(_) => Ok(()), + Err(e) => Err(e.into()), + } +} + +#[cfg(test)] +mod test { + use super::{sign, verify}; + use crate::{ + types::{Curve25519Keypair, Curve25519SecretKey}, + Curve25519PublicKey, XEdDsaSignature, + }; + + #[test] + pub fn test_signature_verification() { + let message = "sahasrahla"; + let key_pair = Curve25519Keypair::new(); + + let signature = sign(&key_pair.secret_key().to_bytes(), message.as_bytes()); + + verify(key_pair.public_key().as_bytes(), message.as_bytes(), signature) + .expect("The signature should be valid"); + + let corrupted_message = message.to_owned() + "!"; + + verify(key_pair.public_key().as_bytes(), corrupted_message.as_bytes(), signature) + .expect_err("The signature should be invalid"); + + let mut corrupted_signature = signature.clone(); + corrupted_signature.0[0] = signature.0[0] + 1; + verify(key_pair.public_key().as_bytes(), message.as_bytes(), corrupted_signature) + .expect_err("The signature should be invalid"); + } + + #[test] + pub fn test_known_signature() { + let message = "sahasrahla"; + let secret_key = [ + 219, 209, 232, 97, 65, 93, 1, 89, 16, 37, 173, 21, 224, 61, 51, 34, 114, 154, 249, 245, + 60, 88, 187, 216, 102, 250, 99, 184, 106, 38, 33, 139, + ]; + let signing_key = Curve25519SecretKey::from_slice(&secret_key); + let verification_key = Curve25519PublicKey::from(&signing_key); + + let signature = XEdDsaSignature([ + 10, 129, 186, 162, 96, 123, 226, 104, 147, 200, 65, 38, 35, 123, 77, 4, 195, 122, 160, + 107, 135, 83, 121, 191, 226, 9, 240, 208, 100, 126, 206, 81, 243, 31, 78, 56, 246, 235, + 244, 199, 40, 178, 96, 72, 138, 96, 47, 205, 234, 107, 101, 79, 121, 125, 178, 46, 142, + 215, 145, 247, 221, 235, 220, 3, + ]); + + verify(verification_key.as_bytes(), message.as_bytes(), signature) + .expect("The known signature should be valid."); + } +}