Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPC-TLS actor refactor #405

Merged
merged 14 commits into from
Jan 15, 2024
2 changes: 1 addition & 1 deletion components/prf/hmac-sha256/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::prf::state::StateError;
#[allow(missing_docs)]
pub enum PrfError {
#[error("MPC backend error: {0:?}")]
Mpc(Box<dyn Error + Send>),
Mpc(Box<dyn Error + Send + Sync>),
#[error("role error: {0:?}")]
RoleError(String),
#[error("Invalid state: {0}")]
Expand Down
2 changes: 1 addition & 1 deletion components/tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ rstest = "0.12"

# misc
derive_builder = "0.12"

enum-try-as-inner = "0.1"
web-time = "0.2"
25 changes: 12 additions & 13 deletions components/tls/tls-backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
#![deny(clippy::all)]
#![forbid(unsafe_code)]

use std::any::Any;

use async_trait::async_trait;
use tls_core::{
cert::ServerCertDetails,
Expand Down Expand Up @@ -68,11 +66,6 @@ pub enum DecryptMode {
/// and decryption.
#[async_trait]
pub trait Backend: Send {
/// Returns reference to `Any` trait object.
fn as_any(&self) -> &dyn Any;
/// Returns mutable reference to `Any` trait object.
fn as_any_mut(&mut self) -> &mut dyn Any;

/// Signals selected protocol version to implementor.
/// Throws error if version is not supported.
async fn set_protocol_version(&mut self, version: ProtocolVersion) -> Result<(), BackendError>;
Expand All @@ -94,17 +87,23 @@ pub trait Backend: Send {
/// Sets server keyshare.
async fn set_server_key_share(&mut self, key: PublicKey) -> Result<(), BackendError>;
/// Sets the server cert chain
fn set_server_cert_details(&mut self, cert_details: ServerCertDetails);
async fn set_server_cert_details(
&mut self,
cert_details: ServerCertDetails,
) -> Result<(), BackendError>;
/// Sets the server kx details
fn set_server_kx_details(&mut self, kx_details: ServerKxDetails);
async fn set_server_kx_details(
&mut self,
kx_details: ServerKxDetails,
) -> Result<(), BackendError>;
/// Sets handshake hash at ClientKeyExchange for EMS.
async fn set_hs_hash_client_key_exchange(&mut self, hash: &[u8]) -> Result<(), BackendError>;
async fn set_hs_hash_client_key_exchange(&mut self, hash: Vec<u8>) -> Result<(), BackendError>;
/// Sets handshake hash at ServerHello.
async fn set_hs_hash_server_hello(&mut self, hash: &[u8]) -> Result<(), BackendError>;
async fn set_hs_hash_server_hello(&mut self, hash: Vec<u8>) -> Result<(), BackendError>;
/// Returns expected ServerFinished verify_data.
async fn get_server_finished_vd(&mut self, hash: &[u8]) -> Result<Vec<u8>, BackendError>;
async fn get_server_finished_vd(&mut self, hash: Vec<u8>) -> Result<Vec<u8>, BackendError>;
/// Returns ClientFinished verify_data.
async fn get_client_finished_vd(&mut self, hash: &[u8]) -> Result<Vec<u8>, BackendError>;
async fn get_client_finished_vd(&mut self, hash: Vec<u8>) -> Result<Vec<u8>, BackendError>;
/// Prepares the backend for encryption.
async fn prepare_encryption(&mut self) -> Result<(), BackendError>;
/// Perform the encryption over the concerned TLS message.
Expand Down
34 changes: 18 additions & 16 deletions components/tls/tls-client/src/backend/standard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,6 @@ impl RustCryptoBackend {

#[async_trait]
impl Backend for RustCryptoBackend {
fn as_any(&self) -> &dyn Any {
self
}

fn as_any_mut(&mut self) -> &mut dyn Any {
self
}

async fn set_protocol_version(&mut self, version: ProtocolVersion) -> Result<(), BackendError> {
match version {
ProtocolVersion::TLSv1_2 => {
Expand Down Expand Up @@ -373,42 +365,52 @@ impl Backend for RustCryptoBackend {
Ok(())
}

fn set_server_cert_details(&mut self, _cert_details: ServerCertDetails) {}
async fn set_server_cert_details(
&mut self,
_cert_details: ServerCertDetails,
) -> Result<(), BackendError> {
Ok(())
}

fn set_server_kx_details(&mut self, _kx_details: ServerKxDetails) {}
async fn set_server_kx_details(
&mut self,
_kx_details: ServerKxDetails,
) -> Result<(), BackendError> {
Ok(())
}

async fn set_hs_hash_client_key_exchange(&mut self, hash: &[u8]) -> Result<(), BackendError> {
async fn set_hs_hash_client_key_exchange(&mut self, hash: Vec<u8>) -> Result<(), BackendError> {
self.ems_seed = Some(hash.to_vec());
Ok(())
}

async fn set_hs_hash_server_hello(&mut self, _hash: &[u8]) -> Result<(), BackendError> {
async fn set_hs_hash_server_hello(&mut self, _hash: Vec<u8>) -> Result<(), BackendError> {
Ok(())
}

async fn get_server_finished_vd(&mut self, hash: &[u8]) -> Result<Vec<u8>, BackendError> {
async fn get_server_finished_vd(&mut self, hash: Vec<u8>) -> Result<Vec<u8>, BackendError> {
let ms = self.master_secret.ok_or(BackendError::InvalidState(
"Master secret not set".to_string(),
))?;

let verify_data = match self.protocol_version.ok_or(BackendError::InvalidState(
"Protocol version not set".to_string(),
))? {
ProtocolVersion::TLSv1_2 => self.verify_data_sf_tls12(hash, &ms),
ProtocolVersion::TLSv1_2 => self.verify_data_sf_tls12(&hash, &ms),
_ => unreachable!(),
};
Ok(verify_data.to_vec())
}

async fn get_client_finished_vd(&mut self, hash: &[u8]) -> Result<Vec<u8>, BackendError> {
async fn get_client_finished_vd(&mut self, hash: Vec<u8>) -> Result<Vec<u8>, BackendError> {
let ms = self.master_secret.ok_or(BackendError::InvalidState(
"Master secret not set".to_string(),
))?;

let verify_data = match self.protocol_version.ok_or(BackendError::InvalidState(
"Protocol version not set".to_string(),
))? {
ProtocolVersion::TLSv1_2 => self.verify_data_cf_tls12(hash, &ms),
ProtocolVersion::TLSv1_2 => self.verify_data_cf_tls12(&hash, &ms),
_ => unreachable!(),
};
Ok(verify_data.to_vec())
Expand Down
20 changes: 13 additions & 7 deletions components/tls/tls-client/src/client/tls12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ impl State<ClientConnectionData> for ExpectCertificate {

cx.common
.backend
.set_server_cert_details(server_cert.clone());
.set_server_cert_details(server_cert.clone())
.await?;

Ok(Box::new(ExpectServerKx {
config: self.config,
Expand Down Expand Up @@ -296,7 +297,8 @@ impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {

cx.common
.backend
.set_server_cert_details(server_cert_details.clone());
.set_server_cert_details(server_cert_details.clone())
.await?;

Box::new(ExpectServerKx {
config: self.config,
Expand Down Expand Up @@ -387,7 +389,8 @@ impl State<ClientConnectionData> for ExpectCertificateStatus {

cx.common
.backend
.set_server_cert_details(server_cert.clone());
.set_server_cert_details(server_cert.clone())
.await?;

Ok(Box::new(ExpectServerKx {
config: self.config,
Expand Down Expand Up @@ -843,7 +846,7 @@ impl State<ClientConnectionData> for ExpectServerDone {

cx.common
.backend
.set_hs_hash_client_key_exchange(ems_seed.as_ref())
.set_hs_hash_client_key_exchange(ems_seed.as_ref().to_vec())
.await?;

// 5c.
Expand All @@ -858,7 +861,10 @@ impl State<ClientConnectionData> for ExpectServerDone {
let server_key_share =
PublicKey::new(ecdh_params.curve_params.named_group, &ecdh_params.public.0);

cx.common.backend.set_server_kx_details(st.server_kx);
cx.common
.backend
.set_server_kx_details(st.server_kx)
.await?;
cx.common
.backend
.set_server_key_share(server_key_share)
Expand All @@ -877,7 +883,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
let cf = cx
.common
.backend
.get_client_finished_vd(hs.as_ref())
.get_client_finished_vd(hs.as_ref().to_vec())
.await?;
emit_finished(&cf, &mut transcript, cx.common).await?;

Expand Down Expand Up @@ -1089,7 +1095,7 @@ impl State<ClientConnectionData> for ExpectFinished {
let expect_verify_data = cx
.common
.backend
.get_server_finished_vd(vh.as_ref())
.get_server_finished_vd(vh.as_ref().to_vec())
.await?;

// Constant-time verification of this is relatively unimportant: they only
Expand Down
6 changes: 3 additions & 3 deletions components/tls/tls-client/src/client/tls13.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pub(super) async fn handle_server_hello(

cx.common
.backend
.set_hs_hash_server_hello(transcript.get_current_hash().as_ref())
.set_hs_hash_server_hello(transcript.get_current_hash().as_ref().to_vec())
.await?;

// Decrypt with the peer's key, encrypt with our own key
Expand Down Expand Up @@ -768,7 +768,7 @@ impl State<ClientConnectionData> for ExpectFinished {
let expect_verify_data = cx
.common
.backend
.get_server_finished_vd(handshake_hash.as_ref())
.get_server_finished_vd(handshake_hash.as_ref().to_vec())
.await?;

let fin = match constant_time::verify_slices_are_equal(
Expand Down Expand Up @@ -829,7 +829,7 @@ impl State<ClientConnectionData> for ExpectFinished {
let client_finished = cx
.common
.backend
.get_client_finished_vd(handshake_hash.as_ref())
.get_client_finished_vd(handshake_hash.as_ref().to_vec())
.await?;
emit_finished_tls13(&client_finished, &mut st.transcript, cx.common).await?;

Expand Down
3 changes: 3 additions & 0 deletions components/tls/tls-mpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ edition = "2021"
name = "tls_mpc"

[features]
default = ["tracing"]
tracing = [
"dep:tracing",
"tlsn-block-cipher/tracing",
Expand Down Expand Up @@ -49,8 +50,10 @@ futures.workspace = true
async-trait.workspace = true
serde.workspace = true
derive_builder.workspace = true
enum-try-as-inner.workspace = true
thiserror.workspace = true
tracing = { workspace = true, optional = true }
ludi = { git = "https://github.com/sinui0/ludi", rev = "b590de5" }

[dev-dependencies]
tlsn-tls-client = { path = "../tls-client" }
Expand Down
Loading