From 98dc167c13ebd3823614ff035a717dc58af45ecc Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Wed, 15 Jan 2025 12:50:25 +0000 Subject: [PATCH 01/39] Add pass to preprocess 'serialize' --- compiler/noirc_evaluator/src/ssa.rs | 15 +++ .../noirc_evaluator/src/ssa/opt/inlining.rs | 45 +++++---- compiler/noirc_evaluator/src/ssa/opt/mod.rs | 1 + .../src/ssa/opt/preprocess_fns.rs | 44 +++++++++ .../noirc_evaluator/src/ssa/opt/unrolling.rs | 93 ++++++++++++------- 5 files changed, 145 insertions(+), 53 deletions(-) create mode 100644 compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 45021fa6158..41621358d41 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -153,6 +153,15 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result Self { + println!("AFTER {msg}: functions={}", self.ssa.functions.len()); + for f in self.ssa.functions.values() { + let block_cnt = ir::post_order::PostOrder::with_function(f).into_vec().len(); + println!(" FUNCTION {}: blocks={block_cnt}", f.name()); + } + let print_ssa_pass = match &self.ssa_logging { SsaLogging::None => false, SsaLogging::All => true, diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index b1dd203cfd0..f55f052676e 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -49,7 +49,7 @@ impl Ssa { Self::inline_functions_inner(self, aggressiveness, false) } - // Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points + /// Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points pub(crate) fn inline_functions_with_no_predicates(self, aggressiveness: i64) -> Ssa { Self::inline_functions_inner(self, aggressiveness, true) } @@ -61,26 +61,37 @@ impl Ssa { ) -> Ssa { let inline_sources = get_functions_to_inline_into(&self, inline_no_predicates_functions, aggressiveness); + // NOTE: Functions are processed independently of each other, with the final mapping replacing the original, + // instead of inlining the "leaf" functions, moving up towards the entry point. self.functions = btree_map(&inline_sources, |entry_point| { - let new_function = InlineContext::new( - &self, - *entry_point, - inline_no_predicates_functions, - inline_sources.clone(), - ) - .inline_all(&self); + let function = &self.functions[entry_point]; + let new_function = + function.inlined(&self, inline_no_predicates_functions, &inline_sources); (*entry_point, new_function) }); self } } +impl Function { + /// Create a new function which has the functions called by this one inlined into its body. + pub(super) fn inlined( + &self, + ssa: &Ssa, + inline_no_predicates_functions: bool, + functions_not_to_inline: &BTreeSet, + ) -> Function { + InlineContext::new(ssa, self.id(), inline_no_predicates_functions, functions_not_to_inline) + .inline_all(ssa) + } +} + /// The context for the function inlining pass. /// /// This works using an internal FunctionBuilder to build a new main function from scratch. /// Doing it this way properly handles importing instructions between functions and lets us /// reuse the existing API at the cost of essentially cloning each of main's instructions. -struct InlineContext { +struct InlineContext<'a> { recursion_level: u32, builder: FunctionBuilder, @@ -97,7 +108,7 @@ struct InlineContext { inline_no_predicates_functions: bool, // These are the functions of the program that we shouldn't inline. - functions_not_to_inline: BTreeSet, + functions_not_to_inline: &'a BTreeSet, } /// The per-function inlining context contains information that is only valid for one function. @@ -105,13 +116,13 @@ struct InlineContext { /// layer to translate between BlockId to BlockId for the current function and the function to /// inline into. The same goes for ValueIds, InstructionIds, and for storing other data like /// parameter to argument mappings. -struct PerFunctionContext<'function> { +struct PerFunctionContext<'function, 'a> { /// The source function is the function we're currently inlining into the function being built. source_function: &'function Function, /// The shared inlining context for all functions. This notably contains the FunctionBuilder used /// to build the function we're inlining into. - context: &'function mut InlineContext, + context: &'function mut InlineContext<'a>, /// Maps ValueIds in the function being inlined to the new ValueIds to use in the function /// being inlined into. This mapping also contains the mapping from parameter values to @@ -161,7 +172,7 @@ fn called_functions(func: &Function) -> BTreeSet { /// - Any Brillig function called from Acir /// - Some Brillig functions depending on aggressiveness and some metrics /// - Any Acir functions with a [fold inline type][InlineType::Fold], -fn get_functions_to_inline_into( +pub(super) fn get_functions_to_inline_into( ssa: &Ssa, inline_no_predicates_functions: bool, aggressiveness: i64, @@ -349,7 +360,7 @@ fn compute_function_interface_cost(func: &Function) -> usize { func.parameters().len() + func.returns().len() } -impl InlineContext { +impl<'a> InlineContext<'a> { /// Create a new context object for the function inlining pass. /// This starts off with an empty mapping of instructions for main's parameters. /// The function being inlined into will always be the main function, although it is @@ -359,7 +370,7 @@ impl InlineContext { ssa: &Ssa, entry_point: FunctionId, inline_no_predicates_functions: bool, - functions_not_to_inline: BTreeSet, + functions_not_to_inline: &'a BTreeSet, ) -> Self { let source = &ssa.functions[&entry_point]; let mut builder = FunctionBuilder::new(source.name().to_owned(), entry_point); @@ -444,13 +455,13 @@ impl InlineContext { } } -impl<'function> PerFunctionContext<'function> { +impl<'function, 'ctx> PerFunctionContext<'function, 'ctx> { /// Create a new PerFunctionContext from the source function. /// The value and block mappings for this context are initially empty except /// for containing the mapping between parameters in the source_function and /// the arguments of the destination function. fn new( - context: &'function mut InlineContext, + context: &'function mut InlineContext<'ctx>, source_function: &'function Function, globals: &'function Function, ) -> Self { diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index 1105e15c30e..476cc660c04 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -16,6 +16,7 @@ mod inlining; mod loop_invariant; mod mem2reg; mod normalize_value_ids; +mod preprocess_fns; mod rc; mod remove_bit_shifts; mod remove_enable_side_effects; diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs new file mode 100644 index 00000000000..bebfc28cdd3 --- /dev/null +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -0,0 +1,44 @@ +//! Pre-process functions before inlining them into others. + +use crate::ssa::Ssa; + +impl Ssa { + /// Run pre-processing steps on functions in isolation. + pub(crate) fn preprocess_functions( + mut self, + aggressiveness: i64, + max_bytecode_increase_percent: Option, + ) -> Ssa { + // Ok(self + // .inline_functions_limited(aggressiveness) + // .as_slice_optimization() + // .try_unroll_loops_iteratively(max_bytecode_increase_percent, true)? + // .mem2reg()) + + // TODO: Ideally we would go bottom-up, starting with the "leaf" functions, so we inline already optimized code into + // the ones that call them, but for now just see what happens if we pre-process the "serialize" function. + let to_preprocess = self + .functions + .iter() + .filter_map(|(id, f)| (f.name() == "serialize").then_some(*id)) + .collect::>(); + + let not_to_inline = + super::inlining::get_functions_to_inline_into(&self, false, aggressiveness); + + for id in to_preprocess { + let function = &self.functions[&id]; + let mut function = function.inlined(&self, false, ¬_to_inline); + // Help unrolling determine bounds. + function.as_slice_optimization(); + // We might not be able to unroll all loops without fully inlining them, so ignore errors. + let _ = function.try_unroll_loops_iteratively(max_bytecode_increase_percent, true); + // Reduce the number of redundant stores/loads after unrolling + function.mem2reg(); + // Put it back into the SSA, so the next functions can pick it up. + self.functions.insert(id, function); + } + + self + } +} diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 79181b7e74e..0d6ff5d9a05 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -54,52 +54,73 @@ impl Ssa { /// fewer SSA instructions, but that can still result in more Brillig opcodes. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn unroll_loops_iteratively( - mut self: Ssa, + mut self, max_bytecode_increase_percent: Option, ) -> Result { - for (_, function) in self.functions.iter_mut() { - // Take a snapshot of the function to compare byte size increase, - // but only if the setting indicates we have to, otherwise skip it. - let orig_func_and_max_incr_pct = max_bytecode_increase_percent - .filter(|_| function.runtime().is_brillig()) - .map(|max_incr_pct| (function.clone(), max_incr_pct)); - - // Try to unroll loops first: - let (mut has_unrolled, mut unroll_errors) = function.try_unroll_loops(); - - // Keep unrolling until no more errors are found - while !unroll_errors.is_empty() { - let prev_unroll_err_count = unroll_errors.len(); - - // Simplify the SSA before retrying - simplify_between_unrolls(function); - - // Unroll again - let (new_unrolled, new_errors) = function.try_unroll_loops(); - unroll_errors = new_errors; - has_unrolled |= new_unrolled; - - // If we didn't manage to unroll any more loops, exit - if unroll_errors.len() >= prev_unroll_err_count { - return Err(unroll_errors.swap_remove(0)); + for function in self.functions.values_mut() { + // We must be able to unroll ACIR loops at this point. + function.try_unroll_loops_iteratively(max_bytecode_increase_percent, false)?; + } + Ok(self) + } +} + +impl Function { + /// Try to unroll loops in the function. + /// + /// Returns an `Err` if it cannot be done, for example because the loop bounds + /// cannot be determined at compile time. This can happen during pre-processing, + /// in which case `restore_on_error` can be used to leave the function untouched. + pub(super) fn try_unroll_loops_iteratively( + &mut self, + max_bytecode_increase_percent: Option, + restore_on_error: bool, + ) -> Result<(), RuntimeError> { + // Take a snapshot in case we have to restore it. + let orig_function = (max_bytecode_increase_percent.is_some() + && self.runtime().is_brillig() + || restore_on_error) + .then(|| self.clone()); + + // Try to unroll loops first: + let (mut has_unrolled, mut unroll_errors) = self.try_unroll_loops(); + + // Keep unrolling until no more errors are found + while !unroll_errors.is_empty() { + let prev_unroll_err_count = unroll_errors.len(); + + // Simplify the SSA before retrying + simplify_between_unrolls(self); + + // Unroll again + let (new_unrolled, new_errors) = self.try_unroll_loops(); + unroll_errors = new_errors; + has_unrolled |= new_unrolled; + + // If we didn't manage to unroll any more loops, exit + if unroll_errors.len() >= prev_unroll_err_count { + if restore_on_error { + *self = orig_function.expect("took snapshot to restore"); } + return Err(unroll_errors.swap_remove(0)); } + } - if has_unrolled { - if let Some((orig_function, max_incr_pct)) = orig_func_and_max_incr_pct { - let new_size = brillig_bytecode_size(function); - let orig_size = brillig_bytecode_size(&orig_function); - if !is_new_size_ok(orig_size, new_size, max_incr_pct) { - *function = orig_function; - } + // Check if the size increase is acceptable + if has_unrolled && self.runtime().is_brillig() { + if let Some(max_incr_pct) = max_bytecode_increase_percent { + let orig_function = orig_function.expect("took snapshot to compare"); + let new_size = brillig_bytecode_size(self); + let orig_size = brillig_bytecode_size(&orig_function); + if !is_new_size_ok(orig_size, new_size, max_incr_pct) { + *self = orig_function; } } } - Ok(self) + + Ok(()) } -} -impl Function { // Loop unrolling in brillig can lead to a code explosion currently. // This can also be true for ACIR, but we have no alternative to unrolling in ACIR. // Brillig also generally prefers smaller code rather than faster code, From 6943765f15e31583b325b22df5a6020736bf2994 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Wed, 15 Jan 2025 13:49:58 +0000 Subject: [PATCH 02/39] Simplify a CFG and run DIE as well --- .../noirc_evaluator/src/ssa/opt/preprocess_fns.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index bebfc28cdd3..57617843b21 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -9,14 +9,8 @@ impl Ssa { aggressiveness: i64, max_bytecode_increase_percent: Option, ) -> Ssa { - // Ok(self - // .inline_functions_limited(aggressiveness) - // .as_slice_optimization() - // .try_unroll_loops_iteratively(max_bytecode_increase_percent, true)? - // .mem2reg()) - // TODO: Ideally we would go bottom-up, starting with the "leaf" functions, so we inline already optimized code into - // the ones that call them, but for now just see what happens if we pre-process the "serialize" function. + // the ones that call them, but for now just see what happens if we pre-process the "serialize" function instances. let to_preprocess = self .functions .iter() @@ -35,6 +29,10 @@ impl Ssa { let _ = function.try_unroll_loops_iteratively(max_bytecode_increase_percent, true); // Reduce the number of redundant stores/loads after unrolling function.mem2reg(); + // Try to reduce the number of blocks. + function.simplify_function(); + // Remove leftover instructions. + function.dead_instruction_elimination(true); // Put it back into the SSA, so the next functions can pick it up. self.functions.insert(id, function); } From 850fc11429c6556289e003059c5a9c10fd4e8400 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Wed, 15 Jan 2025 20:37:01 +0000 Subject: [PATCH 03/39] Add docstrings to inlining --- compiler/noirc_evaluator/src/ssa/opt/inlining.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index f55f052676e..217271686fa 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -145,6 +145,8 @@ struct PerFunctionContext<'function, 'a> { } /// Utility function to find out the direct calls of a function. +/// +/// Returns the function IDs from all `Call` instructions without deduplication. fn called_functions_vec(func: &Function) -> Vec { let mut called_function_ids = Vec::new(); for block_id in func.reachable_blocks() { @@ -162,7 +164,7 @@ fn called_functions_vec(func: &Function) -> Vec { called_function_ids } -/// Utility function to find out the deduplicated direct calls of a function. +/// Utility function to find out the deduplicated direct calls made from a function. fn called_functions(func: &Function) -> BTreeSet { called_functions_vec(func).into_iter().collect() } @@ -222,6 +224,7 @@ pub(super) fn get_functions_to_inline_into( .collect() } +/// Compute the time each function is called from any other function. fn compute_times_called(ssa: &Ssa) -> HashMap { ssa.functions .iter() @@ -236,10 +239,14 @@ fn compute_times_called(ssa: &Ssa) -> HashMap { }) } +/// Traverse the call graph starting from a given function, marking function to be retained if they are: +/// * recursive functions, or +/// * the cost of inlining outweighs the cost of not doing so fn should_retain_recursive( ssa: &Ssa, func: FunctionId, times_called: &HashMap, + // FunctionId -> (should_retain, weight) should_retain_function: &mut HashMap, mut explored_functions: im::HashSet, inline_no_predicates_functions: bool, @@ -280,6 +287,8 @@ fn should_retain_recursive( // And the interface cost of the function (the inherent cost at the callsite, roughly the number of args and returns) // We then can compute an approximation of the cost of inlining vs the cost of retaining the function // We do this computation using saturating i64s to avoid overflows + + // Total weight of functions called by this one, unless we decided not to inline them. let inlined_function_weights: i64 = called_functions.iter().fold(0, |acc, called_function| { let (should_retain, weight) = should_retain_function[called_function]; if should_retain { @@ -311,6 +320,7 @@ fn should_retain_recursive( should_retain_function.insert(func, (!should_inline, this_function_weight)); } +/// Gather the functions that should not be inlined. fn compute_functions_to_retain( ssa: &Ssa, entry_points: &BTreeSet, @@ -346,6 +356,7 @@ fn compute_functions_to_retain( .collect() } +/// Compute a weight of a function based on the number of instructions in its reachable blocks. fn compute_function_own_weight(func: &Function) -> usize { let mut weight = 0; for block_id in func.reachable_blocks() { @@ -356,6 +367,7 @@ fn compute_function_own_weight(func: &Function) -> usize { weight } +/// Compute interface cost of a function based on the number of inputs and outputs. fn compute_function_interface_cost(func: &Function) -> usize { func.parameters().len() + func.returns().len() } From e5c49afa97d191981b712da9ef32cb7966531554 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Wed, 15 Jan 2025 23:43:05 +0000 Subject: [PATCH 04/39] See what happens if we go bottom up --- .../noirc_evaluator/src/ssa/opt/inlining.rs | 88 ++++++++++++++++++- .../src/ssa/opt/preprocess_fns.rs | 17 ++-- cspell.json | 1 + 3 files changed, 92 insertions(+), 14 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 217271686fa..70fc7bc9a74 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -2,9 +2,10 @@ //! The purpose of this pass is to inline the instructions of each function call //! within the function caller. If all function calls are known, there will only //! be a single function remaining when the pass finishes. -use std::collections::{BTreeSet, HashSet, VecDeque}; +use std::collections::{BTreeMap, BTreeSet, HashSet, VecDeque}; use acvm::acir::AcirField; +use im::HashMap; use iter_extended::{btree_map, vecmap}; use crate::ssa::{ @@ -19,7 +20,6 @@ use crate::ssa::{ }, ssa_gen::Ssa, }; -use fxhash::FxHashMap as HashMap; /// An arbitrary limit to the maximum number of recursive call /// frames at any point in time. @@ -225,7 +225,7 @@ pub(super) fn get_functions_to_inline_into( } /// Compute the time each function is called from any other function. -fn compute_times_called(ssa: &Ssa) -> HashMap { +pub(super) fn compute_times_called(ssa: &Ssa) -> HashMap { ssa.functions .iter() .flat_map(|(_caller_id, function)| { @@ -239,6 +239,88 @@ fn compute_times_called(ssa: &Ssa) -> HashMap { }) } +/// Compute for each function the set of functions that call it. +fn compute_callers(ssa: &Ssa) -> BTreeMap> { + ssa.functions + .iter() + .flat_map(|(caller_id, function)| { + let called_functions = called_functions(function); + called_functions.into_iter().map(|called_id| (called_id, *caller_id)) + }) + .fold( + ssa.functions.keys().map(|id| (*id, BTreeSet::new())).collect(), + |mut acc, (called_id, caller_id)| { + let callers = acc.entry(called_id).or_default(); + callers.insert(caller_id); + acc + }, + ) +} + +/// Compute for each function the set of functions called by it. +fn compute_callees(ssa: &Ssa) -> BTreeMap> { + ssa.functions + .iter() + .map(|(caller_id, function)| { + let called_functions = called_functions(function); + (*caller_id, called_functions) + }) + .collect() +} + +/// Compute something like a topological order of the functions, starting with the ones +/// that do not call any other functions, going towards the entry points. When cycles +/// are detected, take the one which are called by the most to break the ties. +/// +/// This can be used to simplify the most often called functions first. +pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec { + let mut order = Vec::new(); + let mut visited = HashSet::new(); + + // Number of times a function is called, to break cycles. + let mut times_called = compute_times_called(ssa).into_iter().collect::>(); + times_called.sort_by_key(|(id, cnt)| (*cnt, *id)); + + let callers = compute_callers(ssa); + let mut callees = compute_callees(ssa); + + // Seed the queue with functions that don't call anything. + let mut queue = callees + .iter() + .filter_map(|(id, callees)| callees.is_empty().then_some(*id)) + .collect::>(); + + loop { + if times_called.is_empty() && queue.is_empty() { + break; + } + while let Some(id) = queue.pop_front() { + order.push(id); + visited.insert(id); + // Remove this function from all of its callers. + for caller in &callers[&id] { + let callees = callees.get_mut(caller).expect("all callees computed"); + // If the caller doesn't call any other function, enqueue it. + if callees.is_empty() && !visited.contains(caller) { + queue.push_back(*caller); + } + } + } + // If we ran out of the queue, maybe there is a cycle; take the next most called function. + loop { + let Some((id, _)) = times_called.pop() else { + break; + }; + if !visited.contains(&id) { + queue.push_back(id); + break; + } + } + } + + order +} + /// Traverse the call graph starting from a given function, marking function to be retained if they are: /// * recursive functions, or /// * the cost of inlining outweighs the cost of not doing so diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index 57617843b21..dd4d7bdc2c1 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -2,6 +2,8 @@ use crate::ssa::Ssa; +use super::inlining; + impl Ssa { /// Run pre-processing steps on functions in isolation. pub(crate) fn preprocess_functions( @@ -9,18 +11,11 @@ impl Ssa { aggressiveness: i64, max_bytecode_increase_percent: Option, ) -> Ssa { - // TODO: Ideally we would go bottom-up, starting with the "leaf" functions, so we inline already optimized code into - // the ones that call them, but for now just see what happens if we pre-process the "serialize" function instances. - let to_preprocess = self - .functions - .iter() - .filter_map(|(id, f)| (f.name() == "serialize").then_some(*id)) - .collect::>(); - - let not_to_inline = - super::inlining::get_functions_to_inline_into(&self, false, aggressiveness); + // Bottom-up order, starting with the "leaf" functions, so we inline already optimized code into the ones that call them. + let bottom_up = inlining::compute_bottom_up_order(&self); + let not_to_inline = inlining::get_functions_to_inline_into(&self, false, aggressiveness); - for id in to_preprocess { + for id in bottom_up.into_iter().filter(|id| !not_to_inline.contains(id)) { let function = &self.functions[&id]; let mut function = function.inlined(&self, false, ¬_to_inline); // Help unrolling determine bounds. diff --git a/cspell.json b/cspell.json index ed9f7427c6f..a42b90d2e8c 100644 --- a/cspell.json +++ b/cspell.json @@ -35,6 +35,7 @@ "bunx", "bytecount", "cachix", + "callees", "callsite", "callsites", "callstack", From 13b42f2f365b95e8ad00e85be9632ab0032cb823 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 16 Jan 2025 11:16:26 +0000 Subject: [PATCH 05/39] Use a cutoff weight to skip processing large functions --- compiler/noirc_evaluator/src/ssa.rs | 2 +- .../noirc_evaluator/src/ssa/opt/inlining.rs | 66 ++++++++++++------- .../src/ssa/opt/preprocess_fns.rs | 19 +++++- 3 files changed, 62 insertions(+), 25 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 41621358d41..8e0585fa618 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -491,7 +491,7 @@ impl SsaBuilder { println!("AFTER {msg}: functions={}", self.ssa.functions.len()); for f in self.ssa.functions.values() { let block_cnt = ir::post_order::PostOrder::with_function(f).into_vec().len(); - println!(" FUNCTION {}: blocks={block_cnt}", f.name()); + println!(" fn {} {}: blocks={block_cnt}", f.name(), f.id()); } let print_ssa_pass = match &self.ssa_logging { diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index affae768fac..bdcc3cbdc70 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -239,33 +239,42 @@ pub(super) fn compute_times_called(ssa: &Ssa) -> HashMap { }) } -/// Compute for each function the set of functions that call it. -fn compute_callers(ssa: &Ssa) -> BTreeMap> { +/// Compute for each function the set of functions that call it, and how many times they do so. +fn compute_callers(ssa: &Ssa) -> BTreeMap> { ssa.functions .iter() .flat_map(|(caller_id, function)| { - let called_functions = called_functions(function); - called_functions.into_iter().map(|called_id| (called_id, *caller_id)) + let called_functions = called_functions_vec(function); + called_functions.into_iter().map(|callee_id| (*caller_id, callee_id)) }) .fold( - ssa.functions.keys().map(|id| (*id, BTreeSet::new())).collect(), - |mut acc, (called_id, caller_id)| { - let callers = acc.entry(called_id).or_default(); - callers.insert(caller_id); + // Make sure an entry exists even for ones that don't get called. + ssa.functions.keys().map(|id| (*id, BTreeMap::new())).collect(), + |mut acc, (caller_id, callee_id)| { + let callers = acc.entry(callee_id).or_default(); + *callers.entry(caller_id).or_default() += 1; acc }, ) } -/// Compute for each function the set of functions called by it. -fn compute_callees(ssa: &Ssa) -> BTreeMap> { +/// Compute for each function the set of functions called by it, and how many times it does so. +fn compute_callees(ssa: &Ssa) -> BTreeMap> { ssa.functions .iter() - .map(|(caller_id, function)| { - let called_functions = called_functions(function); - (*caller_id, called_functions) + .flat_map(|(caller_id, function)| { + let called_functions = called_functions_vec(function); + called_functions.into_iter().map(|callee_id| (*caller_id, callee_id)) }) - .collect() + .fold( + // Make sure an entry exists even for ones that don't call anything. + ssa.functions.keys().map(|id| (*id, BTreeMap::new())).collect(), + |mut acc, (caller_id, callee_id)| { + let callees = acc.entry(caller_id).or_default(); + *callees.entry(callee_id).or_default() += 1; + acc + }, + ) } /// Compute something like a topological order of the functions, starting with the ones @@ -273,7 +282,10 @@ fn compute_callees(ssa: &Ssa) -> BTreeMap> { /// are detected, take the one which are called by the most to break the ties. /// /// This can be used to simplify the most often called functions first. -pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec { +/// +/// Returns the functions paired with their transitive weight, which accumulates +/// the weight of all the functions they call. +pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> { let mut order = Vec::new(); let mut visited = HashSet::new(); @@ -281,6 +293,13 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec { let mut times_called = compute_times_called(ssa).into_iter().collect::>(); times_called.sort_by_key(|(id, cnt)| (*cnt, *id)); + // Start with the weight of the functions in isolation, then accumulate as we pop off the ones they call. + let mut weights = ssa + .functions + .iter() + .map(|(id, f)| (*id, compute_function_own_weight(f))) + .collect::>(); + let callers = compute_callers(ssa); let mut callees = compute_callees(ssa); @@ -292,14 +311,19 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec { loop { if times_called.is_empty() && queue.is_empty() { - break; + return order; } while let Some(id) = queue.pop_front() { - order.push(id); + let weight = weights[&id]; + order.push((id, weight)); visited.insert(id); - // Remove this function from all of its callers. - for caller in &callers[&id] { - let callees = callees.get_mut(caller).expect("all callees computed"); + // Update the callers of this function. + for (caller, call_count) in &callers[&id] { + // Update the weight of the caller with the weight of this function. + weights[caller] = weights[caller].saturating_add(call_count.saturating_mul(weight)); + // Remove this function from the callees of the caller. + let callees = callees.get_mut(caller).unwrap(); + callees.remove(&id); // If the caller doesn't call any other function, enqueue it. if callees.is_empty() && !visited.contains(caller) { queue.push_back(*caller); @@ -317,8 +341,6 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec { } } } - - order } /// Traverse the call graph starting from a given function, marking function to be retained if they are: diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index dd4d7bdc2c1..d186c3d08ae 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -11,12 +11,27 @@ impl Ssa { aggressiveness: i64, max_bytecode_increase_percent: Option, ) -> Ssa { + // No point pre-processing the functions that will never be inlined into others. + let not_to_inline = inlining::get_functions_to_inline_into(&self, false, aggressiveness); // Bottom-up order, starting with the "leaf" functions, so we inline already optimized code into the ones that call them. let bottom_up = inlining::compute_bottom_up_order(&self); - let not_to_inline = inlining::get_functions_to_inline_into(&self, false, aggressiveness); - for id in bottom_up.into_iter().filter(|id| !not_to_inline.contains(id)) { + // As a heuristic to avoid optimizing functions near the entry point, find a cutoff weight. + let total_weight = bottom_up.iter().fold(0usize, |acc, (_, w)| acc.saturating_add(*w)); + let mean_weight = total_weight / bottom_up.len(); + let cutoff_weight = mean_weight; + + for (id, weight) in bottom_up.into_iter().filter(|(id, _)| !not_to_inline.contains(id)) { let function = &self.functions[&id]; + if weight <= cutoff_weight { + println!("PREPROCESSING fn {} {id} with weight {weight}", function.name()); + } else { + println!( + "SKIP PREPROCESSING fn {} {id} with weight {weight} > {cutoff_weight}", + function.name() + ); + continue; + } let mut function = function.inlined(&self, false, ¬_to_inline); // Help unrolling determine bounds. function.as_slice_optimization(); From 8f7eb9ce7753956035ca4ecae34384a46fed1b88 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 16 Jan 2025 11:54:12 +0000 Subject: [PATCH 06/39] Add CLI option to turn off preprocessing --- compiler/noirc_driver/src/lib.rs | 5 +++++ compiler/noirc_evaluator/src/ssa.rs | 16 +++++++++++----- compiler/noirc_evaluator/src/ssa/opt/hint.rs | 1 + .../src/ssa/opt/preprocess_fns.rs | 13 ++++++------- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/compiler/noirc_driver/src/lib.rs b/compiler/noirc_driver/src/lib.rs index a7e7e2d4e2f..2646b13a33a 100644 --- a/compiler/noirc_driver/src/lib.rs +++ b/compiler/noirc_driver/src/lib.rs @@ -141,6 +141,10 @@ pub struct CompileOptions { #[arg(long)] pub skip_brillig_constraints_check: bool, + /// Flag to turn off preprocessing functions during SSA passes. + #[arg(long)] + pub skip_preprocess_fns: bool, + /// Setting to decide on an inlining strategy for Brillig functions. /// A more aggressive inliner should generate larger programs but more optimized /// A less aggressive inliner should generate smaller programs @@ -679,6 +683,7 @@ pub fn compile_no_check( emit_ssa: if options.emit_ssa { Some(context.package_build_path.clone()) } else { None }, skip_underconstrained_check: options.skip_underconstrained_check, skip_brillig_constraints_check: options.skip_brillig_constraints_check, + skip_preprocess_fns: options.skip_preprocess_fns, inliner_aggressiveness: options.inliner_aggressiveness, max_bytecode_increase_percent: options.max_bytecode_increase_percent, }; diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 8e0585fa618..4e8ecc78b1b 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -71,6 +71,9 @@ pub struct SsaEvaluatorOptions { /// Skip the missing Brillig call constraints check pub skip_brillig_constraints_check: bool, + /// Skip preprocessing functions. + pub skip_preprocess_fns: bool, + /// The higher the value, the more inlined Brillig functions will be. pub inliner_aggressiveness: i64, @@ -155,6 +158,9 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result Self { - println!("AFTER {msg}: functions={}", self.ssa.functions.len()); - for f in self.ssa.functions.values() { - let block_cnt = ir::post_order::PostOrder::with_function(f).into_vec().len(); - println!(" fn {} {}: blocks={block_cnt}", f.name(), f.id()); - } + // println!("AFTER {msg}: functions={}", self.ssa.functions.len()); + // for f in self.ssa.functions.values() { + // let block_cnt = ir::post_order::PostOrder::with_function(f).into_vec().len(); + // println!(" fn {} {}: blocks={block_cnt}", f.name(), f.id()); + // } let print_ssa_pass = match &self.ssa_logging { SsaLogging::None => false, diff --git a/compiler/noirc_evaluator/src/ssa/opt/hint.rs b/compiler/noirc_evaluator/src/ssa/opt/hint.rs index 1326c2cc010..3012fae9dbb 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/hint.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/hint.rs @@ -19,6 +19,7 @@ mod tests { emit_ssa: None, skip_underconstrained_check: true, skip_brillig_constraints_check: true, + skip_preprocessing_fns: true, inliner_aggressiveness: 0, max_bytecode_increase_percent: None, }; diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index d186c3d08ae..f395a93226f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -23,15 +23,14 @@ impl Ssa { for (id, weight) in bottom_up.into_iter().filter(|(id, _)| !not_to_inline.contains(id)) { let function = &self.functions[&id]; - if weight <= cutoff_weight { - println!("PREPROCESSING fn {} {id} with weight {weight}", function.name()); - } else { - println!( - "SKIP PREPROCESSING fn {} {id} with weight {weight} > {cutoff_weight}", - function.name() - ); + if weight >= cutoff_weight { + // println!( + // "SKIP PREPROCESSING fn {} {id} with weight {weight} > {cutoff_weight}", + // function.name() + // ); continue; } + // println!("PREPROCESSING fn {} {id} with weight {weight}", function.name()); let mut function = function.inlined(&self, false, ¬_to_inline); // Help unrolling determine bounds. function.as_slice_optimization(); From 02cb85646a97221fc83e5b1e09491431c2880b78 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 16 Jan 2025 12:08:16 +0000 Subject: [PATCH 07/39] Don't call DIE, to fix the tests --- compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index f395a93226f..42739f8c9e0 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -41,7 +41,9 @@ impl Ssa { // Try to reduce the number of blocks. function.simplify_function(); // Remove leftover instructions. - function.dead_instruction_elimination(true); + // XXX: Leaving this in causes integration test failures, + // for example with `traits_in_crates_1` it eliminates a store to a mutable input reference. + // function.dead_instruction_elimination(true); // Put it back into the SSA, so the next functions can pick it up. self.functions.insert(id, function); } From 5aa3db7a9a475f87b3d795b9562c052b0e943a51 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 16 Jan 2025 12:13:59 +0000 Subject: [PATCH 08/39] Fix field name --- compiler/noirc_evaluator/src/ssa/opt/hint.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/hint.rs b/compiler/noirc_evaluator/src/ssa/opt/hint.rs index 3012fae9dbb..3f913614e76 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/hint.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/hint.rs @@ -19,7 +19,7 @@ mod tests { emit_ssa: None, skip_underconstrained_check: true, skip_brillig_constraints_check: true, - skip_preprocessing_fns: true, + skip_preprocess_fns: true, inliner_aggressiveness: 0, max_bytecode_increase_percent: None, }; From 3c3e215ed1de02b19472d9a1e795fa7f73dd662d Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 16 Jan 2025 12:17:54 +0000 Subject: [PATCH 09/39] Remove restore_on_error from unrolling --- .../noirc_evaluator/src/ssa/opt/preprocess_fns.rs | 2 +- compiler/noirc_evaluator/src/ssa/opt/unrolling.rs | 15 +++++---------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index 42739f8c9e0..fe934fe2c6b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -35,7 +35,7 @@ impl Ssa { // Help unrolling determine bounds. function.as_slice_optimization(); // We might not be able to unroll all loops without fully inlining them, so ignore errors. - let _ = function.try_unroll_loops_iteratively(max_bytecode_increase_percent, true); + let _ = function.unroll_loops_iteratively(max_bytecode_increase_percent); // Reduce the number of redundant stores/loads after unrolling function.mem2reg(); // Try to reduce the number of blocks. diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 0d6ff5d9a05..bbe345b03d4 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -59,7 +59,7 @@ impl Ssa { ) -> Result { for function in self.functions.values_mut() { // We must be able to unroll ACIR loops at this point. - function.try_unroll_loops_iteratively(max_bytecode_increase_percent, false)?; + function.unroll_loops_iteratively(max_bytecode_increase_percent)?; } Ok(self) } @@ -70,17 +70,15 @@ impl Function { /// /// Returns an `Err` if it cannot be done, for example because the loop bounds /// cannot be determined at compile time. This can happen during pre-processing, - /// in which case `restore_on_error` can be used to leave the function untouched. - pub(super) fn try_unroll_loops_iteratively( + /// but it should still leave the function in a partially unrolled, but valid state. + pub(super) fn unroll_loops_iteratively( &mut self, max_bytecode_increase_percent: Option, - restore_on_error: bool, ) -> Result<(), RuntimeError> { // Take a snapshot in case we have to restore it. let orig_function = (max_bytecode_increase_percent.is_some() - && self.runtime().is_brillig() - || restore_on_error) - .then(|| self.clone()); + && self.runtime().is_brillig()) + .then(|| self.clone()); // Try to unroll loops first: let (mut has_unrolled, mut unroll_errors) = self.try_unroll_loops(); @@ -99,9 +97,6 @@ impl Function { // If we didn't manage to unroll any more loops, exit if unroll_errors.len() >= prev_unroll_err_count { - if restore_on_error { - *self = orig_function.expect("took snapshot to restore"); - } return Err(unroll_errors.swap_remove(0)); } } From 1b9ffde83265ed1d87a5399e68e339da7dcedbb7 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 16 Jan 2025 12:21:57 +0000 Subject: [PATCH 10/39] Remove prints --- compiler/noirc_evaluator/src/ssa.rs | 6 ------ .../src/ssa/opt/preprocess_fns.rs | 17 +++++++---------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 4e8ecc78b1b..ae77665c373 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -494,12 +494,6 @@ impl SsaBuilder { } fn print(mut self, msg: &str) -> Self { - // println!("AFTER {msg}: functions={}", self.ssa.functions.len()); - // for f in self.ssa.functions.values() { - // let block_cnt = ir::post_order::PostOrder::with_function(f).into_vec().len(); - // println!(" fn {} {}: blocks={block_cnt}", f.name(), f.id()); - // } - let print_ssa_pass = match &self.ssa_logging { SsaLogging::None => false, SsaLogging::All => true, diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index fe934fe2c6b..b90151785a5 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -21,16 +21,11 @@ impl Ssa { let mean_weight = total_weight / bottom_up.len(); let cutoff_weight = mean_weight; - for (id, weight) in bottom_up.into_iter().filter(|(id, _)| !not_to_inline.contains(id)) { + for (id, weight) in bottom_up + .into_iter() + .filter(|(id, weight)| !not_to_inline.contains(id) && *weight < cutoff_weight) + { let function = &self.functions[&id]; - if weight >= cutoff_weight { - // println!( - // "SKIP PREPROCESSING fn {} {id} with weight {weight} > {cutoff_weight}", - // function.name() - // ); - continue; - } - // println!("PREPROCESSING fn {} {id} with weight {weight}", function.name()); let mut function = function.inlined(&self, false, ¬_to_inline); // Help unrolling determine bounds. function.as_slice_optimization(); @@ -40,10 +35,12 @@ impl Ssa { function.mem2reg(); // Try to reduce the number of blocks. function.simplify_function(); + // Remove leftover instructions. - // XXX: Leaving this in causes integration test failures, + // XXX: Doing this would currently integration test failures, // for example with `traits_in_crates_1` it eliminates a store to a mutable input reference. // function.dead_instruction_elimination(true); + // Put it back into the SSA, so the next functions can pick it up. self.functions.insert(id, function); } From 1e07b7a400876f300bc24611828f876f29a8806e Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 16 Jan 2025 12:25:27 +0000 Subject: [PATCH 11/39] Fix clippy --- compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index b90151785a5..85452c5ea0d 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -21,9 +21,10 @@ impl Ssa { let mean_weight = total_weight / bottom_up.len(); let cutoff_weight = mean_weight; - for (id, weight) in bottom_up + for (id, _) in bottom_up .into_iter() - .filter(|(id, weight)| !not_to_inline.contains(id) && *weight < cutoff_weight) + .filter(|(id, _)| !not_to_inline.contains(id)) + .filter(|(_, weight)| *weight < cutoff_weight) { let function = &self.functions[&id]; let mut function = function.inlined(&self, false, ¬_to_inline); From de5176a255a005129224a216e1fcd7abb5f180a2 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 16 Jan 2025 13:49:11 +0000 Subject: [PATCH 12/39] Do not inline self-recursive entries --- compiler/noirc_evaluator/src/ssa/opt/inlining.rs | 6 ++++++ compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 4410bfc1bb8..3dd1e9378b6 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -83,6 +83,12 @@ impl Function { ) -> Function { let should_inline_call = |_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool { + // Do not inline self-recursive functions on the top level. + // Inlining a self-recursive function works when there is something to inline into + // by importing all the recursive blocks, but for the entry function there is no wrapper. + if called_func_id == self.id() { + return false; + } let function = &ssa.functions[&called_func_id]; match function.runtime() { diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index 85452c5ea0d..e5754425d8b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -13,6 +13,7 @@ impl Ssa { ) -> Ssa { // No point pre-processing the functions that will never be inlined into others. let not_to_inline = inlining::get_functions_to_inline_into(&self, false, aggressiveness); + // Bottom-up order, starting with the "leaf" functions, so we inline already optimized code into the ones that call them. let bottom_up = inlining::compute_bottom_up_order(&self); @@ -38,7 +39,7 @@ impl Ssa { function.simplify_function(); // Remove leftover instructions. - // XXX: Doing this would currently integration test failures, + // XXX: Doing this would currently cause integration test failures, // for example with `traits_in_crates_1` it eliminates a store to a mutable input reference. // function.dead_instruction_elimination(true); From deaa31130f60f2eb50c36974f20ee419bfcd85bc Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 16 Jan 2025 14:02:25 +0000 Subject: [PATCH 13/39] Rename pass to Preprocess from Pre-process --- compiler/noirc_evaluator/src/ssa.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index ae77665c373..70135a3fc28 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -166,7 +166,7 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result Date: Thu, 16 Jan 2025 14:14:46 +0000 Subject: [PATCH 14/39] Fix inlining recursion test --- .../noirc_evaluator/src/ssa/opt/inlining.rs | 36 +++++++++---------- compiler/noirc_frontend/src/tests.rs | 3 ++ 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 3dd1e9378b6..f5f19002f48 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -563,7 +563,7 @@ impl InlineContext { if self.recursion_level > RECURSION_LIMIT { panic!( - "Attempted to recur more than {RECURSION_LIMIT} times during inlining function '{}': {}", source_function.name(), source_function + "Attempted to recur more than {RECURSION_LIMIT} times during inlining function '{}':\n{}", source_function.name(), source_function ); } @@ -1021,6 +1021,7 @@ mod test { map::Id, types::{NumericType, Type}, }, + Ssa, }; #[test] @@ -1293,26 +1294,25 @@ mod test { #[test] #[should_panic( - expected = "Attempted to recur more than 1000 times during inlining function 'main': acir(inline) fn main f0 {" + expected = "Attempted to recur more than 1000 times during inlining function 'foo':\nacir(inline) fn foo f1 {" )] fn unconditional_recursion() { - // fn main f1 { - // b0(): - // call f1() - // return - // } - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let main = builder.import_function(main_id); - let results = builder.insert_call(main, Vec::new(), vec![]).to_vec(); - builder.terminate_with_return(results); - - let ssa = builder.finish(); - assert_eq!(ssa.functions.len(), 1); + let src = " + acir(inline) fn main f0 { + b0(): + call f1() + return + } + acir(inline) fn foo f1 { + b0(): + call f1() + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + assert_eq!(ssa.functions.len(), 2); - let inlined = ssa.inline_functions(i64::MAX); - assert_eq!(inlined.functions.len(), 0); + let _ = ssa.inline_functions(i64::MAX); } #[test] diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 637b15e7197..2e4bc134977 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -3456,6 +3456,9 @@ fn arithmetic_generics_rounding_fail_on_struct() { #[test] fn unconditional_recursion_fail() { + // These examples are self recursive top level functions, which actually + // would not be inlined now, but this error comes from the compilation checks, + // which is different from what the SSA would try to inline. let srcs = vec![ r#" fn main() { From f337a04d5c1598d59cfdcc5f268c906f1d4891aa Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 16 Jan 2025 23:37:34 +0000 Subject: [PATCH 15/39] Refactor to use InlineInfo --- compiler/noirc_evaluator/src/ssa.rs | 6 +- .../noirc_evaluator/src/ssa/opt/inlining.rs | 221 ++++++++++-------- .../src/ssa/opt/preprocess_fns.rs | 14 +- 3 files changed, 142 insertions(+), 99 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 70135a3fc28..4ceedba56a5 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -153,7 +153,7 @@ pub(crate) fn optimize_into_acir( /// Run all SSA passes. fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result { Ok(builder - .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions") + .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (1st)") .run_pass(Ssa::defunctionalize, "Defunctionalization") .run_pass(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs") .run_pass( @@ -173,7 +173,7 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result Result Ssa { - let inline_sources = get_functions_to_inline_into(&self, false, aggressiveness); - Self::inline_functions_inner(self, &inline_sources, false) + let inline_infos = compute_inline_infos(&self, false, aggressiveness); + Self::inline_functions_inner(self, &inline_infos, false) } /// Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points pub(crate) fn inline_functions_with_no_predicates(self, aggressiveness: i64) -> Ssa { - let inline_sources = get_functions_to_inline_into(&self, true, aggressiveness); - Self::inline_functions_inner(self, &inline_sources, true) + let inline_infos = compute_inline_infos(&self, true, aggressiveness); + Self::inline_functions_inner(self, &inline_infos, true) } fn inline_functions_inner( mut self, - inline_sources: &BTreeSet, + inline_infos: &InlineInfos, inline_no_predicates_functions: bool, ) -> Ssa { + let inline_targets = + inline_infos.iter().filter_map(|(id, info)| info.is_inline_target().then_some(*id)); + // NOTE: Functions are processed independently of each other, with the final mapping replacing the original, // instead of inlining the "leaf" functions, moving up towards the entry point. - self.functions = btree_map(inline_sources, |entry_point| { - let function = &self.functions[entry_point]; + self.functions = btree_map(inline_targets, |entry_point| { + let function = &self.functions[&entry_point]; let new_function = - function.inlined(&self, inline_no_predicates_functions, inline_sources); - (*entry_point, new_function) + function.inlined(&self, inline_no_predicates_functions, inline_infos); + (entry_point, new_function) }); self } @@ -79,7 +82,7 @@ impl Function { &self, ssa: &Ssa, inline_no_predicates_functions: bool, - functions_not_to_inline: &BTreeSet, + inline_infos: &InlineInfos, ) -> Function { let should_inline_call = |_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool { @@ -89,22 +92,26 @@ impl Function { if called_func_id == self.id() { return false; } - let function = &ssa.functions[&called_func_id]; + let callee = &ssa.functions[&called_func_id]; - match function.runtime() { + match callee.runtime() { RuntimeType::Acir(inline_type) => { // If the called function is acir, we inline if it's not an entry point // If we have not already finished the flattening pass, functions marked // to not have predicates should be preserved. let preserve_function = - !inline_no_predicates_functions && function.is_no_predicates(); + !inline_no_predicates_functions && callee.is_no_predicates(); + !inline_type.is_entry_point() && !preserve_function } RuntimeType::Brillig(_) => { - // If the called function is brillig, we inline only if it's into brillig and the function is not recursive + // If the called function is brillig, we inline only if it's into brillig and the function is not recursive and not too costly. self.runtime().is_brillig() - && !functions_not_to_inline.contains(&called_func_id) + && inline_infos + .get(&called_func_id) + .map(|info| info.should_inline) + .unwrap_or_default() } } }; @@ -186,27 +193,56 @@ fn called_functions(func: &Function) -> BTreeSet { called_functions_vec(func).into_iter().collect() } +/// Information about a function to aid the decision about whether to inline it or not. +/// The final decision depends on what we're inlining it into. +#[derive(Default, Debug)] +pub(super) struct InlineInfo { + is_brillig_entry_point: bool, + is_acir_entry_point: bool, + is_recursive: bool, + should_inline: bool, + weight: i64, + cost: i64, +} + +impl InlineInfo { + /// Functions which are to be retained, not inlined. + pub(super) fn is_inline_target(&self) -> bool { + self.is_brillig_entry_point + || self.is_acir_entry_point + || self.is_recursive + || !self.should_inline + } +} + +type InlineInfos = BTreeMap; + /// The functions we should inline into (and that should be left in the final program) are: /// - main /// - Any Brillig function called from Acir /// - Some Brillig functions depending on aggressiveness and some metrics /// - Any Acir functions with a [fold inline type][InlineType::Fold], -pub(super) fn get_functions_to_inline_into( +/// +/// The returned `InlineInfos` won't have every function in it, only the ones which the algorithm visited. +pub(super) fn compute_inline_infos( ssa: &Ssa, inline_no_predicates_functions: bool, aggressiveness: i64, -) -> BTreeSet { - let mut brillig_entry_points = BTreeSet::default(); - let mut acir_entry_points = BTreeSet::default(); - - if matches!(ssa.main().runtime(), RuntimeType::Brillig(_)) { - brillig_entry_points.insert(ssa.main_id); - } else { - acir_entry_points.insert(ssa.main_id); - } +) -> InlineInfos { + let mut inline_infos = InlineInfos::default(); + + inline_infos.insert( + ssa.main_id, + InlineInfo { + is_acir_entry_point: ssa.main().runtime().is_acir(), + is_brillig_entry_point: ssa.main().runtime().is_brillig(), + ..Default::default() + }, + ); + // Handle ACIR functions. for (func_id, function) in ssa.functions.iter() { - if matches!(function.runtime(), RuntimeType::Brillig(_)) { + if function.runtime().is_brillig() { continue; } @@ -214,31 +250,28 @@ pub(super) fn get_functions_to_inline_into( // to not have predicates should be preserved. let preserve_function = !inline_no_predicates_functions && function.is_no_predicates(); if function.runtime().is_entry_point() || preserve_function { - acir_entry_points.insert(*func_id); + inline_infos.entry(*func_id).or_default().is_acir_entry_point = true; } - for called_function_id in called_functions(function) { - if matches!(ssa.functions[&called_function_id].runtime(), RuntimeType::Brillig(_)) { - brillig_entry_points.insert(called_function_id); + // Any Brillig function called from ACIR is an entry into the Brillig VM. + for called_func_id in called_functions(function) { + if ssa.functions[&called_func_id].runtime().is_brillig() { + inline_infos.entry(called_func_id).or_default().is_brillig_entry_point = true; } } } let times_called = compute_times_called(ssa); - let brillig_functions_to_retain: BTreeSet<_> = compute_functions_to_retain( + mark_brillig_functions_to_retain( ssa, - &brillig_entry_points, - ×_called, inline_no_predicates_functions, aggressiveness, + ×_called, + &mut inline_infos, ); - acir_entry_points - .into_iter() - .chain(brillig_entry_points) - .chain(brillig_functions_to_retain) - .collect() + inline_infos } /// Compute the time each function is called from any other function. @@ -363,43 +396,52 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> { /// Traverse the call graph starting from a given function, marking function to be retained if they are: /// * recursive functions, or /// * the cost of inlining outweighs the cost of not doing so -fn should_retain_recursive( +fn mark_functions_to_retain_recursive( ssa: &Ssa, - func: FunctionId, - times_called: &HashMap, - // FunctionId -> (should_retain, weight) - should_retain_function: &mut HashMap, - mut explored_functions: im::HashSet, inline_no_predicates_functions: bool, aggressiveness: i64, + times_called: &HashMap, + inline_infos: &mut InlineInfos, + mut explored_functions: im::HashSet, + func: FunctionId, ) { - // We have already decided on this function - if should_retain_function.get(&func).is_some() { + // Check if we have set any of the fields this method touches. + let decided = |inline_infos: &InlineInfos| { + inline_infos + .get(&func) + .map(|info| info.is_recursive || info.should_inline || info.weight != 0) + .unwrap_or_default() + }; + + // Check if we have already decided on this function + if decided(inline_infos) { return; } - // Recursive, this function won't be inlined + + // If recursive, this function won't be inlined if explored_functions.contains(&func) { - should_retain_function.insert(func, (true, 0)); + inline_infos.entry(func).or_default().is_recursive = true; return; } explored_functions.insert(func); - // Decide on dependencies first - let called_functions = called_functions(&ssa.functions[&func]); - for function in called_functions.iter() { - should_retain_recursive( + // Decide on dependencies first, so we know their weight. + let called_functions = called_functions_vec(&ssa.functions[&func]); + for callee in &called_functions { + mark_functions_to_retain_recursive( ssa, - *function, - times_called, - should_retain_function, - explored_functions.clone(), inline_no_predicates_functions, aggressiveness, + times_called, + inline_infos, + explored_functions.clone(), + *callee, ); } + // We could have decided on this function while deciding on dependencies - // If the function is recursive - if should_retain_function.get(&func).is_some() { + // if the function is recursive. + if decided(inline_infos) { return; } @@ -407,15 +449,18 @@ fn should_retain_recursive( // We compute the weight (roughly the number of instructions) of the function after inlining // And the interface cost of the function (the inherent cost at the callsite, roughly the number of args and returns) // We then can compute an approximation of the cost of inlining vs the cost of retaining the function - // We do this computation using saturating i64s to avoid overflows + // We do this computation using saturating i64s to avoid overflows, + // and because we want to calculate a difference which can be negative. // Total weight of functions called by this one, unless we decided not to inline them. - let inlined_function_weights: i64 = called_functions.iter().fold(0, |acc, called_function| { - let (should_retain, weight) = should_retain_function[called_function]; - if should_retain { - acc + // Callees which appear multiple times would be inlined multiple times. + let inlined_function_weights: i64 = called_functions.iter().fold(0, |acc, callee| { + let info = &inline_infos[callee]; + // If the callee is not going to be inlined then we can ignore its cost. + if info.should_inline { + acc.saturating_add(info.weight) } else { - acc.saturating_add(weight) + acc } }); @@ -428,53 +473,47 @@ fn should_retain_recursive( let inline_cost = times_called.saturating_mul(this_function_weight); let retain_cost = times_called.saturating_mul(interface_cost) + this_function_weight; + let net_cost = inline_cost.saturating_sub(retain_cost); let runtime = ssa.functions[&func].runtime(); // We inline if the aggressiveness is higher than inline cost minus the retain cost // If aggressiveness is infinite, we'll always inline // If aggressiveness is 0, we'll inline when the inline cost is lower than the retain cost // If aggressiveness is minus infinity, we'll never inline (other than in the mandatory cases) - let should_inline = ((inline_cost.saturating_sub(retain_cost)) < aggressiveness) + let should_inline = (net_cost < aggressiveness) || runtime.is_inline_always() || (runtime.is_no_predicates() && inline_no_predicates_functions); - should_retain_function.insert(func, (!should_inline, this_function_weight)); + let info = inline_infos.entry(func).or_default(); + info.should_inline = should_inline; + info.weight = this_function_weight; + info.cost = net_cost; } -/// Gather the functions that should not be inlined. -fn compute_functions_to_retain( +/// Mark Brillig functions that should not be inlined because they are recursive or expensive. +fn mark_brillig_functions_to_retain( ssa: &Ssa, - entry_points: &BTreeSet, - times_called: &HashMap, inline_no_predicates_functions: bool, aggressiveness: i64, -) -> BTreeSet { - let mut should_retain_function = HashMap::default(); + times_called: &HashMap, + inline_infos: &mut BTreeMap, +) { + let brillig_entry_points = inline_infos + .iter() + .filter_map(|(id, info)| info.is_brillig_entry_point.then_some(*id)) + .collect::>(); - for entry_point in entry_points.iter() { - should_retain_recursive( + for entry_point in brillig_entry_points { + mark_functions_to_retain_recursive( ssa, - *entry_point, - times_called, - &mut should_retain_function, - im::HashSet::default(), inline_no_predicates_functions, aggressiveness, + times_called, + inline_infos, + im::HashSet::default(), + entry_point, ); } - - should_retain_function - .into_iter() - .filter_map( - |(func_id, (should_retain, _))| { - if should_retain { - Some(func_id) - } else { - None - } - }, - ) - .collect() } /// Compute a weight of a function based on the number of instructions in its reachable blocks. diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index e5754425d8b..515e73ac51a 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -11,9 +11,6 @@ impl Ssa { aggressiveness: i64, max_bytecode_increase_percent: Option, ) -> Ssa { - // No point pre-processing the functions that will never be inlined into others. - let not_to_inline = inlining::get_functions_to_inline_into(&self, false, aggressiveness); - // Bottom-up order, starting with the "leaf" functions, so we inline already optimized code into the ones that call them. let bottom_up = inlining::compute_bottom_up_order(&self); @@ -22,13 +19,20 @@ impl Ssa { let mean_weight = total_weight / bottom_up.len(); let cutoff_weight = mean_weight; + // Preliminary inlining decisions. + // Functions which are inline targets will be processed in later passes. + // Here we want to treat the functions which will be inlined into them. + let inline_infos = inlining::compute_inline_infos(&self, false, aggressiveness); + for (id, _) in bottom_up .into_iter() - .filter(|(id, _)| !not_to_inline.contains(id)) + .filter(|(id, _)| { + inline_infos.get(id).map(|info| !info.is_inline_target()).unwrap_or(true) + }) .filter(|(_, weight)| *weight < cutoff_weight) { let function = &self.functions[&id]; - let mut function = function.inlined(&self, false, ¬_to_inline); + let mut function = function.inlined(&self, false, &inline_infos); // Help unrolling determine bounds. function.as_slice_optimization(); // We might not be able to unroll all loops without fully inlining them, so ignore errors. From ca7ba9e2e9ef81a1c716d0eb6338d078b7703c01 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 16 Jan 2025 23:50:50 +0000 Subject: [PATCH 16/39] Fix remove unreachable functions to remove uncalled Brillig --- .../noirc_evaluator/src/ssa/opt/remove_unreachable.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs index 41023b5f376..f435256f261 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs @@ -19,9 +19,13 @@ impl Ssa { pub(crate) fn remove_unreachable_functions(mut self) -> Self { let mut used_functions = HashSet::default(); - for function_id in self.functions.keys() { - if self.is_entry_point(*function_id) { - collect_reachable_functions(&self, *function_id, &mut used_functions); + for (id, function) in self.functions.iter() { + // XXX: `self.is_entry_point(*id)` could leave Brillig functions that nobody calls in the SSA. + let is_entry_point = function.id() == self.main_id + || function.runtime().is_acir() && function.runtime().is_entry_point(); + + if is_entry_point { + collect_reachable_functions(&self, *id, &mut used_functions); } } From 64841c486e9f76e913487875012387ad40b739ce Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 01:01:31 +0000 Subject: [PATCH 17/39] Fix defunctionalization to inherit runtime of caller --- .../noirc_evaluator/src/ssa/ir/function.rs | 2 +- .../src/ssa/opt/defunctionalize.rs | 48 ++++++++++--------- .../src/monomorphization/ast.rs | 4 +- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/function.rs b/compiler/noirc_evaluator/src/ssa/ir/function.rs index b59b0c18a10..b21a84d16dc 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -12,7 +12,7 @@ use super::map::Id; use super::types::Type; use super::value::ValueId; -#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub(crate) enum RuntimeType { // A noir function, to be compiled in ACIR and executed by ACVM Acir(InlineType), diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index 186f10c53e6..ca18f556906 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -13,7 +13,7 @@ use crate::ssa::{ function_builder::FunctionBuilder, ir::{ basic_block::BasicBlockId, - function::{Function, FunctionId, Signature}, + function::{Function, FunctionId, RuntimeType, Signature}, instruction::{BinaryOp, Instruction}, types::{NumericType, Type}, value::{Value, ValueId}, @@ -43,12 +43,15 @@ struct ApplyFunction { dispatches_to_multiple_functions: bool, } +type Variants = BTreeMap<(Signature, RuntimeType), Vec>; +type ApplyFunctions = HashMap<(Signature, RuntimeType), ApplyFunction>; + /// Performs defunctionalization on all functions /// This is done by changing all functions as value to be a number (FieldElement) /// And creating apply functions that dispatch to the correct target by runtime comparisons with constants #[derive(Debug, Clone)] struct DefunctionalizationContext { - apply_functions: HashMap, + apply_functions: ApplyFunctions, } impl Ssa { @@ -104,7 +107,7 @@ impl DefunctionalizationContext { }; // Find the correct apply function - let apply_function = self.get_apply_function(&signature); + let apply_function = self.get_apply_function(signature, func.runtime()); // Replace the instruction with a call to apply let apply_function_value_id = func.dfg.import_function(apply_function.id); @@ -152,19 +155,21 @@ impl DefunctionalizationContext { } /// Returns the apply function for the given signature - fn get_apply_function(&self, signature: &Signature) -> ApplyFunction { - *self.apply_functions.get(signature).expect("Could not find apply function") + fn get_apply_function(&self, signature: Signature, runtime: RuntimeType) -> ApplyFunction { + *self.apply_functions.get(&(signature, runtime)).expect("Could not find apply function") } } /// Collects all functions used as values that can be called by their signatures -fn find_variants(ssa: &Ssa) -> BTreeMap> { - let mut dynamic_dispatches: BTreeSet = BTreeSet::new(); +fn find_variants(ssa: &Ssa) -> Variants { + let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new(); let mut functions_as_values: BTreeSet = BTreeSet::new(); for function in ssa.functions.values() { functions_as_values.extend(find_functions_as_values(function)); - dynamic_dispatches.extend(find_dynamic_dispatches(function)); + dynamic_dispatches.extend( + find_dynamic_dispatches(function).into_iter().map(|sig| (sig, function.runtime())), + ); } let mut signature_to_functions_as_value: BTreeMap> = BTreeMap::new(); @@ -174,16 +179,12 @@ fn find_variants(ssa: &Ssa) -> BTreeMap> { signature_to_functions_as_value.entry(signature).or_default().push(function_id); } - let mut variants = BTreeMap::new(); + let mut variants: Variants = BTreeMap::new(); - for dispatch_signature in dynamic_dispatches { - let mut target_fns = vec![]; - for (target_signature, functions) in &signature_to_functions_as_value { - if &dispatch_signature == target_signature { - target_fns.extend(functions); - } - } - variants.insert(dispatch_signature, target_fns); + for (dispatch_signature, caller_runtime) in dynamic_dispatches { + let target_fns = + signature_to_functions_as_value.get(&dispatch_signature).cloned().unwrap_or_default(); + variants.insert((dispatch_signature, caller_runtime), target_fns); } variants @@ -247,10 +248,10 @@ fn find_dynamic_dispatches(func: &Function) -> BTreeSet { fn create_apply_functions( ssa: &mut Ssa, - variants_map: BTreeMap>, -) -> HashMap { + variants_map: BTreeMap<(Signature, RuntimeType), Vec>, +) -> ApplyFunctions { let mut apply_functions = HashMap::default(); - for (signature, variants) in variants_map.into_iter() { + for ((signature, runtime), variants) in variants_map.into_iter() { assert!( !variants.is_empty(), "ICE: at least one variant should exist for a dynamic call {signature:?}" @@ -258,11 +259,12 @@ fn create_apply_functions( let dispatches_to_multiple_functions = variants.len() > 1; let id = if dispatches_to_multiple_functions { - create_apply_function(ssa, signature.clone(), variants) + create_apply_function(ssa, signature.clone(), runtime, variants) } else { variants[0] }; - apply_functions.insert(signature, ApplyFunction { id, dispatches_to_multiple_functions }); + apply_functions + .insert((signature, runtime), ApplyFunction { id, dispatches_to_multiple_functions }); } apply_functions } @@ -275,6 +277,7 @@ fn function_id_to_field(function_id: FunctionId) -> FieldElement { fn create_apply_function( ssa: &mut Ssa, signature: Signature, + runtime: RuntimeType, function_ids: Vec, ) -> FunctionId { assert!(!function_ids.is_empty()); @@ -282,6 +285,7 @@ fn create_apply_function( ssa.add_fn(|id| { let mut function_builder = FunctionBuilder::new("apply".to_string(), id); function_builder.set_globals(globals); + function_builder.set_runtime(runtime); let target_id = function_builder.add_parameter(Type::field()); let params_ids = vecmap(signature.params, |typ| function_builder.add_parameter(typ)); diff --git a/compiler/noirc_frontend/src/monomorphization/ast.rs b/compiler/noirc_frontend/src/monomorphization/ast.rs index d219e8f7c2d..05df3887848 100644 --- a/compiler/noirc_frontend/src/monomorphization/ast.rs +++ b/compiler/noirc_frontend/src/monomorphization/ast.rs @@ -227,7 +227,9 @@ pub type Parameters = Vec<(LocalId, /*mutable:*/ bool, /*name:*/ String, Type)>; /// Represents how an Acir function should be inlined. /// This type is only relevant for ACIR functions as we do not inline any Brillig functions -#[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[derive( + Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord, +)] pub enum InlineType { /// The most basic entry point can expect all its functions to be inlined. /// All function calls are expected to be inlined into a single ACIR. From e9f7bad958b062a4b137468d735f2956d708cac7 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 01:33:42 +0000 Subject: [PATCH 18/39] Refactor defunctionalization to not create intermediate blocks before return --- .../noirc_evaluator/src/ssa/opt/defunctionalize.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index ca18f556906..a2151f9288f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -343,22 +343,21 @@ fn create_apply_function( }) } -/// Crates a return block, if no previous return exists, it will create a final return -/// Else, it will create a bypass return block that points to the previous return block +/// If no previous return target exists, it will create a final return, +/// otherwise returns the existing return block to jump to. fn build_return_block( builder: &mut FunctionBuilder, previous_block: BasicBlockId, passed_types: &[Type], target: Option, ) -> BasicBlockId { + if let Some(return_block) = target { + return return_block; + } let return_block = builder.insert_block(); builder.switch_to_block(return_block); - let params = vecmap(passed_types, |typ| builder.add_block_parameter(return_block, typ.clone())); - match target { - None => builder.terminate_with_return(params), - Some(target) => builder.terminate_with_jmp(target, params), - } + builder.terminate_with_return(params); builder.switch_to_block(previous_block); return_block } From da1fe6af621edb325db6650a6b6ea03d8b850730 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 10:11:37 +0000 Subject: [PATCH 19/39] Create test to show defunctionalization not handling runtime --- .../src/ssa/opt/defunctionalize.rs | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index 186f10c53e6..da4cbfa89b5 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -358,3 +358,94 @@ fn build_return_block( builder.switch_to_block(previous_block); return_block } + +#[cfg(test)] +mod tests { + use crate::ssa::opt::assert_normalized_ssa_equals; + + use super::Ssa; + + #[test] + fn apply_inherits_caller_runtime() { + // Extracted from `execution_success/brillig_fns_as_values` with `--force-brillig` + let src = " + brillig(inline) fn main f0 { + b0(v0: u32): + v3 = call f1(f2, v0) -> u32 + v5 = add v0, u32 1 + v6 = eq v3, v5 + constrain v3 == v5 + v9 = call f1(f3, v0) -> u32 + v10 = add v0, u32 1 + v11 = eq v9, v10 + constrain v9 == v10 + return + } + brillig(inline) fn wrapper f1 { + b0(v0: function, v1: u32): + v2 = call v0(v1) -> u32 + return v2 + } + brillig(inline) fn increment f2 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + brillig(inline) fn increment_acir f3 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.defunctionalize(); + + let expected = " + brillig(inline) fn main f0 { + b0(v0: u32): + v3 = call f1(Field 2, v0) -> u32 + v5 = add v0, u32 1 + v6 = eq v3, v5 + constrain v3 == v5 + v9 = call f1(Field 3, v0) -> u32 + v10 = add v0, u32 1 + v11 = eq v9, v10 + constrain v9 == v10 + return + } + brillig(inline) fn wrapper f1 { + b0(v0: Field, v1: u32): + v3 = call f4(v0, v1) -> u32 + return v3 + } + brillig(inline) fn increment f2 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + brillig(inline) fn increment_acir f3 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + brillig(inline) fn apply f4 { + b0(v0: Field, v1: u32): + v5 = eq v0, Field 2 + jmpif v5 then: b3, else: b1 + b1(): + constrain v0 == Field 3 + v8 = call f3(v1) -> u32 + jmp b2(v8) + b2(v2: u32): + jmp b4(v2) + b3(): + v10 = call f2(v1) -> u32 + jmp b4(v10) + b4(v3: u32): + return v3 + } + "; + assert_normalized_ssa_equals(ssa, expected); + } +} From 7f2cdefd62e7ffd31b59ea13918bc790b439593a Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 13:25:16 +0000 Subject: [PATCH 20/39] Fix SSA parser to handle function as value --- .../src/ssa/parser/into_ssa.rs | 42 +++++++++++++------ 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs index fcaaf74f533..d5d5593b884 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs @@ -24,7 +24,7 @@ impl ParsedSsa { struct Translator { builder: FunctionBuilder, - /// Maps function names to their IDs + /// Maps internal function names (e.g. "f1") to their IDs functions: HashMap, /// Maps block names to their IDs @@ -135,14 +135,14 @@ impl Translator { match block.terminator { ParsedTerminator::Jmp { destination, arguments } => { - let block_id = self.lookup_block(destination)?; + let block_id = self.lookup_block(&destination)?; let arguments = self.translate_values(arguments)?; self.builder.terminate_with_jmp(block_id, arguments); } ParsedTerminator::Jmpif { condition, then_block, else_block } => { let condition = self.translate_value(condition)?; - let then_destination = self.lookup_block(then_block)?; - let else_destination = self.lookup_block(else_block)?; + let then_destination = self.lookup_block(&then_block)?; + let else_destination = self.lookup_block(&else_block)?; self.builder.terminate_with_jmpif(condition, then_destination, else_destination); } ParsedTerminator::Return(values) => { @@ -187,8 +187,17 @@ impl Translator { let function_id = if let Some(id) = self.builder.import_intrinsic(&function.name) { id } else { - let function_id = self.lookup_function(function)?; - self.builder.import_function(function_id) + match self.lookup_function(&function) { + Ok(f) => self.builder.import_function(f), + Err(e) => { + if let Ok(v) = self.lookup_variable(&function) { + // e.g. `v2 = call v0(v1) -> u32`, a lambda passed as a parameter + v + } else { + return Err(e); + } + } + } }; let arguments = self.translate_values(arguments)?; @@ -293,7 +302,14 @@ impl Translator { ParsedValue::NumericConstant { constant, typ } => { Ok(self.builder.numeric_constant(constant, typ.unwrap_numeric())) } - ParsedValue::Variable(identifier) => self.lookup_variable(identifier), + ParsedValue::Variable(identifier) => self.lookup_variable(&identifier).or_else(|e| { + if let Ok(f) = self.lookup_function(&identifier) { + // e.g. `v3 = call f1(f2, v0) -> u32` + Ok(self.builder.import_function(f)) + } else { + Err(e) + } + }), } } @@ -314,27 +330,27 @@ impl Translator { Ok(()) } - fn lookup_variable(&mut self, identifier: Identifier) -> Result { + fn lookup_variable(&mut self, identifier: &Identifier) -> Result { if let Some(value_id) = self.variables[&self.current_function_id()].get(&identifier.name) { Ok(*value_id) } else { - Err(SsaError::UnknownVariable(identifier)) + Err(SsaError::UnknownVariable(identifier.clone())) } } - fn lookup_block(&mut self, identifier: Identifier) -> Result { + fn lookup_block(&mut self, identifier: &Identifier) -> Result { if let Some(block_id) = self.blocks[&self.current_function_id()].get(&identifier.name) { Ok(*block_id) } else { - Err(SsaError::UnknownBlock(identifier)) + Err(SsaError::UnknownBlock(identifier.clone())) } } - fn lookup_function(&mut self, identifier: Identifier) -> Result { + fn lookup_function(&mut self, identifier: &Identifier) -> Result { if let Some(function_id) = self.functions.get(&identifier.name) { Ok(*function_id) } else { - Err(SsaError::UnknownFunction(identifier)) + Err(SsaError::UnknownFunction(identifier.clone())) } } From bd4f90b054038e52af56f21be828efebd4c0d005 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 01:01:31 +0000 Subject: [PATCH 21/39] Fix defunctionalization to inherit runtime of caller --- .../noirc_evaluator/src/ssa/ir/function.rs | 2 +- .../src/ssa/opt/defunctionalize.rs | 48 ++++++++++--------- .../src/monomorphization/ast.rs | 4 +- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/function.rs b/compiler/noirc_evaluator/src/ssa/ir/function.rs index b59b0c18a10..b21a84d16dc 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -12,7 +12,7 @@ use super::map::Id; use super::types::Type; use super::value::ValueId; -#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub(crate) enum RuntimeType { // A noir function, to be compiled in ACIR and executed by ACVM Acir(InlineType), diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index da4cbfa89b5..135bbf84d09 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -13,7 +13,7 @@ use crate::ssa::{ function_builder::FunctionBuilder, ir::{ basic_block::BasicBlockId, - function::{Function, FunctionId, Signature}, + function::{Function, FunctionId, RuntimeType, Signature}, instruction::{BinaryOp, Instruction}, types::{NumericType, Type}, value::{Value, ValueId}, @@ -43,12 +43,15 @@ struct ApplyFunction { dispatches_to_multiple_functions: bool, } +type Variants = BTreeMap<(Signature, RuntimeType), Vec>; +type ApplyFunctions = HashMap<(Signature, RuntimeType), ApplyFunction>; + /// Performs defunctionalization on all functions /// This is done by changing all functions as value to be a number (FieldElement) /// And creating apply functions that dispatch to the correct target by runtime comparisons with constants #[derive(Debug, Clone)] struct DefunctionalizationContext { - apply_functions: HashMap, + apply_functions: ApplyFunctions, } impl Ssa { @@ -104,7 +107,7 @@ impl DefunctionalizationContext { }; // Find the correct apply function - let apply_function = self.get_apply_function(&signature); + let apply_function = self.get_apply_function(signature, func.runtime()); // Replace the instruction with a call to apply let apply_function_value_id = func.dfg.import_function(apply_function.id); @@ -152,19 +155,21 @@ impl DefunctionalizationContext { } /// Returns the apply function for the given signature - fn get_apply_function(&self, signature: &Signature) -> ApplyFunction { - *self.apply_functions.get(signature).expect("Could not find apply function") + fn get_apply_function(&self, signature: Signature, runtime: RuntimeType) -> ApplyFunction { + *self.apply_functions.get(&(signature, runtime)).expect("Could not find apply function") } } /// Collects all functions used as values that can be called by their signatures -fn find_variants(ssa: &Ssa) -> BTreeMap> { - let mut dynamic_dispatches: BTreeSet = BTreeSet::new(); +fn find_variants(ssa: &Ssa) -> Variants { + let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new(); let mut functions_as_values: BTreeSet = BTreeSet::new(); for function in ssa.functions.values() { functions_as_values.extend(find_functions_as_values(function)); - dynamic_dispatches.extend(find_dynamic_dispatches(function)); + dynamic_dispatches.extend( + find_dynamic_dispatches(function).into_iter().map(|sig| (sig, function.runtime())), + ); } let mut signature_to_functions_as_value: BTreeMap> = BTreeMap::new(); @@ -174,16 +179,12 @@ fn find_variants(ssa: &Ssa) -> BTreeMap> { signature_to_functions_as_value.entry(signature).or_default().push(function_id); } - let mut variants = BTreeMap::new(); + let mut variants: Variants = BTreeMap::new(); - for dispatch_signature in dynamic_dispatches { - let mut target_fns = vec![]; - for (target_signature, functions) in &signature_to_functions_as_value { - if &dispatch_signature == target_signature { - target_fns.extend(functions); - } - } - variants.insert(dispatch_signature, target_fns); + for (dispatch_signature, caller_runtime) in dynamic_dispatches { + let target_fns = + signature_to_functions_as_value.get(&dispatch_signature).cloned().unwrap_or_default(); + variants.insert((dispatch_signature, caller_runtime), target_fns); } variants @@ -247,10 +248,10 @@ fn find_dynamic_dispatches(func: &Function) -> BTreeSet { fn create_apply_functions( ssa: &mut Ssa, - variants_map: BTreeMap>, -) -> HashMap { + variants_map: BTreeMap<(Signature, RuntimeType), Vec>, +) -> ApplyFunctions { let mut apply_functions = HashMap::default(); - for (signature, variants) in variants_map.into_iter() { + for ((signature, runtime), variants) in variants_map.into_iter() { assert!( !variants.is_empty(), "ICE: at least one variant should exist for a dynamic call {signature:?}" @@ -258,11 +259,12 @@ fn create_apply_functions( let dispatches_to_multiple_functions = variants.len() > 1; let id = if dispatches_to_multiple_functions { - create_apply_function(ssa, signature.clone(), variants) + create_apply_function(ssa, signature.clone(), runtime, variants) } else { variants[0] }; - apply_functions.insert(signature, ApplyFunction { id, dispatches_to_multiple_functions }); + apply_functions + .insert((signature, runtime), ApplyFunction { id, dispatches_to_multiple_functions }); } apply_functions } @@ -275,6 +277,7 @@ fn function_id_to_field(function_id: FunctionId) -> FieldElement { fn create_apply_function( ssa: &mut Ssa, signature: Signature, + runtime: RuntimeType, function_ids: Vec, ) -> FunctionId { assert!(!function_ids.is_empty()); @@ -282,6 +285,7 @@ fn create_apply_function( ssa.add_fn(|id| { let mut function_builder = FunctionBuilder::new("apply".to_string(), id); function_builder.set_globals(globals); + function_builder.set_runtime(runtime); let target_id = function_builder.add_parameter(Type::field()); let params_ids = vecmap(signature.params, |typ| function_builder.add_parameter(typ)); diff --git a/compiler/noirc_frontend/src/monomorphization/ast.rs b/compiler/noirc_frontend/src/monomorphization/ast.rs index d219e8f7c2d..05df3887848 100644 --- a/compiler/noirc_frontend/src/monomorphization/ast.rs +++ b/compiler/noirc_frontend/src/monomorphization/ast.rs @@ -227,7 +227,9 @@ pub type Parameters = Vec<(LocalId, /*mutable:*/ bool, /*name:*/ String, Type)>; /// Represents how an Acir function should be inlined. /// This type is only relevant for ACIR functions as we do not inline any Brillig functions -#[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[derive( + Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord, +)] pub enum InlineType { /// The most basic entry point can expect all its functions to be inlined. /// All function calls are expected to be inlined into a single ACIR. From cfdbcad9e50577e3aec208366d744c7094dcb03b Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 13:35:02 +0000 Subject: [PATCH 22/39] Add another test to check 2 runtimes are created --- .../src/ssa/opt/defunctionalize.rs | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index 135bbf84d09..1dd5fcd6b0f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -452,4 +452,50 @@ mod tests { "; assert_normalized_ssa_equals(ssa, expected); } + + #[test] + fn apply_created_per_caller_runtime() { + let src = " + acir(inline) fn main f0 { + b0(v0: u32): + v3 = call f1(f2, v0) -> u32 + v5 = add v0, u32 1 + v6 = eq v3, v5 + constrain v3 == v5 + v9 = call f4(f3, v0) -> u32 + v10 = add v0, u32 1 + v11 = eq v9, v10 + constrain v9 == v10 + return + } + brillig(inline) fn wrapper f1 { + b0(v0: function, v1: u32): + v2 = call v0(v1) -> u32 + return v2 + } + acir(inline) fn wrapper_acir f4 { + b0(v0: function, v1: u32): + v2 = call v0(v1) -> u32 + return v2 + } + brillig(inline) fn increment f2 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + acir(inline) fn increment_acir f3 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.defunctionalize(); + + let applies = ssa.functions.values().filter(|f| f.name() == "apply").collect::>(); + assert_eq!(applies.len(), 2); + assert!(applies.iter().any(|f| f.runtime().is_acir())); + assert!(applies.iter().any(|f| f.runtime().is_brillig())); + } } From 1b7dc2e644e35b07f5905cac123f886fb302e3d9 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 14:04:10 +0000 Subject: [PATCH 23/39] fix: Simplify defunctionalize return (#7101) --- .../src/ssa/opt/defunctionalize.rs | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index 1dd5fcd6b0f..754e550af98 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -343,22 +343,21 @@ fn create_apply_function( }) } -/// Crates a return block, if no previous return exists, it will create a final return -/// Else, it will create a bypass return block that points to the previous return block +/// If no previous return target exists, it will create a final return, +/// otherwise returns the existing return block to jump to. fn build_return_block( builder: &mut FunctionBuilder, previous_block: BasicBlockId, passed_types: &[Type], target: Option, ) -> BasicBlockId { + if let Some(return_block) = target { + return return_block; + } let return_block = builder.insert_block(); builder.switch_to_block(return_block); - let params = vecmap(passed_types, |typ| builder.add_block_parameter(return_block, typ.clone())); - match target { - None => builder.terminate_with_return(params), - Some(target) => builder.terminate_with_jmp(target, params), - } + builder.terminate_with_return(params); builder.switch_to_block(previous_block); return_block } @@ -435,19 +434,17 @@ mod tests { } brillig(inline) fn apply f4 { b0(v0: Field, v1: u32): - v5 = eq v0, Field 2 - jmpif v5 then: b3, else: b1 + v4 = eq v0, Field 2 + jmpif v4 then: b2, else: b1 b1(): constrain v0 == Field 3 - v8 = call f3(v1) -> u32 - jmp b2(v8) - b2(v2: u32): - jmp b4(v2) - b3(): - v10 = call f2(v1) -> u32 - jmp b4(v10) - b4(v3: u32): - return v3 + v7 = call f3(v1) -> u32 + jmp b3(v7) + b2(): + v9 = call f2(v1) -> u32 + jmp b3(v9) + b3(v2: u32): + return v2 } "; assert_normalized_ssa_equals(ssa, expected); From 8102dd2b7adcf0b75ce3537ab944fd116de365cb Mon Sep 17 00:00:00 2001 From: Tom French Date: Fri, 17 Jan 2025 14:36:33 +0000 Subject: [PATCH 24/39] . --- .../src/ssa/opt/defunctionalize.rs | 9 ++++++++- .../noirc_evaluator/src/ssa/opt/inlining.rs | 17 +++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index 754e550af98..3062edfe60c 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -8,6 +8,7 @@ use std::collections::{BTreeMap, BTreeSet, HashSet}; use acvm::FieldElement; use iter_extended::vecmap; +use noirc_frontend::monomorphization::ast::InlineType; use crate::ssa::{ function_builder::FunctionBuilder, @@ -277,7 +278,7 @@ fn function_id_to_field(function_id: FunctionId) -> FieldElement { fn create_apply_function( ssa: &mut Ssa, signature: Signature, - runtime: RuntimeType, + caller_runtime: RuntimeType, function_ids: Vec, ) -> FunctionId { assert!(!function_ids.is_empty()); @@ -285,6 +286,12 @@ fn create_apply_function( ssa.add_fn(|id| { let mut function_builder = FunctionBuilder::new("apply".to_string(), id); function_builder.set_globals(globals); + + // We want to push for apply functions to be inlined more aggressively. + let runtime = match caller_runtime { + RuntimeType::Acir(_) => RuntimeType::Acir(InlineType::InlineAlways), + RuntimeType::Brillig(_) => RuntimeType::Brillig(InlineType::InlineAlways), + }; function_builder.set_runtime(runtime); let target_id = function_builder.add_parameter(Type::field()); let params_ids = vecmap(signature.params, |typ| function_builder.add_parameter(typ)); diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index b5cbc90e30d..5d97a416166 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -66,22 +66,27 @@ impl Ssa { self.functions = btree_map(inline_sources, |entry_point| { let should_inline_call = |_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool { - let function = &ssa.functions[&called_func_id]; + let callee = &ssa.functions[&called_func_id]; + let caller_runtime = ssa.functions[entry_point].runtime(); - match function.runtime() { + match callee.runtime() { RuntimeType::Acir(inline_type) => { // If the called function is acir, we inline if it's not an entry point // If we have not already finished the flattening pass, functions marked // to not have predicates should be preserved. let preserve_function = - !inline_no_predicates_functions && function.is_no_predicates(); + !inline_no_predicates_functions && callee.is_no_predicates(); !inline_type.is_entry_point() && !preserve_function } RuntimeType::Brillig(_) => { - // If the called function is brillig, we inline only if it's into brillig and the function is not recursive - ssa.functions[entry_point].runtime().is_brillig() - && !inline_sources.contains(&called_func_id) + if caller_runtime.is_acir() { + // We never inline a brillig function into an ACIR function. + return false; + } + + // Avoid inlining recursive functions. + !inline_sources.contains(&called_func_id) } } }; From 14057335c4ac4dff2560b5e59243670ff18fb32b Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 14:45:01 +0000 Subject: [PATCH 25/39] Add test with expected SSA --- compiler/noirc_evaluator/src/ssa/opt/die.rs | 48 +++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/compiler/noirc_evaluator/src/ssa/opt/die.rs b/compiler/noirc_evaluator/src/ssa/opt/die.rs index eed1af8251b..b21ea65cb52 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/die.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/die.rs @@ -941,4 +941,52 @@ mod test { let ssa = ssa.dead_instruction_elimination(); assert_normalized_ssa_equals(ssa, src); } + + #[test] + fn do_not_remove_mutable_reference_params() { + let src = " + acir(inline) fn main f0 { + b0(v0: Field, v1: Field): + v2 = allocate -> &mut Field + store v0 at v2 + call f1(v2) + v4 = load v2 -> Field + v5 = eq v4, v1 + constrain v4 == v1 + return + } + acir(inline) fn Add10 f1 { + b0(v0: &mut Field): + v1 = load v0 -> Field + v2 = load v0 -> Field + v4 = add v2, Field 10 + store v4 at v0 + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.dead_instruction_elimination(); + + let expected = " + acir(inline) fn main f0 { + b0(v0: Field, v1: Field): + v2 = allocate -> &mut Field + store v0 at v2 + call f1(v2) + v4 = load v2 -> Field + v5 = eq v4, v1 + constrain v4 == v1 + return + } + acir(inline) fn Add10 f1 { + b0(v0: &mut Field): + v1 = load v0 -> Field + v2 = add v1, Field 10 + store v2 at v0 + return + } + "; + assert_normalized_ssa_equals(ssa, &expected); + } } From f206178552e8f9db18f53f4346c78286ca435830 Mon Sep 17 00:00:00 2001 From: Tom French Date: Fri, 17 Jan 2025 14:49:51 +0000 Subject: [PATCH 26/39] . --- .../src/ssa/parser/into_ssa.rs | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs index d5d5593b884..4d8bca9a1a9 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs @@ -187,17 +187,13 @@ impl Translator { let function_id = if let Some(id) = self.builder.import_intrinsic(&function.name) { id } else { - match self.lookup_function(&function) { - Ok(f) => self.builder.import_function(f), - Err(e) => { - if let Ok(v) = self.lookup_variable(&function) { - // e.g. `v2 = call v0(v1) -> u32`, a lambda passed as a parameter - v - } else { - return Err(e); - } - } - } + let maybe_func = + self.lookup_function(&function).map(|f| self.builder.import_function(f)); + + maybe_func.or_else(|e| { + // e.g. `v2 = call v0(v1) -> u32`, a lambda passed as a parameter + self.lookup_variable(&function).map_err(|_| e) + })? }; let arguments = self.translate_values(arguments)?; @@ -303,12 +299,12 @@ impl Translator { Ok(self.builder.numeric_constant(constant, typ.unwrap_numeric())) } ParsedValue::Variable(identifier) => self.lookup_variable(&identifier).or_else(|e| { - if let Ok(f) = self.lookup_function(&identifier) { - // e.g. `v3 = call f1(f2, v0) -> u32` - Ok(self.builder.import_function(f)) - } else { - Err(e) - } + self.lookup_function(&identifier) + .map(|f| { + // e.g. `v3 = call f1(f2, v0) -> u32` + self.builder.import_function(f) + }) + .map_err(|_| e) }), } } From 939cf847813b3ba44c4bbc02523b10f7a4316391 Mon Sep 17 00:00:00 2001 From: Tom French Date: Fri, 17 Jan 2025 15:03:45 +0000 Subject: [PATCH 27/39] chore: fix tests --- compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index 3062edfe60c..a6e04332c0a 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -439,7 +439,7 @@ mod tests { v2 = add v0, u32 1 return v2 } - brillig(inline) fn apply f4 { + brillig(inline_always) fn apply f4 { b0(v0: Field, v1: u32): v4 = eq v0, Field 2 jmpif v4 then: b2, else: b1 From fc36f522231fd753fafaeb5c211ebe7b14ce760f Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 15:03:33 +0000 Subject: [PATCH 28/39] Add flag to tell the DIE not to remove STORE yet --- .../noirc_evaluator/src/ssa/ir/instruction.rs | 7 ++-- compiler/noirc_evaluator/src/ssa/opt/die.rs | 40 ++++++++++++++----- .../noirc_evaluator/src/ssa/opt/unrolling.rs | 2 +- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 171ca30f5f4..5806e62bf95 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -10,7 +10,7 @@ use fxhash::FxHasher64; use iter_extended::vecmap; use noirc_frontend::hir_def::types::Type as HirType; -use crate::ssa::{ir::function::RuntimeType, opt::flatten_cfg::value_merger::ValueMerger}; +use crate::ssa::opt::flatten_cfg::value_merger::ValueMerger; use super::{ basic_block::BasicBlockId, @@ -506,7 +506,7 @@ impl Instruction { } } - pub(crate) fn can_eliminate_if_unused(&self, function: &Function) -> bool { + pub(crate) fn can_eliminate_if_unused(&self, function: &Function, flattened: bool) -> bool { use Instruction::*; match self { Binary(binary) => { @@ -539,8 +539,7 @@ impl Instruction { // pass where this check is done, but does mean that we cannot perform mem2reg // after the DIE pass. Store { .. } => { - matches!(function.runtime(), RuntimeType::Acir(_)) - && function.reachable_blocks().len() == 1 + flattened && function.runtime().is_acir() && function.reachable_blocks().len() == 1 } Constrain(..) diff --git a/compiler/noirc_evaluator/src/ssa/opt/die.rs b/compiler/noirc_evaluator/src/ssa/opt/die.rs index b21ea65cb52..48e55bc49e5 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/die.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/die.rs @@ -22,9 +22,17 @@ use super::rc::{pop_rc_for, RcInstruction}; impl Ssa { /// Performs Dead Instruction Elimination (DIE) to remove any instructions with /// unused results. + /// + /// This step should come after the flattening of the CFG and mem2reg. #[tracing::instrument(level = "trace", skip(self))] - pub(crate) fn dead_instruction_elimination(mut self) -> Ssa { - self.functions.par_iter_mut().for_each(|(_, func)| func.dead_instruction_elimination(true)); + pub(crate) fn dead_instruction_elimination(self) -> Ssa { + self.dead_instruction_elimination_inner(true) + } + + fn dead_instruction_elimination_inner(mut self, flattened: bool) -> Ssa { + self.functions + .par_iter_mut() + .for_each(|(_, func)| func.dead_instruction_elimination(true, flattened)); self } @@ -37,8 +45,12 @@ impl Function { /// instructions that reference results from an instruction in another block are evaluated first. /// If we did not iterate blocks in this order we could not safely say whether or not the results /// of its instructions are needed elsewhere. - pub(crate) fn dead_instruction_elimination(&mut self, insert_out_of_bounds_checks: bool) { - let mut context = Context::default(); + pub(crate) fn dead_instruction_elimination( + &mut self, + insert_out_of_bounds_checks: bool, + flattened: bool, + ) { + let mut context = Context { flattened, ..Default::default() }; for call_data in &self.dfg.data_bus.call_data { context.mark_used_instruction_results(&self.dfg, call_data.array_id); } @@ -58,7 +70,7 @@ impl Function { // instructions (we don't want to remove those checks, or instructions that are // dependencies of those checks) if inserted_out_of_bounds_checks { - self.dead_instruction_elimination(false); + self.dead_instruction_elimination(false, flattened); return; } @@ -76,6 +88,11 @@ struct Context { /// they technically contain side-effects but we still want to remove them if their /// `value` parameter is not used elsewhere. rc_instructions: Vec<(InstructionId, BasicBlockId)>, + + /// The elimination of certain unused instructions assumes that the DIE pass runs after + /// the flattening of the CFG, but if that's not the case then we should not eliminate + /// them just yet. + flattened: bool, } impl Context { @@ -172,7 +189,7 @@ impl Context { fn is_unused(&self, instruction_id: InstructionId, function: &Function) -> bool { let instruction = &function.dfg[instruction_id]; - if instruction.can_eliminate_if_unused(function) { + if instruction.can_eliminate_if_unused(function, self.flattened) { let results = function.dfg.instruction_results(instruction_id); results.iter().all(|result| !self.used_values.contains(result)) } else if let Instruction::Call { func, arguments } = instruction { @@ -966,7 +983,9 @@ mod test { "; let ssa = Ssa::from_str(src).unwrap(); - let ssa = ssa.dead_instruction_elimination(); + + // Even though these ACIR functions only have 1 block, we have not inlined and flattened anything yet. + let ssa = ssa.dead_instruction_elimination_inner(false); let expected = " acir(inline) fn main f0 { @@ -975,18 +994,17 @@ mod test { store v0 at v2 call f1(v2) v4 = load v2 -> Field - v5 = eq v4, v1 constrain v4 == v1 return } acir(inline) fn Add10 f1 { b0(v0: &mut Field): v1 = load v0 -> Field - v2 = add v1, Field 10 - store v2 at v0 + v3 = add v1, Field 10 + store v3 at v0 return } "; - assert_normalized_ssa_equals(ssa, &expected); + assert_normalized_ssa_equals(ssa, expected); } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 79181b7e74e..37718ac0904 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -990,7 +990,7 @@ fn brillig_bytecode_size(function: &Function) -> usize { simplify_between_unrolls(&mut temp); // This is to try to prevent hitting ICE. - temp.dead_instruction_elimination(false); + temp.dead_instruction_elimination(false, true); convert_ssa_function(&temp, false).byte_code.len() } From dd345472157c14ed61ec23c9bce9f2a6378db584 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 15:29:12 +0000 Subject: [PATCH 29/39] Re-enable the DIE --- compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index 515e73ac51a..6be3fec54a3 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -43,9 +43,7 @@ impl Ssa { function.simplify_function(); // Remove leftover instructions. - // XXX: Doing this would currently cause integration test failures, - // for example with `traits_in_crates_1` it eliminates a store to a mutable input reference. - // function.dead_instruction_elimination(true); + function.dead_instruction_elimination(true, false); // Put it back into the SSA, so the next functions can pick it up. self.functions.insert(id, function); From c8687456a9a65faffb4111677e71d572e18cb665 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 20:36:52 +0000 Subject: [PATCH 30/39] Call loop invariant motion --- compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs | 2 +- compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs index 125cf3a12ca..224916c95e9 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs @@ -36,7 +36,7 @@ impl Ssa { } impl Function { - fn loop_invariant_code_motion(&mut self) { + pub(super) fn loop_invariant_code_motion(&mut self) { Loops::find_all(self).hoist_loop_invariants(self); } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index 076ebab78ad..5671844fcf2 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -31,13 +31,14 @@ impl Ssa { let mut function = function.inlined(&self, false, &inline_infos); // Help unrolling determine bounds. function.as_slice_optimization(); + // Prepare for unrolling + function.loop_invariant_code_motion(); // We might not be able to unroll all loops without fully inlining them, so ignore errors. let _ = function.unroll_loops_iteratively(); // Reduce the number of redundant stores/loads after unrolling function.mem2reg(); // Try to reduce the number of blocks. function.simplify_function(); - // Remove leftover instructions. function.dead_instruction_elimination(true, false); From b6cb13f73711c8e79253310c44e77550516d5973 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 20:38:51 +0000 Subject: [PATCH 31/39] Improve comment --- compiler/noirc_evaluator/src/ssa/opt/inlining.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index abf69502eb3..1194e95746a 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -106,7 +106,8 @@ impl Function { !inline_type.is_entry_point() && !preserve_function } RuntimeType::Brillig(_) => { - // If the called function is brillig, we inline only if it's into brillig and the function is not recursive and not too costly. + // If the called function is brillig, we inline only if the caller is Brillig, + // and the function called wasn't ruled out as too costly to inline or recursive. self.runtime().is_brillig() && inline_infos .get(&called_func_id) From ce7e4125bac4715f7114bf9c29554a0c82231741 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 20:48:14 +0000 Subject: [PATCH 32/39] Rewrite compute_times_called in to use the output of compute_callers --- .../noirc_evaluator/src/ssa/opt/inlining.rs | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 1194e95746a..3a1bf0ab520 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -262,7 +262,8 @@ pub(super) fn compute_inline_infos( } } - let times_called = compute_times_called(ssa); + let callers = compute_callers(ssa); + let times_called = compute_times_called(&callers); mark_brillig_functions_to_retain( ssa, @@ -276,18 +277,16 @@ pub(super) fn compute_inline_infos( } /// Compute the time each function is called from any other function. -pub(super) fn compute_times_called(ssa: &Ssa) -> HashMap { - ssa.functions +fn compute_times_called( + callers: &BTreeMap>, +) -> HashMap { + callers .iter() - .flat_map(|(_caller_id, function)| { - let called_functions_vec = called_functions_vec(function); - called_functions_vec.into_iter() - }) - .chain(std::iter::once(ssa.main_id)) - .fold(HashMap::default(), |mut map, func_id| { - *map.entry(func_id).or_insert(0) += 1; - map + .map(|(callee, callers)| { + let total_calls = callers.iter().fold(0, |acc, (_caller, calls)| acc + calls); + (*callee, total_calls) }) + .collect() } /// Compute for each function the set of functions that call it, and how many times they do so. @@ -340,8 +339,13 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> { let mut order = Vec::new(); let mut visited = HashSet::new(); - // Number of times a function is called, to break cycles. - let mut times_called = compute_times_called(ssa).into_iter().collect::>(); + // Call graph which we'll repeatedly prune to find the "leaves". + let mut callees = compute_callees(ssa); + let callers = compute_callers(ssa); + + // Number of times a function is called, to break cycles in the call graph. + let mut times_called = compute_times_called(&callers).into_iter().collect::>(); + // Sort by number of calls ascending, so popping yields the next most called; break ties by ID. times_called.sort_by_key(|(id, cnt)| (*cnt, *id)); // Start with the weight of the functions in isolation, then accumulate as we pop off the ones they call. @@ -351,9 +355,6 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> { .map(|(id, f)| (*id, compute_function_own_weight(f))) .collect::>(); - let callers = compute_callers(ssa); - let mut callees = compute_callees(ssa); - // Seed the queue with functions that don't call anything. let mut queue = callees .iter() From a46c3e5c6571f7e47042917ff0ce0d885e2ee409 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 20:53:44 +0000 Subject: [PATCH 33/39] Reword comment --- compiler/noirc_frontend/src/tests.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 65390b20ed9..087e34fcc64 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -3457,9 +3457,11 @@ fn arithmetic_generics_rounding_fail_on_struct() { #[test] fn unconditional_recursion_fail() { - // These examples are self recursive top level functions, which actually - // would not be inlined now, but this error comes from the compilation checks, - // which is different from what the SSA would try to inline. + // These examples are self recursive top level functions, which would actually + // not be inlined in the SSA (there is nothing to inline into but self), so it + // wouldn't panic due to infinite recursion, but the errors asserted here + // come from the compilation checks, which does static analysis to catch the + // problem before it even has a chance to cause a panic. let srcs = vec![ r#" fn main() { From 4826682cc6688dda841dc30fdfd44b9ecebcd67f Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 21:05:08 +0000 Subject: [PATCH 34/39] Update compiler/noirc_evaluator/src/ssa/opt/inlining.rs Co-authored-by: Maxim Vezenov --- compiler/noirc_evaluator/src/ssa/opt/inlining.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 3a1bf0ab520..5ad5abb559b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -283,7 +283,7 @@ fn compute_times_called( callers .iter() .map(|(callee, callers)| { - let total_calls = callers.iter().fold(0, |acc, (_caller, calls)| acc + calls); + let total_calls = callers.values().sum(); (*callee, total_calls) }) .collect() From 04c8395326d3965193cbbcf2f54b1fa3c6c6939b Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Sat, 18 Jan 2025 07:47:08 +0000 Subject: [PATCH 35/39] Simplify loop --- compiler/noirc_evaluator/src/ssa/opt/inlining.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index da4fda68264..d7066371d5c 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -366,9 +366,6 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> { .collect::>(); loop { - if times_called.is_empty() && queue.is_empty() { - return order; - } while let Some(id) = queue.pop_front() { let weight = weights[&id]; order.push((id, weight)); @@ -387,15 +384,16 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> { } } // If we ran out of the queue, maybe there is a cycle; take the next most called function. - loop { - let Some((id, _)) = times_called.pop() else { - break; - }; + while let Some((id, _)) = times_called.pop() { if !visited.contains(&id) { queue.push_back(id); break; } } + if times_called.is_empty() && queue.is_empty() { + assert_eq!(order.len(), callers.len()); + return order; + } } } From 64e50f81f5c2a6cb4b704224f9f9d277669e9fcf Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Mon, 20 Jan 2025 12:30:48 +0000 Subject: [PATCH 36/39] Add test for order, tweak weights so the results on the test make sense --- .../noirc_evaluator/src/ssa/opt/inlining.rs | 122 +++++++++++++++++- cspell.json | 1 + 2 files changed, 116 insertions(+), 7 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index d7066371d5c..f73e986dd42 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -347,10 +347,16 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> { let mut callees = compute_callees(ssa); let callers = compute_callers(ssa); - // Number of times a function is called, to break cycles in the call graph. + // Number of times a function is called, used to break cycles in the call graph by popping the next candidate. let mut times_called = compute_times_called(&callers).into_iter().collect::>(); - // Sort by number of calls ascending, so popping yields the next most called; break ties by ID. - times_called.sort_by_key(|(id, cnt)| (*cnt, *id)); + times_called.sort_by_key(|(id, cnt)| { + // Sort by called the *least* by others, as these are less likely to cut the graph when removed. + let called_desc = -(*cnt as i64); + // Sort entries first (last to be popped). + let is_entry_asc = -called_desc.signum(); + // Finally break ties by ID. + (is_entry_asc, called_desc, *id) + }); // Start with the weight of the functions in isolation, then accumulate as we pop off the ones they call. let mut weights = ssa @@ -367,18 +373,29 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> { loop { while let Some(id) = queue.pop_front() { + // Pull the current weight of yet-to-be emitted callees (a nod to mutual recursion). + for (callee, cnt) in &callees[&id] { + if *callee != id { + weights[&id] = weights[&id].saturating_add(cnt.saturating_mul(weights[callee])); + } + } + // Own weight plus the weights accumulated from callees. let weight = weights[&id]; + + // Emit the function. order.push((id, weight)); visited.insert(id); + // Update the callers of this function. - for (caller, call_count) in &callers[&id] { + for (caller, cnt) in &callers[&id] { // Update the weight of the caller with the weight of this function. - weights[caller] = weights[caller].saturating_add(call_count.saturating_mul(weight)); + weights[caller] = weights[caller].saturating_add(cnt.saturating_mul(weight)); // Remove this function from the callees of the caller. let callees = callees.get_mut(caller).unwrap(); callees.remove(&id); - // If the caller doesn't call any other function, enqueue it. - if callees.is_empty() && !visited.contains(caller) { + // If the caller doesn't call any other function, enqueue it, + // unless it's the entry function, which is never called by anything, so it should be last. + if callees.is_empty() && !visited.contains(caller) && !callers[caller].is_empty() { queue.push_back(*caller); } } @@ -1060,6 +1077,8 @@ impl<'function> PerFunctionContext<'function> { #[cfg(test)] mod test { + use std::cmp::max; + use acvm::{acir::AcirField, FieldElement}; use noirc_frontend::monomorphization::ast::InlineType; @@ -1075,6 +1094,8 @@ mod test { Ssa, }; + use super::compute_bottom_up_order; + #[test] fn basic_inlining() { // fn foo { @@ -1453,4 +1474,91 @@ mod test { // No inlining has happened assert_eq!(inlined.functions.len(), 2); } + + #[test] + fn bottom_up_order_and_weights() { + let src = " + brillig(inline) fn main f0 { + b0(v0: u32, v1: u1): + v3 = call f2(v0) -> u1 + v4 = eq v3, v1 + constrain v3 == v1 + return + } + brillig(inline) fn is_even f1 { + b0(v0: u32): + v3 = eq v0, u32 0 + jmpif v3 then: b2, else: b1 + b1(): + v5 = call f3(v0) -> u32 + v7 = call f2(v5) -> u1 + jmp b3(v7) + b2(): + jmp b3(u1 1) + b3(v1: u1): + return v1 + } + brillig(inline) fn is_odd f2 { + b0(v0: u32): + v3 = eq v0, u32 0 + jmpif v3 then: b2, else: b1 + b1(): + v5 = call f3(v0) -> u32 + v7 = call f1(v5) -> u1 + jmp b3(v7) + b2(): + jmp b3(u1 0) + b3(v1: u1): + return v1 + } + brillig(inline) fn decr f3 { + b0(v0: u32): + v2 = sub v0, u32 1 + return v2 + } + "; + // main + // | + // V + // is_odd <-> is_even + // | | + // V V + // decr + + let ssa = Ssa::from_str(src).unwrap(); + let order = compute_bottom_up_order(&ssa); + + assert_eq!(order.len(), 4); + let (ids, ws): (Vec<_>, Vec<_>) = order.into_iter().map(|(id, w)| (id.to_u32(), w)).unzip(); + + // Check order + assert_eq!(ids[0], 3, "decr: first, it doesn't call anything"); + assert_eq!(ids[1], 1, "is_even: called by is_odd; removing first avoids cutting the graph"); + assert_eq!(ids[2], 2, "is_odd: called by is_odd and main"); + assert_eq!(ids[3], 0, "main: last, it's the entry"); + + // Check weights + assert_eq!(ws[0], 2, "decr"); + assert_eq!( + ws[1], + 7 + // own + ws[0] + // pushed from decr + (7 + ws[0]), // pulled from is_odd at the time is_even is emitted + "is_even" + ); + assert_eq!( + ws[2], + 7 + // own + ws[0] + // pushed from decr + ws[1], // pushed from is_even + "is_odd" + ); + assert_eq!( + ws[3], + 4 + // own + ws[2], // pushed from is_odd + "main" + ); + assert!(ws[3] > max(ws[1], ws[2]), "ideally 'main' has the most weight"); + } } diff --git a/cspell.json b/cspell.json index a42b90d2e8c..25a0cc91f52 100644 --- a/cspell.json +++ b/cspell.json @@ -205,6 +205,7 @@ "Secpr", "signedness", "signorecello", + "signum", "smallvec", "smol", "splitn", From 915bf252339823002754b6d1c94b1b7d5bbd1e7c Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Mon, 20 Jan 2025 12:35:40 +0000 Subject: [PATCH 37/39] Remove unused after preprocessing --- compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index 5671844fcf2..5a13ba8185e 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -46,6 +46,7 @@ impl Ssa { self.functions.insert(id, function); } - self + // Remove any functions that have been inlined into others already. + self.remove_unreachable_functions() } } From ed8b10b2798dd77b482109342a4d80e8bd062adb Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Mon, 20 Jan 2025 12:50:06 +0000 Subject: [PATCH 38/39] Remove --skip-preprocess-fns --- compiler/noirc_driver/src/lib.rs | 5 ----- compiler/noirc_evaluator/src/ssa.rs | 10 +--------- compiler/noirc_evaluator/src/ssa/opt/hint.rs | 1 - 3 files changed, 1 insertion(+), 15 deletions(-) diff --git a/compiler/noirc_driver/src/lib.rs b/compiler/noirc_driver/src/lib.rs index 2646b13a33a..a7e7e2d4e2f 100644 --- a/compiler/noirc_driver/src/lib.rs +++ b/compiler/noirc_driver/src/lib.rs @@ -141,10 +141,6 @@ pub struct CompileOptions { #[arg(long)] pub skip_brillig_constraints_check: bool, - /// Flag to turn off preprocessing functions during SSA passes. - #[arg(long)] - pub skip_preprocess_fns: bool, - /// Setting to decide on an inlining strategy for Brillig functions. /// A more aggressive inliner should generate larger programs but more optimized /// A less aggressive inliner should generate smaller programs @@ -683,7 +679,6 @@ pub fn compile_no_check( emit_ssa: if options.emit_ssa { Some(context.package_build_path.clone()) } else { None }, skip_underconstrained_check: options.skip_underconstrained_check, skip_brillig_constraints_check: options.skip_brillig_constraints_check, - skip_preprocess_fns: options.skip_preprocess_fns, inliner_aggressiveness: options.inliner_aggressiveness, max_bytecode_increase_percent: options.max_bytecode_increase_percent, }; diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index dcf401ab586..12ea04daebd 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -71,9 +71,6 @@ pub struct SsaEvaluatorOptions { /// Skip the missing Brillig call constraints check pub skip_brillig_constraints_check: bool, - /// Skip preprocessing functions. - pub skip_preprocess_fns: bool, - /// The higher the value, the more inlined Brillig functions will be. pub inliner_aggressiveness: i64, @@ -157,12 +154,7 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result Date: Mon, 20 Jan 2025 13:30:33 +0000 Subject: [PATCH 39/39] Do not skip heavy functions unless it mostly comes from its own weight --- .../noirc_evaluator/src/ssa/opt/inlining.rs | 50 +++++++++++-------- .../src/ssa/opt/preprocess_fns.rs | 26 ++++++---- 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index f73e986dd42..c3b771d9102 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -337,9 +337,9 @@ fn compute_callees(ssa: &Ssa) -> BTreeMap Vec<(FunctionId, usize)> { +/// Returns the functions paired with their own as well as transitive weight, +/// which accumulates the weight of all the functions they call, as well as own. +pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, (usize, usize))> { let mut order = Vec::new(); let mut visited = HashSet::new(); @@ -359,11 +359,12 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> { }); // Start with the weight of the functions in isolation, then accumulate as we pop off the ones they call. - let mut weights = ssa + let own_weights = ssa .functions .iter() .map(|(id, f)| (*id, compute_function_own_weight(f))) .collect::>(); + let mut weights = own_weights.clone(); // Seed the queue with functions that don't call anything. let mut queue = callees @@ -381,9 +382,10 @@ pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> { } // Own weight plus the weights accumulated from callees. let weight = weights[&id]; + let own_weight = own_weights[&id]; // Emit the function. - order.push((id, weight)); + order.push((id, (own_weight, weight))); visited.insert(id); // Update the callers of this function. @@ -1511,7 +1513,7 @@ mod test { b3(v1: u1): return v1 } - brillig(inline) fn decr f3 { + brillig(inline) fn decrement f3 { b0(v0: u32): v2 = sub v0, u32 1 return v2 @@ -1523,42 +1525,46 @@ mod test { // is_odd <-> is_even // | | // V V - // decr + // decrement let ssa = Ssa::from_str(src).unwrap(); let order = compute_bottom_up_order(&ssa); assert_eq!(order.len(), 4); let (ids, ws): (Vec<_>, Vec<_>) = order.into_iter().map(|(id, w)| (id.to_u32(), w)).unzip(); + let (ows, tws): (Vec<_>, Vec<_>) = ws.into_iter().unzip(); // Check order - assert_eq!(ids[0], 3, "decr: first, it doesn't call anything"); + assert_eq!(ids[0], 3, "decrement: first, it doesn't call anything"); assert_eq!(ids[1], 1, "is_even: called by is_odd; removing first avoids cutting the graph"); assert_eq!(ids[2], 2, "is_odd: called by is_odd and main"); assert_eq!(ids[3], 0, "main: last, it's the entry"); - // Check weights - assert_eq!(ws[0], 2, "decr"); + // Check own weights + assert_eq!(ows, [2, 7, 7, 4]); + + // Check transitive weights + assert_eq!(tws[0], ows[0], "decrement"); assert_eq!( - ws[1], - 7 + // own - ws[0] + // pushed from decr - (7 + ws[0]), // pulled from is_odd at the time is_even is emitted + tws[1], + ows[1] + // own + tws[0] + // pushed from decrement + (ows[2] + tws[0]), // pulled from is_odd at the time is_even is emitted "is_even" ); assert_eq!( - ws[2], - 7 + // own - ws[0] + // pushed from decr - ws[1], // pushed from is_even + tws[2], + ows[2] + // own + tws[0] + // pushed from decrement + tws[1], // pushed from is_even "is_odd" ); assert_eq!( - ws[3], - 4 + // own - ws[2], // pushed from is_odd + tws[3], + ows[3] + // own + tws[2], // pushed from is_odd "main" ); - assert!(ws[3] > max(ws[1], ws[2]), "ideally 'main' has the most weight"); + assert!(tws[3] > max(tws[1], tws[2]), "ideally 'main' has the most weight"); } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index 5a13ba8185e..439c2da5a2d 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -11,23 +11,29 @@ impl Ssa { let bottom_up = inlining::compute_bottom_up_order(&self); // As a heuristic to avoid optimizing functions near the entry point, find a cutoff weight. - let total_weight = bottom_up.iter().fold(0usize, |acc, (_, w)| acc.saturating_add(*w)); + let total_weight = + bottom_up.iter().fold(0usize, |acc, (_, (_, w))| (acc.saturating_add(*w))); let mean_weight = total_weight / bottom_up.len(); let cutoff_weight = mean_weight; // Preliminary inlining decisions. - // Functions which are inline targets will be processed in later passes. - // Here we want to treat the functions which will be inlined into them. let inline_infos = inlining::compute_inline_infos(&self, false, aggressiveness); - for (id, _) in bottom_up - .into_iter() - .filter(|(id, _)| { - inline_infos.get(id).map(|info| !info.is_inline_target()).unwrap_or(true) - }) - .filter(|(_, weight)| *weight < cutoff_weight) - { + for (id, (own_weight, transitive_weight)) in bottom_up { + // Skip preprocessing heavy functions that gained most of their weight from transitive accumulation. + // These can be processed later by the regular SSA passes. + if transitive_weight >= cutoff_weight && transitive_weight > own_weight * 2 { + continue; + } + // Functions which are inline targets will be processed in later passes. + // Here we want to treat the functions which will be inlined into them. + if let Some(info) = inline_infos.get(&id) { + if info.is_inline_target() { + continue; + } + } let function = &self.functions[&id]; + // Start with an inline pass. let mut function = function.inlined(&self, false, &inline_infos); // Help unrolling determine bounds. function.as_slice_optimization();