Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ssa): Pass to preprocess functions #7072

Merged
merged 51 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
98dc167
Add pass to preprocess 'serialize'
aakoshh Jan 15, 2025
6943765
Simplify a CFG and run DIE as well
aakoshh Jan 15, 2025
ed56da8
Merge branch 'master' into af/ssa-preprocess-fns
TomAFrench Jan 15, 2025
850fc11
Add docstrings to inlining
aakoshh Jan 15, 2025
e5c49af
See what happens if we go bottom up
aakoshh Jan 15, 2025
ea4633c
Merge branch 'af/ssa-preprocess-fns' of github.com:noir-lang/noir int…
aakoshh Jan 15, 2025
13b42f2
Use a cutoff weight to skip processing large functions
aakoshh Jan 16, 2025
a9d754f
Merge remote-tracking branch 'origin/master' into af/ssa-preprocess-fns
aakoshh Jan 16, 2025
8f7eb9c
Add CLI option to turn off preprocessing
aakoshh Jan 16, 2025
02cb856
Don't call DIE, to fix the tests
aakoshh Jan 16, 2025
5aa3db7
Fix field name
aakoshh Jan 16, 2025
3c3e215
Remove restore_on_error from unrolling
aakoshh Jan 16, 2025
1b9ffde
Remove prints
aakoshh Jan 16, 2025
1e07b7a
Fix clippy
aakoshh Jan 16, 2025
de5176a
Do not inline self-recursive entries
aakoshh Jan 16, 2025
deaa311
Rename pass to Preprocess from Pre-process
aakoshh Jan 16, 2025
80872e5
Fix inlining recursion test
aakoshh Jan 16, 2025
f337a04
Refactor to use InlineInfo
aakoshh Jan 16, 2025
ca7ba9e
Fix remove unreachable functions to remove uncalled Brillig
aakoshh Jan 16, 2025
64841c4
Fix defunctionalization to inherit runtime of caller
aakoshh Jan 17, 2025
e9f7bad
Refactor defunctionalization to not create intermediate blocks before…
aakoshh Jan 17, 2025
8cf44ad
Merge branch 'master' into af/ssa-preprocess-fns
aakoshh Jan 17, 2025
da1fe6a
Create test to show defunctionalization not handling runtime
aakoshh Jan 17, 2025
7f2cdef
Fix SSA parser to handle function as value
aakoshh Jan 17, 2025
bd4f90b
Fix defunctionalization to inherit runtime of caller
aakoshh Jan 17, 2025
cfdbcad
Add another test to check 2 runtimes are created
aakoshh Jan 17, 2025
0532794
Merge branch 'master' into fix-defunctionalize-runtime
aakoshh Jan 17, 2025
1b7dc2e
fix: Simplify defunctionalize return (#7101)
aakoshh Jan 17, 2025
8102dd2
.
TomAFrench Jan 17, 2025
1405733
Add test with expected SSA
aakoshh Jan 17, 2025
f206178
.
TomAFrench Jan 17, 2025
939cf84
chore: fix tests
TomAFrench Jan 17, 2025
fc36f52
Add flag to tell the DIE not to remove STORE yet
aakoshh Jan 17, 2025
c00613e
Merge branch '7104-fix-die-mut-ref-param' into af/ssa-preprocess-fns
aakoshh Jan 17, 2025
d6c2318
Merge branch 'af/ssa-preprocess-fns' of github.com:noir-lang/noir int…
aakoshh Jan 17, 2025
dd34547
Re-enable the DIE
aakoshh Jan 17, 2025
9ec8d62
Merge remote-tracking branch 'origin/master' into af/ssa-preprocess-fns
aakoshh Jan 17, 2025
c868745
Call loop invariant motion
aakoshh Jan 17, 2025
b6cb13f
Improve comment
aakoshh Jan 17, 2025
ce7e412
Rewrite compute_times_called in to use the output of compute_callers
aakoshh Jan 17, 2025
a46c3e5
Reword comment
aakoshh Jan 17, 2025
4826682
Update compiler/noirc_evaluator/src/ssa/opt/inlining.rs
aakoshh Jan 17, 2025
44dbfd2
Merge remote-tracking branch 'origin/fix-defunctionalize-runtime' int…
aakoshh Jan 17, 2025
382280d
Merge branch 'af/ssa-preprocess-fns' of github.com:noir-lang/noir int…
aakoshh Jan 17, 2025
b7183a8
Merge remote-tracking branch 'origin/master' into af/ssa-preprocess-fns
aakoshh Jan 17, 2025
04c8395
Simplify loop
aakoshh Jan 18, 2025
64e50f8
Add test for order, tweak weights so the results on the test make sense
aakoshh Jan 20, 2025
915bf25
Remove unused after preprocessing
aakoshh Jan 20, 2025
ed8b10b
Remove --skip-preprocess-fns
aakoshh Jan 20, 2025
8f0fcc9
Do not skip heavy functions unless it mostly comes from its own weight
aakoshh Jan 20, 2025
a506fab
Merge branch 'master' into af/ssa-preprocess-fns
aakoshh Jan 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result<Ss
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions")
.run_pass(Ssa::defunctionalize, "Defunctionalization")
.run_pass(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs")
.run_pass(
|ssa| {
ssa.preprocess_functions(
options.inliner_aggressiveness,
options.max_bytecode_increase_percent,
)
},
"Pre-processing Functions",
)
.run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "Inlining (1st)")
// Run mem2reg with the CFG separated into blocks
.run_pass(Ssa::mem2reg, "Mem2Reg (1st)")
Expand Down Expand Up @@ -479,6 +488,12 @@ impl SsaBuilder {
}

fn print(mut self, msg: &str) -> Self {
println!("AFTER {msg}: functions={}", self.ssa.functions.len());
for f in self.ssa.functions.values() {
let block_cnt = ir::post_order::PostOrder::with_function(f).into_vec().len();
println!(" fn {} {}: blocks={block_cnt}", f.name(), f.id());
}

let print_ssa_pass = match &self.ssa_logging {
SsaLogging::None => false,
SsaLogging::All => true,
Expand Down
191 changes: 160 additions & 31 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
//! The purpose of this pass is to inline the instructions of each function call
//! within the function caller. If all function calls are known, there will only
//! be a single function remaining when the pass finishes.
use std::collections::{BTreeSet, HashSet, VecDeque};
use std::collections::{BTreeMap, BTreeSet, HashSet, VecDeque};

use acvm::acir::AcirField;
use im::HashMap;
use iter_extended::{btree_map, vecmap};

use crate::ssa::{
Expand All @@ -19,7 +20,6 @@ use crate::ssa::{
},
ssa_gen::Ssa,
};
use fxhash::FxHashMap as HashMap;

/// An arbitrary limit to the maximum number of recursive call
/// frames at any point in time.
Expand Down Expand Up @@ -50,7 +50,7 @@ impl Ssa {
Self::inline_functions_inner(self, &inline_sources, false)
}

// Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points
/// Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points
pub(crate) fn inline_functions_with_no_predicates(self, aggressiveness: i64) -> Ssa {
let inline_sources = get_functions_to_inline_into(&self, true, aggressiveness);
Self::inline_functions_inner(self, &inline_sources, true)
Expand All @@ -61,39 +61,52 @@ impl Ssa {
inline_sources: &BTreeSet<FunctionId>,
inline_no_predicates_functions: bool,
) -> Ssa {
// Note that we clear all functions other than those in `inline_sources`.
// If we decide to do partial inlining then we should change this to preserve those functions which still exist.
// NOTE: Functions are processed independently of each other, with the final mapping replacing the original,
// instead of inlining the "leaf" functions, moving up towards the entry point.
self.functions = btree_map(inline_sources, |entry_point| {
let should_inline_call =
|_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool {
let function = &ssa.functions[&called_func_id];

match function.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && function.is_no_predicates();
!inline_type.is_entry_point() && !preserve_function
}
RuntimeType::Brillig(_) => {
// If the called function is brillig, we inline only if it's into brillig and the function is not recursive
ssa.functions[entry_point].runtime().is_brillig()
&& !inline_sources.contains(&called_func_id)
}
}
};

let function = &self.functions[entry_point];
let new_function =
InlineContext::new(&self, *entry_point).inline_all(&self, &should_inline_call);
function.inlined(&self, inline_no_predicates_functions, inline_sources);
(*entry_point, new_function)
});
self
}
}

impl Function {
/// Create a new function which has the functions called by this one inlined into its body.
pub(super) fn inlined(
&self,
ssa: &Ssa,
inline_no_predicates_functions: bool,
functions_not_to_inline: &BTreeSet<FunctionId>,
) -> Function {
let should_inline_call =
|_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool {
let function = &ssa.functions[&called_func_id];

match function.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && function.is_no_predicates();
!inline_type.is_entry_point() && !preserve_function
}
RuntimeType::Brillig(_) => {
// If the called function is brillig, we inline only if it's into brillig and the function is not recursive
self.runtime().is_brillig()
&& !functions_not_to_inline.contains(&called_func_id)
}
}
};

InlineContext::new(ssa, self.id()).inline_all(ssa, &should_inline_call)
}
}

