Skip to content

Commit

Permalink
ludi refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
sinui0 committed Jan 10, 2024
1 parent ffaac39 commit 41b1e70
Show file tree
Hide file tree
Showing 15 changed files with 939 additions and 560 deletions.
25 changes: 12 additions & 13 deletions components/tls/tls-backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
#![deny(clippy::all)]
#![forbid(unsafe_code)]

use std::any::Any;

use async_trait::async_trait;
use tls_core::{
cert::ServerCertDetails,
Expand Down Expand Up @@ -68,11 +66,6 @@ pub enum DecryptMode {
/// and decryption.
#[async_trait]
pub trait Backend: Send {
/// Returns reference to `Any` trait object.
fn as_any(&self) -> &dyn Any;
/// Returns mutable reference to `Any` trait object.
fn as_any_mut(&mut self) -> &mut dyn Any;

/// Signals selected protocol version to implementor.
/// Throws error if version is not supported.
async fn set_protocol_version(&mut self, version: ProtocolVersion) -> Result<(), BackendError>;
Expand All @@ -94,17 +87,23 @@ pub trait Backend: Send {
/// Sets server keyshare.
async fn set_server_key_share(&mut self, key: PublicKey) -> Result<(), BackendError>;
/// Sets the server cert chain
fn set_server_cert_details(&mut self, cert_details: ServerCertDetails);
async fn set_server_cert_details(
&mut self,
cert_details: ServerCertDetails,
) -> Result<(), BackendError>;
/// Sets the server kx details
fn set_server_kx_details(&mut self, kx_details: ServerKxDetails);
async fn set_server_kx_details(
&mut self,
kx_details: ServerKxDetails,
) -> Result<(), BackendError>;
/// Sets handshake hash at ClientKeyExchange for EMS.
async fn set_hs_hash_client_key_exchange(&mut self, hash: &[u8]) -> Result<(), BackendError>;
async fn set_hs_hash_client_key_exchange(&mut self, hash: Vec<u8>) -> Result<(), BackendError>;
/// Sets handshake hash at ServerHello.
async fn set_hs_hash_server_hello(&mut self, hash: &[u8]) -> Result<(), BackendError>;
async fn set_hs_hash_server_hello(&mut self, hash: Vec<u8>) -> Result<(), BackendError>;
/// Returns expected ServerFinished verify_data.
async fn get_server_finished_vd(&mut self, hash: &[u8]) -> Result<Vec<u8>, BackendError>;
async fn get_server_finished_vd(&mut self, hash: Vec<u8>) -> Result<Vec<u8>, BackendError>;
/// Returns ClientFinished verify_data.
async fn get_client_finished_vd(&mut self, hash: &[u8]) -> Result<Vec<u8>, BackendError>;
async fn get_client_finished_vd(&mut self, hash: Vec<u8>) -> Result<Vec<u8>, BackendError>;
/// Prepares the backend for encryption.
async fn prepare_encryption(&mut self) -> Result<(), BackendError>;
/// Perform the encryption over the concerned TLS message.
Expand Down
34 changes: 18 additions & 16 deletions components/tls/tls-client/src/backend/standard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,6 @@ impl RustCryptoBackend {

#[async_trait]
impl Backend for RustCryptoBackend {
fn as_any(&self) -> &dyn Any {
self
}

fn as_any_mut(&mut self) -> &mut dyn Any {
self
}

async fn set_protocol_version(&mut self, version: ProtocolVersion) -> Result<(), BackendError> {
match version {
ProtocolVersion::TLSv1_2 => {
Expand Down Expand Up @@ -373,42 +365,52 @@ impl Backend for RustCryptoBackend {
Ok(())
}

fn set_server_cert_details(&mut self, _cert_details: ServerCertDetails) {}
async fn set_server_cert_details(
&mut self,
_cert_details: ServerCertDetails,
) -> Result<(), BackendError> {
Ok(())
}

fn set_server_kx_details(&mut self, _kx_details: ServerKxDetails) {}
async fn set_server_kx_details(
&mut self,
_kx_details: ServerKxDetails,
) -> Result<(), BackendError> {
Ok(())
}

async fn set_hs_hash_client_key_exchange(&mut self, hash: &[u8]) -> Result<(), BackendError> {
async fn set_hs_hash_client_key_exchange(&mut self, hash: Vec<u8>) -> Result<(), BackendError> {
self.ems_seed = Some(hash.to_vec());
Ok(())
}

async fn set_hs_hash_server_hello(&mut self, _hash: &[u8]) -> Result<(), BackendError> {
async fn set_hs_hash_server_hello(&mut self, _hash: Vec<u8>) -> Result<(), BackendError> {
Ok(())
}

async fn get_server_finished_vd(&mut self, hash: &[u8]) -> Result<Vec<u8>, BackendError> {
async fn get_server_finished_vd(&mut self, hash: Vec<u8>) -> Result<Vec<u8>, BackendError> {
let ms = self.master_secret.ok_or(BackendError::InvalidState(
"Master secret not set".to_string(),
))?;

let verify_data = match self.protocol_version.ok_or(BackendError::InvalidState(
"Protocol version not set".to_string(),
))? {
ProtocolVersion::TLSv1_2 => self.verify_data_sf_tls12(hash, &ms),
ProtocolVersion::TLSv1_2 => self.verify_data_sf_tls12(&hash, &ms),
_ => unreachable!(),
};
Ok(verify_data.to_vec())
}

async fn get_client_finished_vd(&mut self, hash: &[u8]) -> Result<Vec<u8>, BackendError> {
async fn get_client_finished_vd(&mut self, hash: Vec<u8>) -> Result<Vec<u8>, BackendError> {
let ms = self.master_secret.ok_or(BackendError::InvalidState(
"Master secret not set".to_string(),
))?;

let verify_data = match self.protocol_version.ok_or(BackendError::InvalidState(
"Protocol version not set".to_string(),
))? {
ProtocolVersion::TLSv1_2 => self.verify_data_cf_tls12(hash, &ms),
ProtocolVersion::TLSv1_2 => self.verify_data_cf_tls12(&hash, &ms),
_ => unreachable!(),
};
Ok(verify_data.to_vec())
Expand Down
20 changes: 13 additions & 7 deletions components/tls/tls-client/src/client/tls12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ impl State<ClientConnectionData> for ExpectCertificate {

cx.common
.backend
.set_server_cert_details(server_cert.clone());
.set_server_cert_details(server_cert.clone())
.await?;

Ok(Box::new(ExpectServerKx {
config: self.config,
Expand Down Expand Up @@ -296,7 +297,8 @@ impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {

cx.common
.backend
.set_server_cert_details(server_cert_details.clone());
.set_server_cert_details(server_cert_details.clone())
.await?;

Box::new(ExpectServerKx {
config: self.config,
Expand Down Expand Up @@ -387,7 +389,8 @@ impl State<ClientConnectionData> for ExpectCertificateStatus {

cx.common
.backend
.set_server_cert_details(server_cert.clone());
.set_server_cert_details(server_cert.clone())
.await?;

Ok(Box::new(ExpectServerKx {
config: self.config,
Expand Down Expand Up @@ -843,7 +846,7 @@ impl State<ClientConnectionData> for ExpectServerDone {

cx.common
.backend
.set_hs_hash_client_key_exchange(ems_seed.as_ref())
.set_hs_hash_client_key_exchange(ems_seed.as_ref().to_vec())
.await?;

// 5c.
Expand All @@ -858,7 +861,10 @@ impl State<ClientConnectionData> for ExpectServerDone {
let server_key_share =
PublicKey::new(ecdh_params.curve_params.named_group, &ecdh_params.public.0);

cx.common.backend.set_server_kx_details(st.server_kx);
cx.common
.backend
.set_server_kx_details(st.server_kx)
.await?;
cx.common
.backend
.set_server_key_share(server_key_share)
Expand All @@ -877,7 +883,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
let cf = cx
.common
.backend
.get_client_finished_vd(hs.as_ref())
.get_client_finished_vd(hs.as_ref().to_vec())
.await?;
emit_finished(&cf, &mut transcript, cx.common).await?;

Expand Down Expand Up @@ -1089,7 +1095,7 @@ impl State<ClientConnectionData> for ExpectFinished {
let expect_verify_data = cx
.common
.backend
.get_server_finished_vd(vh.as_ref())
.get_server_finished_vd(vh.as_ref().to_vec())
.await?;

// Constant-time verification of this is relatively unimportant: they only
Expand Down
6 changes: 3 additions & 3 deletions components/tls/tls-client/src/client/tls13.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pub(super) async fn handle_server_hello(

cx.common
.backend
.set_hs_hash_server_hello(transcript.get_current_hash().as_ref())
.set_hs_hash_server_hello(transcript.get_current_hash().as_ref().to_vec())
.await?;

// Decrypt with the peer's key, encrypt with our own key
Expand Down Expand Up @@ -768,7 +768,7 @@ impl State<ClientConnectionData> for ExpectFinished {
let expect_verify_data = cx
.common
.backend
.get_server_finished_vd(handshake_hash.as_ref())
.get_server_finished_vd(handshake_hash.as_ref().to_vec())
.await?;

let fin = match constant_time::verify_slices_are_equal(
Expand Down Expand Up @@ -829,7 +829,7 @@ impl State<ClientConnectionData> for ExpectFinished {
let client_finished = cx
.common
.backend
.get_client_finished_vd(handshake_hash.as_ref())
.get_client_finished_vd(handshake_hash.as_ref().to_vec())
.await?;
emit_finished_tls13(&client_finished, &mut st.transcript, cx.common).await?;

Expand Down
1 change: 0 additions & 1 deletion components/tls/tls-core/src/msgs/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ impl MessagePayload {
/// buffers as well as for fragmenting, joining and encryption/decryption. It can be converted
/// into a `Message` by decoding the payload.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct OpaqueMessage {
pub typ: ContentType,
pub version: ProtocolVersion,
Expand Down
1 change: 1 addition & 0 deletions components/tls/tls-mpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ derive_builder.workspace = true
enum-try-as-inner.workspace = true
thiserror.workspace = true
tracing = { workspace = true, optional = true }
ludi = { git = "https://github.com/sinui0/ludi", rev = "dcfe639" }

[dev-dependencies]
tlsn-tls-client = { path = "../tls-client" }
Expand Down
79 changes: 60 additions & 19 deletions components/tls/tls-mpc/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
use tls_core::msgs::enums::{ContentType, NamedGroup};
use std::error::Error;

use crate::msg::MpcTlsMessageError;
use tls_backend::BackendError;
use tls_core::msgs::enums::{ContentType, NamedGroup};

/// An error type for this crate
#[allow(missing_docs)]
#[derive(Debug, thiserror::Error)]
pub enum MpcTlsError {
#[error(transparent)]
#[error("io error")]
IOError(#[from] std::io::Error),
#[error(transparent)]
MuxerError(#[from] utils_aio::mux::MuxerError),
#[error(transparent)]
VmError(#[from] mpz_garble::VmError),
#[error(transparent)]
KeyExchangeError(#[from] key_exchange::KeyExchangeError),
#[error(transparent)]
PrfError(#[from] hmac_sha256::PrfError),
#[error(transparent)]
AeadError(#[from] aead::AeadError),
#[error("no committed message")]
#[error("mpc error")]
Mpc { source: Box<dyn Error + Send> },
#[error("key exchange error")]
KeyExchange { source: Box<dyn Error + Send> },
#[error("prf error")]
Prf { source: Box<dyn Error + Send> },
#[error("encryption error")]
Encryption { source: Box<dyn Error + Send> },
#[error("decryption error")]
Decryption { source: Box<dyn Error + Send> },
#[error("peer misbehaved")]
PeerMisbehaved,
#[error("invalid state: {0}")]
StateError(String),
#[error("missing handshake commitment")]
NoHandshakeCommitment,
#[error("missing committed message")]
NoCommittedMessage,
#[error("unexpected content type")]
UnexpectedContentType(ContentType),
Expand All @@ -30,6 +37,10 @@ pub enum MpcTlsError {
UnexpectedSequenceNumber(u64),
#[error("not set up")]
NotSetUp,
#[error("protocol version not set")]
ProtocolVersionNotSet,
#[error("cipher suite not set")]
CipherSuiteNotSet,
#[error("server key not set")]
ServerKeyNotSet,
#[error("server cert not set")]
Expand All @@ -48,12 +59,42 @@ pub enum MpcTlsError {
ReceivedFatalAlert,
#[error("payload decoding error")]
PayloadDecodingError,
#[error("leader closed the connection abruptly")]
LeaderClosedAbruptly,
#[error("actor error")]
ActorError,
}

impl From<ludi::MessageError> for MpcTlsError {
fn from(_: ludi::MessageError) -> Self {
MpcTlsError::ActorError
}
}

impl From<mpz_garble::VmError> for MpcTlsError {
fn from(err: mpz_garble::VmError) -> Self {
MpcTlsError::Mpc {
source: Box::new(err),
}
}
}

impl From<key_exchange::KeyExchangeError> for MpcTlsError {
fn from(err: key_exchange::KeyExchangeError) -> Self {
MpcTlsError::KeyExchange {
source: Box::new(err),
}
}
}

impl From<hmac_sha256::PrfError> for MpcTlsError {
fn from(err: hmac_sha256::PrfError) -> Self {
MpcTlsError::Prf {
source: Box::new(err),
}
}
}

impl From<MpcTlsMessageError> for MpcTlsError {
fn from(err: MpcTlsMessageError) -> Self {
std::io::Error::new(std::io::ErrorKind::InvalidData, err).into()
impl From<MpcTlsError> for BackendError {
fn from(err: MpcTlsError) -> Self {
BackendError::InternalError(err.to_string())
}
}
Loading

0 comments on commit 41b1e70

Please sign in to comment.