Skip to content

Commit

Permalink
chore: allow passing custom conditions to inlining pass (#7083)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench authored Jan 16, 2025
1 parent 197b02a commit 0512ee6
Showing 1 changed file with 59 additions and 65 deletions.
124 changes: 59 additions & 65 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,48 @@ impl Ssa {
/// This step should run after runtime separation, since it relies on the runtime of the called functions being final.
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn inline_functions(self, aggressiveness: i64) -> Ssa {
Self::inline_functions_inner(self, aggressiveness, false)
let inline_sources = get_functions_to_inline_into(&self, false, aggressiveness);
Self::inline_functions_inner(self, &inline_sources, false)
}

// Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points
pub(crate) fn inline_functions_with_no_predicates(self, aggressiveness: i64) -> Ssa {
Self::inline_functions_inner(self, aggressiveness, true)
let inline_sources = get_functions_to_inline_into(&self, true, aggressiveness);
Self::inline_functions_inner(self, &inline_sources, true)
}

fn inline_functions_inner(
mut self,
aggressiveness: i64,
inline_sources: &BTreeSet<FunctionId>,
inline_no_predicates_functions: bool,
) -> Ssa {
let inline_sources =
get_functions_to_inline_into(&self, inline_no_predicates_functions, aggressiveness);
self.functions = btree_map(&inline_sources, |entry_point| {
let new_function = InlineContext::new(
&self,
*entry_point,
inline_no_predicates_functions,
inline_sources.clone(),
)
.inline_all(&self);
// Note that we clear all functions other than those in `inline_sources`.
// If we decide to do partial inlining then we should change this to preserve those functions which still exist.
self.functions = btree_map(inline_sources, |entry_point| {
let should_inline_call =
|_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool {
let function = &ssa.functions[&called_func_id];

match function.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && function.is_no_predicates();
!inline_type.is_entry_point() && !preserve_function
}
RuntimeType::Brillig(_) => {
// If the called function is brillig, we inline only if it's into brillig and the function is not recursive
ssa.functions[entry_point].runtime().is_brillig()
&& !inline_sources.contains(&called_func_id)
}
}
};

let new_function =
InlineContext::new(&self, *entry_point).inline_all(&self, &should_inline_call);
(*entry_point, new_function)
});
self
Expand All @@ -88,16 +107,6 @@ struct InlineContext {

// The FunctionId of the entry point function we're inlining into in the old, unmodified Ssa.
entry_point: FunctionId,

/// Whether the inlining pass should inline any functions marked with [`InlineType::NoPredicates`]
/// or whether these should be preserved as entrypoint functions.
///
/// This is done as we delay inlining of functions with the attribute `#[no_predicates]` until after
/// the control flow graph has been flattened.
inline_no_predicates_functions: bool,

// These are the functions of the program that we shouldn't inline.
functions_not_to_inline: BTreeSet<FunctionId>,
}

/// The per-function inlining context contains information that is only valid for one function.
Expand Down Expand Up @@ -355,32 +364,23 @@ impl InlineContext {
/// The function being inlined into will always be the main function, although it is
/// actually a copy that is created in case the original main is still needed from a function
/// that could not be inlined calling it.
fn new(
ssa: &Ssa,
entry_point: FunctionId,
inline_no_predicates_functions: bool,
functions_not_to_inline: BTreeSet<FunctionId>,
) -> Self {
fn new(ssa: &Ssa, entry_point: FunctionId) -> Self {
let source = &ssa.functions[&entry_point];
let mut builder = FunctionBuilder::new(source.name().to_owned(), entry_point);
builder.set_runtime(source.runtime());
builder.current_function.set_globals(source.dfg.globals.clone());

Self {
builder,
recursion_level: 0,
entry_point,
call_stack: CallStackId::root(),
inline_no_predicates_functions,
functions_not_to_inline,
}
Self { builder, recursion_level: 0, entry_point, call_stack: CallStackId::root() }
}

/// Start inlining the entry point function and all functions reachable from it.
fn inline_all(mut self, ssa: &Ssa) -> Function {
fn inline_all(
mut self,
ssa: &Ssa,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
) -> Function {
let entry_point = &ssa.functions[&self.entry_point];

// let globals = self.globals;
let mut context = PerFunctionContext::new(&mut self, entry_point, &ssa.globals);
context.inlining_entry = true;

Expand All @@ -401,7 +401,7 @@ impl InlineContext {
}

context.blocks.insert(context.source_function.entry_block(), entry_block);
context.inline_blocks(ssa);
context.inline_blocks(ssa, should_inline_call);
// translate databus values
let databus = entry_point.dfg.data_bus.map_values(|t| context.translate_value(t));

Expand All @@ -420,6 +420,7 @@ impl InlineContext {
ssa: &Ssa,
id: FunctionId,
arguments: &[ValueId],
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
) -> Vec<ValueId> {
self.recursion_level += 1;

Expand All @@ -440,7 +441,7 @@ impl InlineContext {
let current_block = context.context.builder.current_block();
context.blocks.insert(source_function.entry_block(), current_block);

let return_values = context.inline_blocks(ssa);
let return_values = context.inline_blocks(ssa, should_inline_call);
self.recursion_level -= 1;
return_values
}
Expand Down Expand Up @@ -568,7 +569,11 @@ impl<'function> PerFunctionContext<'function> {
}

/// Inline all reachable blocks within the source_function into the destination function.
fn inline_blocks(&mut self, ssa: &Ssa) -> Vec<ValueId> {
fn inline_blocks(
&mut self,
ssa: &Ssa,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
) -> Vec<ValueId> {
let mut seen_blocks = HashSet::new();
let mut block_queue = VecDeque::new();
block_queue.push_back(self.source_function.entry_block());
Expand All @@ -585,7 +590,7 @@ impl<'function> PerFunctionContext<'function> {
self.context.builder.switch_to_block(translated_block_id);

seen_blocks.insert(source_block_id);
self.inline_block_instructions(ssa, source_block_id);
self.inline_block_instructions(ssa, source_block_id, should_inline_call);

if let Some((block, values)) =
self.handle_terminator_instruction(source_block_id, &mut block_queue)
Expand Down Expand Up @@ -630,16 +635,21 @@ impl<'function> PerFunctionContext<'function> {

/// Inline each instruction in the given block into the function being inlined into.
/// This may recurse if it finds another function to inline if a call instruction is within this block.
fn inline_block_instructions(&mut self, ssa: &Ssa, block_id: BasicBlockId) {
fn inline_block_instructions(
&mut self,
ssa: &Ssa,
block_id: BasicBlockId,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
) {
let mut side_effects_enabled: Option<ValueId> = None;

let block = &self.source_function.dfg[block_id];
for id in block.instructions() {
match &self.source_function.dfg[*id] {
Instruction::Call { func, arguments } => match self.get_function(*func) {
Some(func_id) => {
if self.should_inline_call(ssa, func_id) {
self.inline_function(ssa, *id, func_id, arguments);
if should_inline_call(self, ssa, func_id) {
self.inline_function(ssa, *id, func_id, arguments, should_inline_call);

// This is only relevant during handling functions with `InlineType::NoPredicates` as these
// can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`,
Expand Down Expand Up @@ -667,31 +677,14 @@ impl<'function> PerFunctionContext<'function> {
}
}

fn should_inline_call(&self, ssa: &Ssa, called_func_id: FunctionId) -> bool {
let function = &ssa.functions[&called_func_id];

if let RuntimeType::Acir(inline_type) = function.runtime() {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!self.context.inline_no_predicates_functions && function.is_no_predicates();
!inline_type.is_entry_point() && !preserve_function
} else {
// If the called function is brillig, we inline only if it's into brillig and the function is not recursive
matches!(ssa.functions[&self.context.entry_point].runtime(), RuntimeType::Brillig(_))
&& !self.context.functions_not_to_inline.contains(&called_func_id)
}
}

/// Inline a function call and remember the inlined return values in the values map
fn inline_function(
&mut self,
ssa: &Ssa,
call_id: InstructionId,
function: FunctionId,
arguments: &[ValueId],
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
) {
let old_results = self.source_function.dfg.instruction_results(call_id);
let arguments = vecmap(arguments, |arg| self.translate_value(*arg));
Expand All @@ -707,7 +700,8 @@ impl<'function> PerFunctionContext<'function> {
.extend_call_stack(self.context.call_stack, &call_stack);

self.context.call_stack = new_call_stack;
let new_results = self.context.inline_function(ssa, function, &arguments);
let new_results =
self.context.inline_function(ssa, function, &arguments, should_inline_call);
self.context.call_stack = self
.context
.builder
Expand Down

0 comments on commit 0512ee6

Please sign in to comment.