diff --git a/ipa-step-test/Cargo.toml b/ipa-step-test/Cargo.toml index 18bb52dfe..de5ec4256 100644 --- a/ipa-step-test/Cargo.toml +++ b/ipa-step-test/Cargo.toml @@ -7,6 +7,7 @@ build = "build.rs" [dependencies] ipa-step = { path = "../ipa-step" } ipa-step-derive = { path = "../ipa-step-derive" } +trybuild = "1.0" [build-dependencies] ipa-step = { path = "../ipa-step", features = ["build"] } diff --git a/ipa-step/src/gate.rs b/ipa-step/src/gate.rs index da8571f8d..a5de3f229 100644 --- a/ipa-step/src/gate.rs +++ b/ipa-step/src/gate.rs @@ -7,7 +7,7 @@ use proc_macro2::TokenStream; use quote::quote; use syn::{parse2, parse_str, Ident, Path}; -use crate::{name::GateName, CompactGateIndex, CompactStep}; +use crate::{hashing::HashingSteps, name::GateName, CompactGateIndex, CompactStep}; fn crate_path(p: &str) -> String { let Some((_, p)) = p.split_once("::") else { @@ -45,6 +45,16 @@ fn build_narrows( } /// Write code for the `CompactGate` implementation related to `S` to the determined file. +/// +/// Maps the string representation of a gate to its corresponding type. +/// Rust compiler does not really like long functions, so these were put in place +/// when we saw long compilation times and pinned them to large match statements inside +/// AsRef and FromStr implementations. +/// Maps the string representation of a gate to its corresponding type. This array stores +/// a pair (hash(s), gate) and is sorted by hash value. This allows us to do binary search +/// for a gate by its string representation. The cost of this operation at runtime +/// is: O(hash) + O(log(n)). +// static STR_TO_GATE: [(u64, u32); #step_count] = [#(#from_arms),*]; /// # Panics /// For various reasons when the type of `S` takes a form that is surprising. pub fn build() { @@ -54,50 +64,56 @@ pub fn build() { }; let name_maker = GateName::new(name); let gate_name = name_maker.name(); - let out = PathBuf::from(env::var("OUT_DIR").unwrap()).join(name_maker.filename()); println!("writing Gate implementation {gate_name} (for {step_name}) to {out:?}"); + let gate_impl = compact_gate_impl::(&gate_name); + + write(out, prettyplease::unparse(&parse2(gate_impl).unwrap())).unwrap(); +} +fn compact_gate_impl(gate_name: &str) -> TokenStream { let mut step_narrows = HashMap::new(); step_narrows.insert(std::any::type_name::(), vec![0]); // Add the first step. - let mut as_ref_arms = TokenStream::new(); - let mut from_arms = TokenStream::new(); - let ident: Ident = parse_str(&gate_name).unwrap(); + let step_count = usize::try_from(S::STEP_COUNT).unwrap(); + // this is an array of gate names indexed by the compact gate index. + let mut gate_names = Vec::with_capacity(step_count); + // reverse mapping of gate names to compact gate index. + let mut step_hasher = HashingSteps::default(); + + let ident: Ident = parse_str(gate_name).unwrap(); for i in 1..=S::STEP_COUNT { let s = String::from("/") + &S::step_string(i - 1); - as_ref_arms.extend(quote! { - #i => #s, - }); - from_arms.extend(quote! { - #s => Ok(#ident(#i)), - }); + step_hasher.hash(&s, i); + gate_names.push(s); + if let Some(t) = S::step_narrow_type(i - 1) { step_narrows.entry(t).or_insert_with(Vec::new).push(i); } } let from_panic = format!("unknown string for {gate_name}: \"{{s}}\""); + let str_lookup_type: Ident = step_hasher.lookup_type(); let mut syntax = quote! { + + static GATE_LOOKUP: [&str; #step_count] = [#(#gate_names),*]; + static STR_LOOKUP: #str_lookup_type = #step_hasher + impl ::std::convert::AsRef for #ident { - #[allow(clippy::too_many_lines)] fn as_ref(&self) -> &str { - match self.0 { + match usize::try_from(self.0).unwrap() { 0 => "/", - #as_ref_arms - _ => unreachable!(), + i => GATE_LOOKUP[i - 1], } } } impl ::std::str::FromStr for #ident { type Err = String; - #[allow(clippy::too_many_lines)] fn from_str(s: &str) -> Result { match s { "/" => Ok(Self::default()), - #from_arms - _ => Err(format!(#from_panic)), + v => STR_LOOKUP.find(v).map(#ident).ok_or_else(|| format!(#from_panic)), } } } @@ -108,7 +124,40 @@ pub fn build() { } } }; - build_narrows(&ident, &gate_name, step_narrows, &mut syntax); + build_narrows(&ident, gate_name, step_narrows, &mut syntax); + + syntax +} - write(out, prettyplease::unparse(&parse2(syntax).unwrap())).unwrap(); +#[cfg(test)] +mod test { + use crate::{CompactGateIndex, CompactStep, Step}; + + struct HashCollision; + + impl Step for HashCollision {} + + impl AsRef for HashCollision { + fn as_ref(&self) -> &str { + std::any::type_name::() + } + } + + impl CompactStep for HashCollision { + const STEP_COUNT: CompactGateIndex = 2; + + fn base_index(&self) -> CompactGateIndex { + 0 + } + + fn step_string(_i: CompactGateIndex) -> String { + "same-step".to_string() + } + } + + #[test] + #[should_panic(expected = "Hash collision for /same-step")] + fn panics_on_hash_collision() { + super::compact_gate_impl::("Gate"); + } } diff --git a/ipa-step/src/hashing.rs b/ipa-step/src/hashing.rs new file mode 100644 index 000000000..392c92006 --- /dev/null +++ b/ipa-step/src/hashing.rs @@ -0,0 +1,78 @@ +use std::{ + collections::BTreeMap, + hash::{DefaultHasher, Hash, Hasher}, +}; + +use proc_macro2::{Ident, TokenStream}; +use quote::{quote, ToTokens}; +use syn::parse_str; + +use crate::CompactGateIndex; + +/// Builds a map of step strings to the corresponding compact gate index. Emits an array of tuples +/// containing the hash and the index, sorted by hash. [`std::str::FromStr`] implementation for +/// compact gate uses the same hashing algorithm for input string and uses the provided code to +/// run binary search and find the item. +/// +/// The complexity of this operation at compile time is O(n) and the cost is hash(str)*n. +/// Runtime overhead is proportional to hash(str)+log(n) +#[derive(Default)] +pub(crate) struct HashingSteps { + inner: BTreeMap, +} + +fn hash(s: &str) -> u64 { + let mut hasher = DefaultHasher::default(); + s.hash(&mut hasher); + hasher.finish() +} + +impl HashingSteps { + /// Add a step to the map. + /// ## Panics + /// if the step already added or if there is a hash collision with any of the steps already + /// added. + pub fn hash(&mut self, step: &str, gate: CompactGateIndex) { + let h = hash(step); + if let Some(old_val) = self.inner.insert(h, gate) { + panic!("Hash collision for {step}: {h} => {old_val} and {gate}. Check that there are no duplicate steps defined in the protocol."); + } + } + + pub fn lookup_type(&self) -> Ident { + parse_str("GateLookup").unwrap() + } +} + +impl ToTokens for HashingSteps { + fn to_tokens(&self, tokens: &mut TokenStream) { + let lookup_type = self.lookup_type(); + let sz = self.inner.len(); + let hashes = self.inner.iter().map(|(h, i)| quote! {(#h, #i)}); + + tokens.extend(quote! { + #lookup_type { + #[allow(clippy::unreadable_literal)] + inner: [#(#hashes),*] + }; + + struct #lookup_type { + inner: [(u64, u32); #sz] + } + + impl #lookup_type { + fn find(&self, input: &str) -> Option { + let h = Self::hash(input); + self.inner.binary_search_by_key(&h, |(hash, _)| *hash).ok().map(|i| self.inner[i].1) + } + + /// This must be kept in sync with proc-macro code that generates the hash. + fn hash(s: &str) -> u64 { + let mut hasher = ::std::hash::DefaultHasher::default(); + ::std::hash::Hash::hash(s, &mut hasher); + ::std::hash::Hasher::finish(&hasher) + } + } + }); + } +} diff --git a/ipa-step/src/lib.rs b/ipa-step/src/lib.rs index a5b6a2db9..30851f3e8 100644 --- a/ipa-step/src/lib.rs +++ b/ipa-step/src/lib.rs @@ -3,6 +3,8 @@ pub mod descriptive; #[cfg(feature = "build")] pub mod gate; +#[cfg(feature = "build")] +mod hashing; #[cfg(feature = "name")] pub mod name;