/// The context for the function inlining pass.
///
/// This works using an internal FunctionBuilder to build a new main function from scratch.
Expand Down Expand Up @@ -143,6 +156,8 @@ struct PerFunctionContext<'function> {
}

/// Utility function to find out the direct calls of a function.
///
/// Returns the function IDs from all `Call` instructions without deduplication.
fn called_functions_vec(func: &Function) -> Vec<FunctionId> {
let mut called_function_ids = Vec::new();
for block_id in func.reachable_blocks() {
Expand All @@ -160,7 +175,7 @@ fn called_functions_vec(func: &Function) -> Vec<FunctionId> {
called_function_ids
}

/// Utility function to find out the deduplicated direct calls of a function.
/// Utility function to find out the deduplicated direct calls made from a function.
fn called_functions(func: &Function) -> BTreeSet<FunctionId> {
called_functions_vec(func).into_iter().collect()
}
Expand All @@ -170,7 +185,7 @@ fn called_functions(func: &Function) -> BTreeSet<FunctionId> {
/// - Any Brillig function called from Acir
/// - Some Brillig functions depending on aggressiveness and some metrics
/// - Any Acir functions with a [fold inline type][InlineType::Fold],
fn get_functions_to_inline_into(
pub(super) fn get_functions_to_inline_into(
ssa: &Ssa,
inline_no_predicates_functions: bool,
aggressiveness: i64,
Expand Down Expand Up @@ -220,7 +235,8 @@ fn get_functions_to_inline_into(
.collect()
}

fn compute_times_called(ssa: &Ssa) -> HashMap<FunctionId, usize> {
/// Compute the time each function is called from any other function.
pub(super) fn compute_times_called(ssa: &Ssa) -> HashMap<FunctionId, usize> {
ssa.functions
.iter()
.flat_map(|(_caller_id, function)| {
Expand All @@ -234,10 +250,118 @@ fn compute_times_called(ssa: &Ssa) -> HashMap<FunctionId, usize> {
})
}

/// Compute for each function the set of functions that call it, and how many times they do so.
fn compute_callers(ssa: &Ssa) -> BTreeMap<FunctionId, BTreeMap<FunctionId, usize>> {
ssa.functions
.iter()
.flat_map(|(caller_id, function)| {
let called_functions = called_functions_vec(function);
called_functions.into_iter().map(|callee_id| (*caller_id, callee_id))
})
.fold(
// Make sure an entry exists even for ones that don't get called.
ssa.functions.keys().map(|id| (*id, BTreeMap::new())).collect(),
|mut acc, (caller_id, callee_id)| {
let callers = acc.entry(callee_id).or_default();
*callers.entry(caller_id).or_default() += 1;
acc
},
)
}

/// Compute for each function the set of functions called by it, and how many times it does so.
fn compute_callees(ssa: &Ssa) -> BTreeMap<FunctionId, BTreeMap<FunctionId, usize>> {
ssa.functions
.iter()
.flat_map(|(caller_id, function)| {
let called_functions = called_functions_vec(function);
called_functions.into_iter().map(|callee_id| (*caller_id, callee_id))
})
.fold(
// Make sure an entry exists even for ones that don't call anything.
ssa.functions.keys().map(|id| (*id, BTreeMap::new())).collect(),
|mut acc, (caller_id, callee_id)| {
let callees = acc.entry(caller_id).or_default();
*callees.entry(callee_id).or_default() += 1;
acc
},
)
}

/// Compute something like a topological order of the functions, starting with the ones
/// that do not call any other functions, going towards the entry points. When cycles
/// are detected, take the one which are called by the most to break the ties.
///
/// This can be used to simplify the most often called functions first.
///
/// Returns the functions paired with their transitive weight, which accumulates
/// the weight of all the functions they call.
pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> {
let mut order = Vec::new();
let mut visited = HashSet::new();

// Number of times a function is called, to break cycles.
let mut times_called = compute_times_called(ssa).into_iter().collect::<Vec<_>>();
times_called.sort_by_key(|(id, cnt)| (*cnt, *id));

// Start with the weight of the functions in isolation, then accumulate as we pop off the ones they call.
let mut weights = ssa
.functions
.iter()
.map(|(id, f)| (*id, compute_function_own_weight(f)))
.collect::<HashMap<_, _>>();

let callers = compute_callers(ssa);
let mut callees = compute_callees(ssa);

// Seed the queue with functions that don't call anything.
let mut queue = callees
.iter()
.filter_map(|(id, callees)| callees.is_empty().then_some(*id))
.collect::<VecDeque<_>>();

loop {
if times_called.is_empty() && queue.is_empty() {
return order;
}
while let Some(id) = queue.pop_front() {
let weight = weights[&id];
order.push((id, weight));
visited.insert(id);
// Update the callers of this function.
for (caller, call_count) in &callers[&id] {
// Update the weight of the caller with the weight of this function.
weights[caller] = weights[caller].saturating_add(call_count.saturating_mul(weight));
// Remove this function from the callees of the caller.
let callees = callees.get_mut(caller).unwrap();
callees.remove(&id);
// If the caller doesn't call any other function, enqueue it.
if callees.is_empty() && !visited.contains(caller) {
queue.push_back(*caller);
}
}
}
// If we ran out of the queue, maybe there is a cycle; take the next most called function.
loop {
let Some((id, _)) = times_called.pop() else {
break;
};
if !visited.contains(&id) {
queue.push_back(id);
break;
}
}
}
}

/// Traverse the call graph starting from a given function, marking function to be retained if they are:
/// * recursive functions, or
/// * the cost of inlining outweighs the cost of not doing so
fn should_retain_recursive(
ssa: &Ssa,
func: FunctionId,
times_called: &HashMap<FunctionId, usize>,
// FunctionId -> (should_retain, weight)
should_retain_function: &mut HashMap<FunctionId, (bool, i64)>,
mut explored_functions: im::HashSet<FunctionId>,
inline_no_predicates_functions: bool,
Expand Down Expand Up @@ -278,6 +402,8 @@ fn should_retain_recursive(
// And the interface cost of the function (the inherent cost at the callsite, roughly the number of args and returns)
// We then can compute an approximation of the cost of inlining vs the cost of retaining the function
// We do this computation using saturating i64s to avoid overflows

// Total weight of functions called by this one, unless we decided not to inline them.
let inlined_function_weights: i64 = called_functions.iter().fold(0, |acc, called_function| {
let (should_retain, weight) = should_retain_function[called_function];
if should_retain {
Expand Down Expand Up @@ -309,6 +435,7 @@ fn should_retain_recursive(
should_retain_function.insert(func, (!should_inline, this_function_weight));
}

/// Gather the functions that should not be inlined.
fn compute_functions_to_retain(
ssa: &Ssa,
entry_points: &BTreeSet<FunctionId>,
Expand Down Expand Up @@ -344,6 +471,7 @@ fn compute_functions_to_retain(
.collect()
}

/// Compute a weight of a function based on the number of instructions in its reachable blocks.
fn compute_function_own_weight(func: &Function) -> usize {
let mut weight = 0;
for block_id in func.reachable_blocks() {
Expand All @@ -354,6 +482,7 @@ fn compute_function_own_weight(func: &Function) -> usize {
weight
}

/// Compute interface cost of a function based on the number of inputs and outputs.
fn compute_function_interface_cost(func: &Function) -> usize {
func.parameters().len() + func.returns().len()
}
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_evaluator/src/ssa/opt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod inlining;
mod loop_invariant;
mod mem2reg;
mod normalize_value_ids;
mod preprocess_fns;
mod rc;
mod remove_bit_shifts;
mod remove_enable_side_effects;
Expand Down
52 changes: 52 additions & 0 deletions compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//! Pre-process functions before inlining them into others.

use crate::ssa::Ssa;

use super::inlining;

impl Ssa {
/// Run pre-processing steps on functions in isolation.
pub(crate) fn preprocess_functions(
mut self,
aggressiveness: i64,
max_bytecode_increase_percent: Option<i32>,
) -> Ssa {
// No point pre-processing the functions that will never be inlined into others.
let not_to_inline = inlining::get_functions_to_inline_into(&self, false, aggressiveness);
// Bottom-up order, starting with the "leaf" functions, so we inline already optimized code into the ones that call them.
let bottom_up = inlining::compute_bottom_up_order(&self);

// As a heuristic to avoid optimizing functions near the entry point, find a cutoff weight.
let total_weight = bottom_up.iter().fold(0usize, |acc, (_, w)| acc.saturating_add(*w));
let mean_weight = total_weight / bottom_up.len();
let cutoff_weight = mean_weight;

for (id, weight) in bottom_up.into_iter().filter(|(id, _)| !not_to_inline.contains(id)) {
let function = &self.functions[&id];
if weight <= cutoff_weight {
println!("PREPROCESSING fn {} {id} with weight {weight}", function.name());
} else {
println!(
"SKIP PREPROCESSING fn {} {id} with weight {weight} > {cutoff_weight}",
function.name()
);
continue;
}
let mut function = function.inlined(&self, false, &not_to_inline);
// Help unrolling determine bounds.
function.as_slice_optimization();
// We might not be able to unroll all loops without fully inlining them, so ignore errors.
let _ = function.try_unroll_loops_iteratively(max_bytecode_increase_percent, true);
// Reduce the number of redundant stores/loads after unrolling
function.mem2reg();
// Try to reduce the number of blocks.
function.simplify_function();
// Remove leftover instructions.
function.dead_instruction_elimination(true);
// Put it back into the SSA, so the next functions can pick it up.
self.functions.insert(id, function);
}

self
}
}
Loading
Loading