Skip to content

Commit

Permalink
ConstantTimeEq proc macro derive
Browse files Browse the repository at this point in the history
  • Loading branch information
varsha888 committed Mar 24, 2023
1 parent 45d890a commit 850018f
Show file tree
Hide file tree
Showing 28 changed files with 876 additions and 52 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"capable/sys",
"capable/sys/types",
"capable/types",
"constant_time_derive",
"core/build",
"core/sys/types",
"core/types",
Expand Down
14 changes: 14 additions & 0 deletions constant_time_derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "constant_time_derive"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"

[lib]
proc-macro = true

[dependencies]
proc-macro2 = "1.0.8"
quote = "1.0"
subtle = { version = "2.4.0", default-features = false }
syn = "1.0"
99 changes: 99 additions & 0 deletions constant_time_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright (c) 2023 The MobileCoin Foundation

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DataEnum, DeriveInput, Fields, GenericParam, Generics};

#[proc_macro_derive(ConstantTimeEq)]
pub fn constant_time_eq(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
derive_ct_eq(&input)
}
// TODO: Check or remove padding and align decorators on the struct
fn parse_fields(fields: &Fields) -> Result<proc_macro2::TokenStream, &'static str> {
match &fields {
Fields::Named(fields_named) => {
let mut token_stream = quote!();
let mut iter = fields_named.named.iter().peekable();

while let Some(field) = iter.next() {
let ident = &field.ident;
match iter.peek() {
None => token_stream.extend(quote! { {self.#ident}.ct_eq(&{other.#ident}) }),
Some(_) => {
token_stream.extend(quote! { {self.#ident}.ct_eq(&{other.#ident}) & })
}
}
}
Ok(token_stream)
}
Fields::Unnamed(_) => {
let token_stream = quote! { {self.0}.ct_eq({&other.0}) };

Ok(token_stream)
}
Fields::Unit => Err("Constant time cannot be derived for unit fields"),
}
}

fn parse_enum(data_enum: &DataEnum) -> Result<proc_macro2::TokenStream, &'static str> {
for variant in data_enum.variants.iter() {
if let Fields::Unnamed(_) = variant.fields {
panic!("Cannot derive ct_eq for fields in enums")
}
}
let token_stream = quote! {
if self == other {
::subtle::Choice::from(1)
}
else {
::subtle::Choice::from(0)
}
};

Ok(token_stream)
}

fn parse_data(data: &Data) -> Result<proc_macro2::TokenStream, &'static str> {
match data {
Data::Struct(variant_data) => parse_fields(&variant_data.fields),
Data::Enum(data_enum) => parse_enum(data_enum),
Data::Union(..) => Err("Constant time cannot be derived for a union"),
}
}

fn parse_lifetime(generics: &Generics) -> bool {
for i in generics.params.iter() {
if let GenericParam::Lifetime(_) = i {
return true;
}
}
false
}

fn derive_ct_eq(input: &DeriveInput) -> TokenStream {
let ident = &input.ident;
let data = &input.data;
let generics = &input.generics;

let is_lifetime = parse_lifetime(generics);
let ct_eq_stream: proc_macro2::TokenStream =
parse_data(data).expect("Failed to parse DeriveInput data");
let data_ident = if is_lifetime {
format!("{}<'_>", ident)
} else {
ident.to_string()
};
let ident_stream: proc_macro2::TokenStream = data_ident.parse().unwrap();

let expanded: proc_macro2::TokenStream = quote! {
impl ::subtle::ConstantTimeEq for #ident_stream {
fn ct_eq(&self, other: &Self) -> ::subtle::Choice {
use ::subtle::ConstantTimeEq;
return #ct_eq_stream
}
}
};

expanded.into()
}
20 changes: 20 additions & 0 deletions core/build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ pub struct SgxParseCallbacks {

// Dynamically Sized types
dynamically_sized_types: Vec<String>,

constant_time_types: Vec<String>,
}

impl SgxParseCallbacks {
Expand Down Expand Up @@ -176,6 +178,20 @@ impl SgxParseCallbacks {
self
}

/// Types to derive constant time for
///
/// # Arguments
/// * `default_types` - Types to derive default for.
pub fn derive_constant_time<'a, E, I>(mut self, constant_time_types: I) -> Self
where
I: IntoIterator<Item = &'a E>,
E: ToString + 'a + ?Sized,
{
self.constant_time_types
.extend(constant_time_types.into_iter().map(ToString::to_string));
self
}

/// Dynamically Sized Types
///
/// # Arguments
Expand Down Expand Up @@ -205,6 +221,10 @@ impl ParseCallbacks for SgxParseCallbacks {
attributes.push("Default");
}

if self.constant_time_types.iter().any(|n| *n == name) {
attributes.push("constant_time_derive::ConstantTimeEq");
}

// The [enum_types] method adds enums to the [copyable_types]
if self.copyable_types.iter().any(|n| *n == name) {
attributes.push("Copy");
Expand Down
4 changes: 4 additions & 0 deletions core/sys/types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ rust-version = "1.62.1"
[lib]
doctest = false

[dependencies]
constant_time_derive = { path = "../../../constant_time_derive", version = "0.1.0" }
subtle = { version = "2.4.0", default-features = false }

[build-dependencies]
bindgen = "0.64.0"
cargo-emit = "0.2.1"
Expand Down
30 changes: 30 additions & 0 deletions core/sys/types/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ fn main() {
"sgx_measurement_t",
"sgx_report_data_t",
"sgx_attributes_t",
"sgx_key_request_t",
"sgx_platform_info_t",
"sgx_basename_t",
])
.dynamically_sized_types(["sgx_quote_t"])
.derive_default([
Expand All @@ -108,6 +111,33 @@ fn main() {
"sgx_quote_nonce_t",
"sgx_update_info_bit_t",
"sgx_qe_report_info_t",
])
.derive_constant_time([
"sgx_att_key_id_ext_t",
"sgx_ql_att_key_id_t",
"sgx_measurement_t",
"sgx_attributes_t",
"sgx_key_id_t",
"sgx_config_id_t",
"sgx_key_request_t",
"sgx_qe_report_info_t",
"sgx_platform_info_t",
"sgx_update_info_bit_t",
"sgx_epid_group_id_t",
"sgx_basename_t",
"sgx_quote_nonce_t",
"sgx_report_t",
"sgx_target_info_t",
"sgx_cpu_svn_t",
"sgx_mac_t",
"sgx_report_data_t",
"sgx_isvfamily_id_t",
"sgx_isvext_prod_id_t",
"sgx_prod_id_t",
"sgx_config_svn_t",
"sgx_isv_svn_t",
"sgx_cpu_svn_t",
"sgx_report_body_t",
]);
let mut builder = mc_sgx_core_build::sgx_builder()
.header("wrapper.h")
Expand Down
2 changes: 2 additions & 0 deletions core/types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ alloc = []

[dependencies]
bitflags = "2.0.0"
constant_time_derive = { path = "../../constant_time_derive", version = "0.1.0" }
displaydoc = { version = "0.2.3", default-features = false }
mc-sgx-core-sys-types = { path = "../sys/types", version = "=0.5.1-beta.0" }
mc-sgx-util = { path = "../../util", version = "=0.5.1-beta.0" }
nom = { version = "7.1.2", default-features = false }
rand_core = { version = "0.6.4", default-features = false }
serde = { version = "1.0.152", default-features = false, features = ["derive"], optional = true }
subtle = { version = "2.4.0", default-features = false, features = ["i128"] }

# `getrandom` is pulled in by `rand_core` we only need to access it directly when registering a custom spng,
# `register_custom_getrandom`, which only happens for target_os = none
Expand Down
Loading

0 comments on commit 850018f

Please sign in to comment.