diff --git a/.aztec-sync-commit b/.aztec-sync-commit index dc4d9c58bda..063664dceac 100644 --- a/.aztec-sync-commit +++ b/.aztec-sync-commit @@ -1 +1 @@ -58761fcf181d0a26c4e387105b1bc8a3a75b0368 +b8bace9a00c3a8eb93f42682e8cbfa351fc5238c diff --git a/Cargo.lock b/Cargo.lock index 6b24c0f8c67..0f4866f8c7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -805,9 +805,9 @@ dependencies = [ [[package]] name = "clap_complete" -version = "4.5.37" +version = "4.5.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11611dca53440593f38e6b25ec629de50b14cdfa63adc0fb856115a2c6d97595" +checksum = "d9647a559c112175f17cf724dc72d3645680a883c58481332779192b0d8e7a01" dependencies = [ "clap", ] diff --git a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs index 9d0d5b3af60..0a55e4ca17c 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs @@ -50,31 +50,36 @@ impl MergeExpressionsOptimizer { let mut new_circuit = Vec::new(); let mut new_acir_opcode_positions = Vec::new(); // For each opcode, try to get a target opcode to merge with - for (i, opcode) in circuit.opcodes.iter().enumerate() { + for (i, (opcode, opcode_position)) in + circuit.opcodes.iter().zip(acir_opcode_positions).enumerate() + { if !matches!(opcode, Opcode::AssertZero(_)) { new_circuit.push(opcode.clone()); - new_acir_opcode_positions.push(acir_opcode_positions[i]); + new_acir_opcode_positions.push(opcode_position); continue; } let opcode = modified_gates.get(&i).unwrap_or(opcode).clone(); let mut to_keep = true; let input_witnesses = self.witness_inputs(&opcode); - for w in input_witnesses.clone() { - let empty_gates = BTreeSet::new(); - let gates_using_w = used_witness.get(&w).unwrap_or(&empty_gates); + for w in input_witnesses { + let Some(gates_using_w) = used_witness.get(&w) else { + continue; + }; // We only consider witness which are used in exactly two arithmetic gates if gates_using_w.len() == 2 { - let gates_using_w: Vec<_> = gates_using_w.iter().collect(); - let mut b = *gates_using_w[1]; - if b == i { - b = *gates_using_w[0]; + let first = *gates_using_w.first().expect("gates_using_w.len == 2"); + let second = *gates_using_w.last().expect("gates_using_w.len == 2"); + let b = if second == i { + first } else { // sanity check - assert!(i == *gates_using_w[0]); - } - let second_gate = modified_gates.get(&b).unwrap_or(&circuit.opcodes[b]).clone(); + assert!(i == first); + second + }; + + let second_gate = modified_gates.get(&b).unwrap_or(&circuit.opcodes[b]); if let (Opcode::AssertZero(expr_define), Opcode::AssertZero(expr_use)) = - (opcode.clone(), second_gate) + (&opcode, second_gate) { // We cannot merge an expression into an earlier opcode, because this // would break the 'execution ordering' of the opcodes @@ -85,16 +90,15 @@ impl MergeExpressionsOptimizer { // - doing this pass again until there is no change, or // - merging 'b' into 'i' instead if i < b { - if let Some(expr) = Self::merge(&expr_use, &expr_define, w) { + if let Some(expr) = Self::merge(expr_use, expr_define, w) { modified_gates.insert(b, Opcode::AssertZero(expr)); to_keep = false; // Update the 'used_witness' map to account for the merge. - for w2 in CircuitSimulator::expr_wit(&expr_define) { + for w2 in CircuitSimulator::expr_wit(expr_define) { if !circuit_inputs.contains(&w2) { - let mut v = used_witness[&w2].clone(); + let v = used_witness.entry(w2).or_default(); v.insert(b); v.remove(&i); - used_witness.insert(w2, v); } } // We need to stop here and continue with the next opcode @@ -107,12 +111,9 @@ impl MergeExpressionsOptimizer { } if to_keep { - if modified_gates.contains_key(&i) { - new_circuit.push(modified_gates[&i].clone()); - } else { - new_circuit.push(opcode.clone()); - } - new_acir_opcode_positions.push(acir_opcode_positions[i]); + let opcode = modified_gates.get(&i).cloned().unwrap_or(opcode); + new_circuit.push(opcode); + new_acir_opcode_positions.push(opcode_position); } } (new_circuit, new_acir_opcode_positions) diff --git a/compiler/noirc_evaluator/src/acir/mod.rs b/compiler/noirc_evaluator/src/acir/mod.rs index 5c7899b5035..46d0924b322 100644 --- a/compiler/noirc_evaluator/src/acir/mod.rs +++ b/compiler/noirc_evaluator/src/acir/mod.rs @@ -24,12 +24,10 @@ mod big_int; mod brillig_directive; mod generated_acir; +use crate::brillig::brillig_gen::gen_brillig_for; use crate::brillig::{ brillig_gen::brillig_fn::FunctionContext as BrilligFunctionContext, - brillig_ir::{ - artifact::{BrilligParameter, GeneratedBrillig}, - BrilligContext, - }, + brillig_ir::artifact::{BrilligParameter, GeneratedBrillig}, Brillig, }; use crate::errors::{InternalError, InternalWarning, RuntimeError, SsaReport}; @@ -518,7 +516,7 @@ impl<'a> Context<'a> { let outputs: Vec = vecmap(main_func.returns(), |result_id| dfg.type_of_value(*result_id).into()); - let code = self.gen_brillig_for(main_func, arguments.clone(), brillig)?; + let code = gen_brillig_for(main_func, arguments.clone(), brillig)?; // We specifically do not attempt execution of the brillig code being generated as this can result in it being // replaced with constraints on witnesses to the program outputs. @@ -878,8 +876,7 @@ impl<'a> Context<'a> { None, )? } else { - let code = - self.gen_brillig_for(func, arguments.clone(), brillig)?; + let code = gen_brillig_for(func, arguments.clone(), brillig)?; let generated_pointer = self.shared_context.new_generated_pointer(); let output_values = self.acir_context.brillig_call( @@ -999,47 +996,6 @@ impl<'a> Context<'a> { .collect() } - fn gen_brillig_for( - &self, - func: &Function, - arguments: Vec, - brillig: &Brillig, - ) -> Result, InternalError> { - // Create the entry point artifact - let mut entry_point = BrilligContext::new_entry_point_artifact( - arguments, - BrilligFunctionContext::return_values(func), - func.id(), - ); - entry_point.name = func.name().to_string(); - - // Link the entry point with all dependencies - while let Some(unresolved_fn_label) = entry_point.first_unresolved_function_call() { - let artifact = &brillig.find_by_label(unresolved_fn_label.clone()); - let artifact = match artifact { - Some(artifact) => artifact, - None => { - return Err(InternalError::General { - message: format!("Cannot find linked fn {unresolved_fn_label}"), - call_stack: CallStack::new(), - }) - } - }; - entry_point.link_with(artifact); - // Insert the range of opcode locations occupied by a procedure - if let Some(procedure_id) = &artifact.procedure { - let num_opcodes = entry_point.byte_code.len(); - let previous_num_opcodes = entry_point.byte_code.len() - artifact.byte_code.len(); - // We subtract one as to keep the range inclusive on both ends - entry_point - .procedure_locations - .insert(procedure_id.clone(), (previous_num_opcodes, num_opcodes - 1)); - } - } - // Generate the final bytecode - Ok(entry_point.finish()) - } - /// Handles an ArrayGet or ArraySet instruction. /// To set an index of the array (and create a new array in doing so), pass Some(value) for /// store_value. To just retrieve an index of the array, pass None for store_value. diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen.rs index 786a03031d6..ca4e783aa93 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen.rs @@ -9,11 +9,17 @@ mod variable_liveness; use acvm::FieldElement; use self::{brillig_block::BrilligBlock, brillig_fn::FunctionContext}; -use super::brillig_ir::{ - artifact::{BrilligArtifact, Label}, - BrilligContext, +use super::{ + brillig_ir::{ + artifact::{BrilligArtifact, BrilligParameter, GeneratedBrillig, Label}, + BrilligContext, + }, + Brillig, +}; +use crate::{ + errors::InternalError, + ssa::ir::{dfg::CallStack, function::Function}, }; -use crate::ssa::ir::function::Function; /// Converting an SSA function into Brillig bytecode. pub(crate) fn convert_ssa_function( @@ -36,3 +42,43 @@ pub(crate) fn convert_ssa_function( artifact.name = func.name().to_string(); artifact } + +pub(crate) fn gen_brillig_for( + func: &Function, + arguments: Vec, + brillig: &Brillig, +) -> Result, InternalError> { + // Create the entry point artifact + let mut entry_point = BrilligContext::new_entry_point_artifact( + arguments, + FunctionContext::return_values(func), + func.id(), + ); + entry_point.name = func.name().to_string(); + + // Link the entry point with all dependencies + while let Some(unresolved_fn_label) = entry_point.first_unresolved_function_call() { + let artifact = &brillig.find_by_label(unresolved_fn_label.clone()); + let artifact = match artifact { + Some(artifact) => artifact, + None => { + return Err(InternalError::General { + message: format!("Cannot find linked fn {unresolved_fn_label}"), + call_stack: CallStack::new(), + }) + } + }; + entry_point.link_with(artifact); + // Insert the range of opcode locations occupied by a procedure + if let Some(procedure_id) = &artifact.procedure { + let num_opcodes = entry_point.byte_code.len(); + let previous_num_opcodes = entry_point.byte_code.len() - artifact.byte_code.len(); + // We subtract one as to keep the range inclusive on both ends + entry_point + .procedure_locations + .insert(procedure_id.clone(), (previous_num_opcodes, num_opcodes - 1)); + } + } + // Generate the final bytecode + Ok(entry_point.finish()) +} diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 9e11441caf4..f0a6fcd986f 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -140,6 +140,23 @@ pub(crate) fn optimize_into_acir( ssa.to_brillig(options.enable_brillig_logging) }); + let ssa_gen_span = span!(Level::TRACE, "ssa_generation"); + let ssa_gen_span_guard = ssa_gen_span.enter(); + + let ssa = SsaBuilder { + ssa, + print_ssa_passes: options.enable_ssa_logging, + print_codegen_timings: options.print_codegen_timings, + } + .run_pass( + |ssa| ssa.fold_constants_with_brillig(&brillig), + "After Constant Folding with Brillig:", + ) + .run_pass(Ssa::dead_instruction_elimination, "After Dead Instruction Elimination:") + .finish(); + + drop(ssa_gen_span_guard); + let artifacts = time("SSA to ACIR", options.print_codegen_timings, || { ssa.into_acir(&brillig, options.expression_width) })?; diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index 95447f941b0..7e41512fd8f 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -806,7 +806,8 @@ fn simplify_derive_generators( results.push(is_infinite); } let len = results.len(); - let typ = Type::Array(vec![Type::field()].into(), len); + let typ = + Type::Array(vec![Type::field(), Type::field(), Type::unsigned(1)].into(), len / 3); let result = make_array(dfg, results.into(), typ, block, call_stack); SimplifyResult::SimplifiedTo(result) } else { @@ -816,3 +817,34 @@ fn simplify_derive_generators( unreachable!("Unexpected number of arguments to derive_generators"); } } + +#[cfg(test)] +mod tests { + use crate::ssa::{opt::assert_normalized_ssa_equals, Ssa}; + + #[test] + fn simplify_derive_generators_has_correct_type() { + let src = " + brillig(inline) fn main f0 { + b0(): + v0 = make_array [u8 68, u8 69, u8 70, u8 65, u8 85, u8 76, u8 84, u8 95, u8 68, u8 79, u8 77, u8 65, u8 73, u8 78, u8 95, u8 83, u8 69, u8 80, u8 65, u8 82, u8 65, u8 84, u8 79, u8 82] : [u8; 24] + + // This call was previously incorrectly simplified to something that returned `[Field; 3]` + v2 = call derive_pedersen_generators(v0, u32 0) -> [(Field, Field, u1); 1] + + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + brillig(inline) fn main f0 { + b0(): + v15 = make_array [u8 68, u8 69, u8 70, u8 65, u8 85, u8 76, u8 84, u8 95, u8 68, u8 79, u8 77, u8 65, u8 73, u8 78, u8 95, u8 83, u8 69, u8 80, u8 65, u8 82, u8 65, u8 84, u8 79, u8 82] : [u8; 24] + v19 = make_array [Field 3728882899078719075161482178784387565366481897740339799480980287259621149274, Field -9903063709032878667290627648209915537972247634463802596148419711785767431332, u1 0] : [(Field, Field, u1); 1] + return v19 + } + "; + assert_normalized_ssa_equals(ssa, expected); + } +} diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 32f66e5a0f0..019bace33a3 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -19,22 +19,35 @@ //! //! This is the only pass which removes duplicated pure [`Instruction`]s however and so is needed when //! different blocks are merged, i.e. after the [`flatten_cfg`][super::flatten_cfg] pass. -use std::collections::{HashSet, VecDeque}; +use std::collections::{BTreeMap, HashSet, VecDeque}; -use acvm::{acir::AcirField, FieldElement}; +use acvm::{ + acir::AcirField, + brillig_vm::{MemoryValue, VMStatus, VM}, + FieldElement, +}; +use bn254_blackbox_solver::Bn254BlackBoxSolver; +use im::Vector; use iter_extended::vecmap; -use crate::ssa::{ - ir::{ - basic_block::BasicBlockId, - dfg::{DataFlowGraph, InsertInstructionResult}, - dom::DominatorTree, - function::Function, - instruction::{Instruction, InstructionId}, - types::Type, - value::{Value, ValueId}, +use crate::{ + brillig::{ + brillig_gen::gen_brillig_for, + brillig_ir::{artifact::BrilligParameter, brillig_variable::get_bit_size_from_ssa_type}, + Brillig, + }, + ssa::{ + ir::{ + basic_block::BasicBlockId, + dfg::{DataFlowGraph, InsertInstructionResult}, + dom::DominatorTree, + function::{Function, FunctionId, RuntimeType}, + instruction::{Instruction, InstructionId}, + types::Type, + value::{Value, ValueId}, + }, + ssa_gen::Ssa, }, - ssa_gen::Ssa, }; use fxhash::FxHashMap as HashMap; @@ -45,7 +58,7 @@ impl Ssa { #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn fold_constants(mut self) -> Ssa { for function in self.functions.values_mut() { - function.constant_fold(false); + function.constant_fold(false, None); } self } @@ -58,17 +71,82 @@ impl Ssa { #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn fold_constants_using_constraints(mut self) -> Ssa { for function in self.functions.values_mut() { - function.constant_fold(true); + function.constant_fold(true, None); } self } + + /// Performs constant folding on each instruction while also replacing calls to brillig functions + /// with all constant arguments by trying to evaluate those calls. + #[tracing::instrument(level = "trace", skip(self, brillig))] + pub(crate) fn fold_constants_with_brillig(mut self, brillig: &Brillig) -> Ssa { + // Collect all brillig functions so that later we can find them when processing a call instruction + let mut brillig_functions: BTreeMap = BTreeMap::new(); + for (func_id, func) in &self.functions { + if let RuntimeType::Brillig(..) = func.runtime() { + let cloned_function = Function::clone_with_id(*func_id, func); + brillig_functions.insert(*func_id, cloned_function); + }; + } + + let brillig_info = Some(BrilligInfo { brillig, brillig_functions: &brillig_functions }); + + for function in self.functions.values_mut() { + function.constant_fold(false, brillig_info); + } + + // It could happen that we inlined all calls to a given brillig function. + // In that case it's unused so we can remove it. This is what we check next. + self.remove_unused_brillig_functions(brillig_functions) + } + + fn remove_unused_brillig_functions( + mut self, + mut brillig_functions: BTreeMap, + ) -> Ssa { + // Remove from the above map functions that are called + for function in self.functions.values() { + for block_id in function.reachable_blocks() { + for instruction_id in function.dfg[block_id].instructions() { + let instruction = &function.dfg[*instruction_id]; + let Instruction::Call { func: func_id, arguments: _ } = instruction else { + continue; + }; + + let func_value = &function.dfg[*func_id]; + let Value::Function(func_id) = func_value else { continue }; + + brillig_functions.remove(func_id); + } + } + } + + // The ones that remain are never called: let's remove them. + for func_id in brillig_functions.keys() { + // We never want to remove the main function (it could be `unconstrained` or it + // could have been turned into brillig if `--force-brillig` was given). + // We also don't want to remove entry points. + if self.main_id == *func_id || self.entry_point_to_generated_index.contains_key(func_id) + { + continue; + } + + self.functions.remove(func_id); + } + + self + } } impl Function { /// The structure of this pass is simple: /// Go through each block and re-insert all instructions. - pub(crate) fn constant_fold(&mut self, use_constraint_info: bool) { - let mut context = Context::new(self, use_constraint_info); + pub(crate) fn constant_fold( + &mut self, + use_constraint_info: bool, + brillig_info: Option, + ) { + let mut context = Context::new(self, use_constraint_info, brillig_info); context.block_queue.push_back(self.entry_block()); while let Some(block) = context.block_queue.pop_front() { @@ -82,8 +160,9 @@ impl Function { } } -struct Context { +struct Context<'a> { use_constraint_info: bool, + brillig_info: Option>, /// Maps pre-folded ValueIds to the new ValueIds obtained by re-inserting the instruction. visited_blocks: HashSet, block_queue: VecDeque, @@ -103,6 +182,12 @@ struct Context { dom: DominatorTree, } +#[derive(Copy, Clone)] +pub(crate) struct BrilligInfo<'a> { + brillig: &'a Brillig, + brillig_functions: &'a BTreeMap, +} + /// HashMap from (Instruction, side_effects_enabled_var) to the results of the instruction. /// Stored as a two-level map to avoid cloning Instructions during the `.get` call. /// @@ -118,10 +203,15 @@ struct ResultCache { result: Option<(BasicBlockId, Vec)>, } -impl Context { - fn new(function: &Function, use_constraint_info: bool) -> Self { +impl<'brillig> Context<'brillig> { + fn new( + function: &Function, + use_constraint_info: bool, + brillig_info: Option>, + ) -> Self { Self { use_constraint_info, + brillig_info, visited_blocks: Default::default(), block_queue: Default::default(), constraint_simplification_mappings: Default::default(), @@ -177,8 +267,25 @@ impl Context { } } - // Otherwise, try inserting the instruction again to apply any optimizations using the newly resolved inputs. - let new_results = Self::push_instruction(id, instruction.clone(), &old_results, block, dfg); + let new_results = + // First try to inline a call to a brillig function with all constant arguments. + Self::try_inline_brillig_call_with_all_constants( + &instruction, + &old_results, + block, + dfg, + self.brillig_info, + ) + .unwrap_or_else(|| { + // Otherwise, try inserting the instruction again to apply any optimizations using the newly resolved inputs. + Self::push_instruction( + id, + instruction.clone(), + &old_results, + block, + dfg, + ) + }); Self::replace_result_ids(dfg, &old_results, &new_results); @@ -342,6 +449,166 @@ impl Context { results_for_instruction.get(&predicate)?.get(block, &mut self.dom) } + + /// Checks if the given instruction is a call to a brillig function with all constant arguments. + /// If so, we can try to evaluate that function and replace the results with the evaluation results. + fn try_inline_brillig_call_with_all_constants( + instruction: &Instruction, + old_results: &[ValueId], + block: BasicBlockId, + dfg: &mut DataFlowGraph, + brillig_info: Option, + ) -> Option> { + let evaluation_result = Self::evaluate_const_brillig_call( + instruction, + brillig_info?.brillig, + brillig_info?.brillig_functions, + dfg, + ); + + match evaluation_result { + EvaluationResult::NotABrilligCall | EvaluationResult::CannotEvaluate(_) => None, + EvaluationResult::Evaluated(memory_values) => { + let mut memory_index = 0; + let new_results = vecmap(old_results, |old_result| { + let typ = dfg.type_of_value(*old_result); + Self::new_value_for_type_and_memory_values( + typ, + block, + &memory_values, + &mut memory_index, + dfg, + ) + }); + Some(new_results) + } + } + } + + /// Tries to evaluate an instruction if it's a call that points to a brillig function, + /// and all its arguments are constant. + /// We do this by directly executing the function with a brillig VM. + fn evaluate_const_brillig_call( + instruction: &Instruction, + brillig: &Brillig, + brillig_functions: &BTreeMap, + dfg: &mut DataFlowGraph, + ) -> EvaluationResult { + let Instruction::Call { func: func_id, arguments } = instruction else { + return EvaluationResult::NotABrilligCall; + }; + + let func_value = &dfg[*func_id]; + let Value::Function(func_id) = func_value else { + return EvaluationResult::NotABrilligCall; + }; + + let Some(func) = brillig_functions.get(func_id) else { + return EvaluationResult::NotABrilligCall; + }; + + if !arguments.iter().all(|argument| dfg.is_constant(*argument)) { + return EvaluationResult::CannotEvaluate(*func_id); + } + + let mut brillig_arguments = Vec::new(); + for argument in arguments { + let typ = dfg.type_of_value(*argument); + let Some(parameter) = type_to_brillig_parameter(&typ) else { + return EvaluationResult::CannotEvaluate(*func_id); + }; + brillig_arguments.push(parameter); + } + + // Check that return value types are supported by brillig + for return_id in func.returns().iter() { + let typ = func.dfg.type_of_value(*return_id); + if type_to_brillig_parameter(&typ).is_none() { + return EvaluationResult::CannotEvaluate(*func_id); + } + } + + let Ok(generated_brillig) = gen_brillig_for(func, brillig_arguments, brillig) else { + return EvaluationResult::CannotEvaluate(*func_id); + }; + + let mut calldata = Vec::new(); + for argument in arguments { + value_id_to_calldata(*argument, dfg, &mut calldata); + } + + let bytecode = &generated_brillig.byte_code; + let foreign_call_results = Vec::new(); + let black_box_solver = Bn254BlackBoxSolver; + let profiling_active = false; + let mut vm = + VM::new(calldata, bytecode, foreign_call_results, &black_box_solver, profiling_active); + let vm_status: VMStatus<_> = vm.process_opcodes(); + let VMStatus::Finished { return_data_offset, return_data_size } = vm_status else { + return EvaluationResult::CannotEvaluate(*func_id); + }; + + let memory = + vm.get_memory()[return_data_offset..(return_data_offset + return_data_size)].to_vec(); + + EvaluationResult::Evaluated(memory) + } + + /// Creates a new value inside this function by reading it from `memory_values` starting at + /// `memory_index` depending on the given Type: if it's an array multiple values will be read + /// and a new `make_array` instruction will be created. + fn new_value_for_type_and_memory_values( + typ: Type, + block_id: BasicBlockId, + memory_values: &[MemoryValue], + memory_index: &mut usize, + dfg: &mut DataFlowGraph, + ) -> ValueId { + match typ { + Type::Numeric(_) => { + let memory = memory_values[*memory_index]; + *memory_index += 1; + + let field_value = match memory { + MemoryValue::Field(field_value) => field_value, + MemoryValue::Integer(u128_value, _) => u128_value.into(), + }; + dfg.make_constant(field_value, typ) + } + Type::Array(types, length) => { + let mut new_array_values = Vector::new(); + for _ in 0..length { + for typ in types.iter() { + let new_value = Self::new_value_for_type_and_memory_values( + typ.clone(), + block_id, + memory_values, + memory_index, + dfg, + ); + new_array_values.push_back(new_value); + } + } + + let instruction = Instruction::MakeArray { + elements: new_array_values, + typ: Type::Array(types, length), + }; + let instruction_id = dfg.make_instruction(instruction, None); + dfg[block_id].instructions_mut().push(instruction_id); + *dfg.instruction_results(instruction_id).first().unwrap() + } + Type::Reference(_) => { + panic!("Unexpected reference type in brillig function result") + } + Type::Slice(_) => { + panic!("Unexpected slice type in brillig function result") + } + Type::Function => { + panic!("Unexpected function type in brillig function result") + } + } + } } impl ResultCache { @@ -376,6 +643,50 @@ enum CacheResult<'a> { NeedToHoistToCommonBlock(BasicBlockId, &'a [ValueId]), } +/// Result of trying to evaluate an instruction (any instruction) in this pass. +enum EvaluationResult { + /// Nothing was done because the instruction wasn't a call to a brillig function, + /// or some arguments to it were not constants. + NotABrilligCall, + /// The instruction was a call to a brillig function, but we couldn't evaluate it. + /// This can occur in the situation where the brillig function reaches a "trap" or a foreign call opcode. + CannotEvaluate(FunctionId), + /// The instruction was a call to a brillig function and we were able to evaluate it, + /// returning evaluation memory values. + Evaluated(Vec>), +} + +/// Similar to FunctionContext::ssa_type_to_parameter but never panics and disallows reference types. +pub(crate) fn type_to_brillig_parameter(typ: &Type) -> Option { + match typ { + Type::Numeric(_) => Some(BrilligParameter::SingleAddr(get_bit_size_from_ssa_type(typ))), + Type::Array(item_type, size) => { + let mut parameters = Vec::with_capacity(item_type.len()); + for item_typ in item_type.iter() { + parameters.push(type_to_brillig_parameter(item_typ)?); + } + Some(BrilligParameter::Array(parameters, *size)) + } + _ => None, + } +} + +fn value_id_to_calldata(value_id: ValueId, dfg: &DataFlowGraph, calldata: &mut Vec) { + if let Some(value) = dfg.get_numeric_constant(value_id) { + calldata.push(value); + return; + } + + if let Some((values, _type)) = dfg.get_array_constant(value_id) { + for value in values { + value_id_to_calldata(value, dfg, calldata); + } + return; + } + + panic!("Expected ValueId to be numeric constant or array constant"); +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -854,4 +1165,180 @@ mod test { let ssa = ssa.fold_constants_using_constraints(); assert_normalized_ssa_equals(ssa, expected); } + + #[test] + fn inlines_brillig_call_without_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1() -> Field + return v0 + } + + brillig(inline) fn one f1 { + b0(): + v0 = add Field 2, Field 3 + return v0 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + return Field 5 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_two_field_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(Field 2, Field 3) -> Field + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: Field, v1: Field): + v2 = add v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + return Field 5 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_two_i32_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(i32 2, i32 3) -> i32 + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: i32, v1: i32): + v2 = add v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + return i32 5 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_array_return() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(Field 2, Field 3, Field 4) -> [Field; 3] + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: Field, v1: Field, v2: Field): + v3 = make_array [v0, v1, v2] : [Field; 3] + return v3 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + v3 = make_array [Field 2, Field 3, Field 4] : [Field; 3] + return v3 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_composite_array_return() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(Field 2, i32 3, Field 4, i32 5) -> [(Field, i32); 2] + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: Field, v1: i32, v2: i32, v3: Field): + v4 = make_array [v0, v1, v2, v3] : [(Field, i32); 2] + return v4 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + v4 = make_array [Field 2, i32 3, Field 4, i32 5] : [(Field, i32); 2] + return v4 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_array_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = make_array [Field 2, Field 3] : [Field; 2] + v1 = call f1(v0) -> Field + return v1 + } + + brillig(inline) fn one f1 { + b0(v0: [Field; 2]): + inc_rc v0 + v2 = array_get v0, index u32 0 -> Field + v4 = array_get v0, index u32 1 -> Field + v5 = add v2, v4 + dec_rc v0 + return v5 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + v2 = make_array [Field 2, Field 3] : [Field; 2] + return Field 5 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } } diff --git a/compiler/noirc_evaluator/src/ssa/parser/mod.rs b/compiler/noirc_evaluator/src/ssa/parser/mod.rs index 2db2c636a8f..7753908b2bd 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/mod.rs @@ -28,6 +28,11 @@ mod tests; mod token; impl Ssa { + /// Creates an Ssa object from the given string. + /// + /// Note that the resulting Ssa might not be exactly the same as the given string. + /// This is because, internally, the Ssa is built using a `FunctionBuilder`, so + /// some instructions might be simplified while they are inserted. pub(crate) fn from_str(src: &str) -> Result { let mut parser = Parser::new(src).map_err(|err| SsaErrorWithSource::parse_error(err, src))?; diff --git a/compiler/noirc_frontend/src/ast/traits.rs b/compiler/noirc_frontend/src/ast/traits.rs index 723df775b1e..475e3ff1be9 100644 --- a/compiler/noirc_frontend/src/ast/traits.rs +++ b/compiler/noirc_frontend/src/ast/traits.rs @@ -24,6 +24,7 @@ pub struct NoirTrait { pub items: Vec>, pub attributes: Vec, pub visibility: ItemVisibility, + pub is_alias: bool, } /// Any declaration inside the body of a trait that a user is required to @@ -77,6 +78,9 @@ pub struct NoirTraitImpl { pub where_clause: Vec, pub items: Vec>, + + /// true if generated at compile-time, e.g. from a trait alias + pub is_synthetic: bool, } /// Represents a simple trait constraint such as `where Foo: TraitY` @@ -130,12 +134,19 @@ impl Display for TypeImpl { } } +// TODO: display where clauses (follow-up issue) impl Display for NoirTrait { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let generics = vecmap(&self.generics, |generic| generic.to_string()); let generics = if generics.is_empty() { "".into() } else { generics.join(", ") }; write!(f, "trait {}{}", self.name, generics)?; + + if self.is_alias { + let bounds = vecmap(&self.bounds, |bound| bound.to_string()).join(" + "); + return write!(f, " = {};", bounds); + } + if !self.bounds.is_empty() { let bounds = vecmap(&self.bounds, |bound| bound.to_string()).join(" + "); write!(f, ": {}", bounds)?; @@ -222,6 +233,11 @@ impl Display for TraitBound { impl Display for NoirTraitImpl { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Synthetic NoirTraitImpl's don't get printed + if self.is_synthetic { + return Ok(()); + } + write!(f, "impl")?; if !self.impl_generics.is_empty() { write!( diff --git a/compiler/noirc_frontend/src/hir/type_check/errors.rs b/compiler/noirc_frontend/src/hir/type_check/errors.rs index a63601a4280..a6b6120986e 100644 --- a/compiler/noirc_frontend/src/hir/type_check/errors.rs +++ b/compiler/noirc_frontend/src/hir/type_check/errors.rs @@ -208,7 +208,7 @@ pub enum TypeCheckError { #[derive(Debug, Clone, PartialEq, Eq)] pub struct NoMatchingImplFoundError { - constraints: Vec<(Type, String)>, + pub(crate) constraints: Vec<(Type, String)>, pub span: Span, } diff --git a/compiler/noirc_frontend/src/parser/errors.rs b/compiler/noirc_frontend/src/parser/errors.rs index 63471efac43..bcb4ce1c616 100644 --- a/compiler/noirc_frontend/src/parser/errors.rs +++ b/compiler/noirc_frontend/src/parser/errors.rs @@ -95,6 +95,8 @@ pub enum ParserErrorReason { AssociatedTypesNotAllowedInPaths, #[error("Associated types are not allowed on a method call")] AssociatedTypesNotAllowedInMethodCalls, + #[error("Empty trait alias")] + EmptyTraitAlias, #[error( "Wrong number of arguments for attribute `{}`. Expected {}, found {}", name, diff --git a/compiler/noirc_frontend/src/parser/parser/impls.rs b/compiler/noirc_frontend/src/parser/parser/impls.rs index 9215aec2742..8e6b3bae0e9 100644 --- a/compiler/noirc_frontend/src/parser/parser/impls.rs +++ b/compiler/noirc_frontend/src/parser/parser/impls.rs @@ -113,6 +113,7 @@ impl<'a> Parser<'a> { let object_type = self.parse_type_or_error(); let where_clause = self.parse_where_clause(); let items = self.parse_trait_impl_body(); + let is_synthetic = false; NoirTraitImpl { impl_generics, @@ -121,6 +122,7 @@ impl<'a> Parser<'a> { object_type, where_clause, items, + is_synthetic, } } diff --git a/compiler/noirc_frontend/src/parser/parser/item.rs b/compiler/noirc_frontend/src/parser/parser/item.rs index 4fbcd7abac5..ce712b559d8 100644 --- a/compiler/noirc_frontend/src/parser/parser/item.rs +++ b/compiler/noirc_frontend/src/parser/parser/item.rs @@ -1,3 +1,5 @@ +use iter_extended::vecmap; + use crate::{ parser::{labels::ParsingRuleLabel, Item, ItemKind}, token::{Keyword, Token}, @@ -13,31 +15,32 @@ impl<'a> Parser<'a> { } pub(crate) fn parse_module_items(&mut self, nested: bool) -> Vec { - self.parse_many("items", without_separator(), |parser| { + self.parse_many_to_many("items", without_separator(), |parser| { parser.parse_module_item_in_list(nested) }) } - fn parse_module_item_in_list(&mut self, nested: bool) -> Option { + fn parse_module_item_in_list(&mut self, nested: bool) -> Vec { loop { // We only break out of the loop on `}` if we are inside a `mod { ..` if nested && self.at(Token::RightBrace) { - return None; + return vec![]; } // We always break on EOF (we don't error because if we are inside `mod { ..` // the outer parsing logic will error instead) if self.at_eof() { - return None; + return vec![]; } - let Some(item) = self.parse_item() else { + let parsed_items = self.parse_item(); + if parsed_items.is_empty() { // If we couldn't parse an item we check which token we got match self.token.token() { Token::RightBrace if nested => { - return None; + return vec![]; } - Token::EOF => return None, + Token::EOF => return vec![], _ => (), } @@ -47,7 +50,7 @@ impl<'a> Parser<'a> { continue; }; - return Some(item); + return parsed_items; } } @@ -85,15 +88,19 @@ impl<'a> Parser<'a> { } /// Item = OuterDocComments ItemKind - fn parse_item(&mut self) -> Option { + fn parse_item(&mut self) -> Vec { let start_span = self.current_token_span; let doc_comments = self.parse_outer_doc_comments(); - let kind = self.parse_item_kind()?; + let kinds = self.parse_item_kind(); let span = self.span_since(start_span); - Some(Item { kind, span, doc_comments }) + vecmap(kinds, |kind| Item { kind, span, doc_comments: doc_comments.clone() }) } + /// This method returns one 'ItemKind' in the majority of cases. + /// The current exception is when parsing a trait alias, + /// which returns both the trait and the impl. + /// /// ItemKind /// = InnerAttribute /// | Attributes Modifiers @@ -106,9 +113,9 @@ impl<'a> Parser<'a> { /// | TypeAlias /// | Function /// ) - fn parse_item_kind(&mut self) -> Option { + fn parse_item_kind(&mut self) -> Vec { if let Some(kind) = self.parse_inner_attribute() { - return Some(ItemKind::InnerAttribute(kind)); + return vec![ItemKind::InnerAttribute(kind)]; } let start_span = self.current_token_span; @@ -122,78 +129,81 @@ impl<'a> Parser<'a> { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); let use_tree = self.parse_use_tree(); - return Some(ItemKind::Import(use_tree, modifiers.visibility)); + return vec![ItemKind::Import(use_tree, modifiers.visibility)]; } if let Some(is_contract) = self.eat_mod_or_contract() { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(self.parse_mod_or_contract(attributes, is_contract, modifiers.visibility)); + return vec![self.parse_mod_or_contract(attributes, is_contract, modifiers.visibility)]; } if self.eat_keyword(Keyword::Struct) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(ItemKind::Struct(self.parse_struct( + return vec![ItemKind::Struct(self.parse_struct( attributes, modifiers.visibility, start_span, - ))); + ))]; } if self.eat_keyword(Keyword::Impl) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(match self.parse_impl() { + return vec![match self.parse_impl() { Impl::Impl(type_impl) => ItemKind::Impl(type_impl), Impl::TraitImpl(noir_trait_impl) => ItemKind::TraitImpl(noir_trait_impl), - }); + }]; } if self.eat_keyword(Keyword::Trait) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(ItemKind::Trait(self.parse_trait( - attributes, - modifiers.visibility, - start_span, - ))); + let (noir_trait, noir_impl) = + self.parse_trait(attributes, modifiers.visibility, start_span); + let mut output = vec![ItemKind::Trait(noir_trait)]; + if let Some(noir_impl) = noir_impl { + output.push(ItemKind::TraitImpl(noir_impl)); + } + + return output; } if self.eat_keyword(Keyword::Global) { self.unconstrained_not_applicable(modifiers); - return Some(ItemKind::Global( + return vec![ItemKind::Global( self.parse_global( attributes, modifiers.comptime.is_some(), modifiers.mutable.is_some(), ), modifiers.visibility, - )); + )]; } if self.eat_keyword(Keyword::Type) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(ItemKind::TypeAlias( + return vec![ItemKind::TypeAlias( self.parse_type_alias(modifiers.visibility, start_span), - )); + )]; } if self.eat_keyword(Keyword::Fn) { self.mutable_not_applicable(modifiers); - return Some(ItemKind::Function(self.parse_function( + return vec![ItemKind::Function(self.parse_function( attributes, modifiers.visibility, modifiers.comptime.is_some(), modifiers.unconstrained.is_some(), false, // allow_self - ))); + ))]; } - None + vec![] } fn eat_mod_or_contract(&mut self) -> Option { diff --git a/compiler/noirc_frontend/src/parser/parser/parse_many.rs b/compiler/noirc_frontend/src/parser/parser/parse_many.rs index ea4dfe97122..be156eb1618 100644 --- a/compiler/noirc_frontend/src/parser/parser/parse_many.rs +++ b/compiler/noirc_frontend/src/parser/parser/parse_many.rs @@ -18,6 +18,19 @@ impl<'a> Parser<'a> { self.parse_many_return_trailing_separator_if_any(items, separated_by, f).0 } + /// parse_many, where the given function `f` may return multiple results + pub(super) fn parse_many_to_many( + &mut self, + items: &'static str, + separated_by: SeparatedBy, + f: F, + ) -> Vec + where + F: FnMut(&mut Parser<'a>) -> Vec, + { + self.parse_many_to_many_return_trailing_separator_if_any(items, separated_by, f).0 + } + /// Same as parse_many, but returns a bool indicating whether a trailing separator was found. pub(super) fn parse_many_return_trailing_separator_if_any( &mut self, @@ -27,6 +40,26 @@ impl<'a> Parser<'a> { ) -> (Vec, bool) where F: FnMut(&mut Parser<'a>) -> Option, + { + let f = |x: &mut Parser<'a>| { + if let Some(result) = f(x) { + vec![result] + } else { + vec![] + } + }; + self.parse_many_to_many_return_trailing_separator_if_any(items, separated_by, f) + } + + /// Same as parse_many, but returns a bool indicating whether a trailing separator was found. + fn parse_many_to_many_return_trailing_separator_if_any( + &mut self, + items: &'static str, + separated_by: SeparatedBy, + mut f: F, + ) -> (Vec, bool) + where + F: FnMut(&mut Parser<'a>) -> Vec, { let mut elements: Vec = Vec::new(); let mut trailing_separator = false; @@ -38,12 +71,13 @@ impl<'a> Parser<'a> { } let start_span = self.current_token_span; - let Some(element) = f(self) else { + let mut new_elements = f(self); + if new_elements.is_empty() { if let Some(end) = &separated_by.until { self.eat(end.clone()); } break; - }; + } if let Some(separator) = &separated_by.token { if !trailing_separator && !elements.is_empty() { @@ -51,7 +85,7 @@ impl<'a> Parser<'a> { } } - elements.push(element); + elements.append(&mut new_elements); trailing_separator = if let Some(separator) = &separated_by.token { self.eat(separator.clone()) diff --git a/compiler/noirc_frontend/src/parser/parser/traits.rs b/compiler/noirc_frontend/src/parser/parser/traits.rs index fead6a34c82..e03b629e9ea 100644 --- a/compiler/noirc_frontend/src/parser/parser/traits.rs +++ b/compiler/noirc_frontend/src/parser/parser/traits.rs @@ -1,9 +1,14 @@ +use iter_extended::vecmap; + use noirc_errors::Span; -use crate::ast::{Documented, ItemVisibility, NoirTrait, Pattern, TraitItem, UnresolvedType}; +use crate::ast::{ + Documented, GenericTypeArg, GenericTypeArgs, ItemVisibility, NoirTrait, Path, Pattern, + TraitItem, UnresolvedGeneric, UnresolvedTraitConstraint, UnresolvedType, +}; use crate::{ ast::{Ident, UnresolvedTypeData}, - parser::{labels::ParsingRuleLabel, ParserErrorReason}, + parser::{labels::ParsingRuleLabel, NoirTraitImpl, ParserErrorReason}, token::{Attribute, Keyword, SecondaryAttribute, Token}, }; @@ -12,34 +17,117 @@ use super::Parser; impl<'a> Parser<'a> { /// Trait = 'trait' identifier Generics ( ':' TraitBounds )? WhereClause TraitBody + /// | 'trait' identifier Generics '=' TraitBounds WhereClause ';' pub(crate) fn parse_trait( &mut self, attributes: Vec<(Attribute, Span)>, visibility: ItemVisibility, start_span: Span, - ) -> NoirTrait { + ) -> (NoirTrait, Option) { let attributes = self.validate_secondary_attributes(attributes); let Some(name) = self.eat_ident() else { self.expected_identifier(); - return empty_trait(attributes, visibility, self.span_since(start_span)); + let noir_trait = empty_trait(attributes, visibility, self.span_since(start_span)); + let no_implicit_impl = None; + return (noir_trait, no_implicit_impl); }; let generics = self.parse_generics(); - let bounds = if self.eat_colon() { self.parse_trait_bounds() } else { Vec::new() }; - let where_clause = self.parse_where_clause(); - let items = self.parse_trait_body(); - NoirTrait { + // Trait aliases: + // trait Foo<..> = A + B + E where ..; + let (bounds, where_clause, items, is_alias) = if self.eat_assign() { + let bounds = self.parse_trait_bounds(); + + if bounds.is_empty() { + self.push_error(ParserErrorReason::EmptyTraitAlias, self.previous_token_span); + } + + let where_clause = self.parse_where_clause(); + let items = Vec::new(); + if !self.eat_semicolon() { + self.expected_token(Token::Semicolon); + } + + let is_alias = true; + (bounds, where_clause, items, is_alias) + } else { + let bounds = if self.eat_colon() { self.parse_trait_bounds() } else { Vec::new() }; + let where_clause = self.parse_where_clause(); + let items = self.parse_trait_body(); + let is_alias = false; + (bounds, where_clause, items, is_alias) + }; + + let span = self.span_since(start_span); + + let noir_impl = is_alias.then(|| { + let object_type_ident = Ident::new("#T".to_string(), span); + let object_type_path = Path::from_ident(object_type_ident.clone()); + let object_type_generic = UnresolvedGeneric::Variable(object_type_ident); + + let is_synthesized = true; + let object_type = UnresolvedType { + typ: UnresolvedTypeData::Named(object_type_path, vec![].into(), is_synthesized), + span, + }; + + let mut impl_generics = generics.clone(); + impl_generics.push(object_type_generic); + + let trait_name = Path::from_ident(name.clone()); + let trait_generics: GenericTypeArgs = vecmap(generics.clone(), |generic| { + let is_synthesized = true; + let generic_type = UnresolvedType { + typ: UnresolvedTypeData::Named( + Path::from_ident(generic.ident().clone()), + vec![].into(), + is_synthesized, + ), + span, + }; + + GenericTypeArg::Ordered(generic_type) + }) + .into(); + + // bounds from trait + let mut where_clause = where_clause.clone(); + for bound in bounds.clone() { + where_clause.push(UnresolvedTraitConstraint { + typ: object_type.clone(), + trait_bound: bound, + }); + } + + let items = vec![]; + let is_synthetic = true; + + NoirTraitImpl { + impl_generics, + trait_name, + trait_generics, + object_type, + where_clause, + items, + is_synthetic, + } + }); + + let noir_trait = NoirTrait { name, generics, bounds, where_clause, - span: self.span_since(start_span), + span, items, attributes, visibility, - } + is_alias, + }; + + (noir_trait, noir_impl) } /// TraitBody = '{' ( OuterDocComments TraitItem )* '}' @@ -188,28 +276,51 @@ fn empty_trait( items: Vec::new(), attributes, visibility, + is_alias: false, } } #[cfg(test)] mod tests { use crate::{ - ast::{NoirTrait, TraitItem}, + ast::{NoirTrait, NoirTraitImpl, TraitItem}, parser::{ - parser::{parse_program, tests::expect_no_errors}, + parser::{parse_program, tests::expect_no_errors, ParserErrorReason}, ItemKind, }, }; - fn parse_trait_no_errors(src: &str) -> NoirTrait { + fn parse_trait_opt_impl_no_errors(src: &str) -> (NoirTrait, Option) { let (mut module, errors) = parse_program(src); expect_no_errors(&errors); - assert_eq!(module.items.len(), 1); - let item = module.items.remove(0); + let (item, impl_item) = if module.items.len() == 2 { + let item = module.items.remove(0); + let impl_item = module.items.remove(0); + (item, Some(impl_item)) + } else { + assert_eq!(module.items.len(), 1); + let item = module.items.remove(0); + (item, None) + }; let ItemKind::Trait(noir_trait) = item.kind else { panic!("Expected trait"); }; - noir_trait + let noir_trait_impl = impl_item.map(|impl_item| { + let ItemKind::TraitImpl(noir_trait_impl) = impl_item.kind else { + panic!("Expected impl"); + }; + noir_trait_impl + }); + (noir_trait, noir_trait_impl) + } + + fn parse_trait_with_impl_no_errors(src: &str) -> (NoirTrait, NoirTraitImpl) { + let (noir_trait, noir_trait_impl) = parse_trait_opt_impl_no_errors(src); + (noir_trait, noir_trait_impl.expect("expected a NoirTraitImpl")) + } + + fn parse_trait_no_errors(src: &str) -> NoirTrait { + parse_trait_opt_impl_no_errors(src).0 } #[test] @@ -220,6 +331,15 @@ mod tests { assert!(noir_trait.generics.is_empty()); assert!(noir_trait.where_clause.is_empty()); assert!(noir_trait.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_empty_trait_alias() { + let src = "trait Foo = ;"; + let (_module, errors) = parse_program(src); + assert_eq!(errors.len(), 2); + assert_eq!(errors[1].reason(), Some(ParserErrorReason::EmptyTraitAlias).as_ref()); } #[test] @@ -230,6 +350,50 @@ mod tests { assert_eq!(noir_trait.generics.len(), 2); assert!(noir_trait.where_clause.is_empty()); assert!(noir_trait.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_trait_alias_with_generics() { + let src = "trait Foo = Bar + Baz;"; + let (noir_trait_alias, noir_trait_impl) = parse_trait_with_impl_no_errors(src); + assert_eq!(noir_trait_alias.name.to_string(), "Foo"); + assert_eq!(noir_trait_alias.generics.len(), 2); + assert_eq!(noir_trait_alias.bounds.len(), 2); + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + assert!(noir_trait_alias.where_clause.is_empty()); + assert!(noir_trait_alias.items.is_empty()); + assert!(noir_trait_alias.is_alias); + + assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + assert_eq!(noir_trait_impl.impl_generics.len(), 3); + assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 2); + assert_eq!(noir_trait_impl.where_clause.len(), 2); + assert_eq!(noir_trait_alias.bounds.len(), 2); + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + assert!(noir_trait_impl.items.is_empty()); + assert!(noir_trait_impl.is_synthetic); + + // Equivalent to + let src = "trait Foo: Bar + Baz {}"; + let noir_trait = parse_trait_no_errors(src); + assert_eq!(noir_trait.name.to_string(), noir_trait_alias.name.to_string()); + assert_eq!(noir_trait.generics.len(), noir_trait_alias.generics.len()); + assert_eq!(noir_trait.bounds.len(), noir_trait_alias.bounds.len()); + assert_eq!(noir_trait.bounds[0].to_string(), noir_trait_alias.bounds[0].to_string()); + assert_eq!(noir_trait.where_clause.is_empty(), noir_trait_alias.where_clause.is_empty()); + assert_eq!(noir_trait.items.is_empty(), noir_trait_alias.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_empty_trait_alias_with_generics() { + let src = "trait Foo = ;"; + let (_module, errors) = parse_program(src); + assert_eq!(errors.len(), 2); + assert_eq!(errors[1].reason(), Some(ParserErrorReason::EmptyTraitAlias).as_ref()); } #[test] @@ -240,6 +404,54 @@ mod tests { assert_eq!(noir_trait.generics.len(), 2); assert_eq!(noir_trait.where_clause.len(), 1); assert!(noir_trait.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_trait_alias_with_where_clause() { + let src = "trait Foo = Bar + Baz where A: Z;"; + let (noir_trait_alias, noir_trait_impl) = parse_trait_with_impl_no_errors(src); + assert_eq!(noir_trait_alias.name.to_string(), "Foo"); + assert_eq!(noir_trait_alias.generics.len(), 2); + assert_eq!(noir_trait_alias.bounds.len(), 2); + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + assert_eq!(noir_trait_alias.where_clause.len(), 1); + assert!(noir_trait_alias.items.is_empty()); + assert!(noir_trait_alias.is_alias); + + assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + assert_eq!(noir_trait_impl.impl_generics.len(), 3); + assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 2); + assert_eq!(noir_trait_impl.where_clause.len(), 3); + assert_eq!(noir_trait_impl.where_clause[0].to_string(), "A: Z"); + assert_eq!(noir_trait_impl.where_clause[1].to_string(), "#T: Bar"); + assert_eq!(noir_trait_impl.where_clause[2].to_string(), "#T: Baz"); + assert!(noir_trait_impl.items.is_empty()); + assert!(noir_trait_impl.is_synthetic); + + // Equivalent to + let src = "trait Foo: Bar + Baz where A: Z {}"; + let noir_trait = parse_trait_no_errors(src); + assert_eq!(noir_trait.name.to_string(), noir_trait_alias.name.to_string()); + assert_eq!(noir_trait.generics.len(), noir_trait_alias.generics.len()); + assert_eq!(noir_trait.bounds.len(), noir_trait_alias.bounds.len()); + assert_eq!(noir_trait.bounds[0].to_string(), noir_trait_alias.bounds[0].to_string()); + assert_eq!(noir_trait.where_clause.len(), noir_trait_alias.where_clause.len()); + assert_eq!( + noir_trait.where_clause[0].to_string(), + noir_trait_alias.where_clause[0].to_string() + ); + assert_eq!(noir_trait.items.is_empty(), noir_trait_alias.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_empty_trait_alias_with_where_clause() { + let src = "trait Foo = where A: Z;"; + let (_module, errors) = parse_program(src); + assert_eq!(errors.len(), 2); + assert_eq!(errors[1].reason(), Some(ParserErrorReason::EmptyTraitAlias).as_ref()); } #[test] @@ -253,6 +465,7 @@ mod tests { panic!("Expected type"); }; assert_eq!(name.to_string(), "Elem"); + assert!(!noir_trait.is_alias); } #[test] @@ -268,6 +481,7 @@ mod tests { assert_eq!(name.to_string(), "x"); assert_eq!(typ.to_string(), "Field"); assert_eq!(default_value.unwrap().to_string(), "1"); + assert!(!noir_trait.is_alias); } #[test] @@ -281,6 +495,7 @@ mod tests { panic!("Expected function"); }; assert!(body.is_none()); + assert!(!noir_trait.is_alias); } #[test] @@ -294,6 +509,7 @@ mod tests { panic!("Expected function"); }; assert!(body.is_some()); + assert!(!noir_trait.is_alias); } #[test] @@ -306,5 +522,39 @@ mod tests { assert_eq!(noir_trait.bounds[1].to_string(), "Baz"); assert_eq!(noir_trait.to_string(), "trait Foo: Bar + Baz {\n}"); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_trait_alias() { + let src = "trait Foo = Bar + Baz;"; + let (noir_trait_alias, noir_trait_impl) = parse_trait_with_impl_no_errors(src); + assert_eq!(noir_trait_alias.bounds.len(), 2); + + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + + assert_eq!(noir_trait_alias.to_string(), "trait Foo = Bar + Baz;"); + assert!(noir_trait_alias.is_alias); + + assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + assert_eq!(noir_trait_impl.impl_generics.len(), 1); + assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 0); + assert_eq!(noir_trait_impl.where_clause.len(), 2); + assert_eq!(noir_trait_impl.where_clause[0].to_string(), "#T: Bar"); + assert_eq!(noir_trait_impl.where_clause[1].to_string(), "#T: Baz"); + assert!(noir_trait_impl.items.is_empty()); + assert!(noir_trait_impl.is_synthetic); + + // Equivalent to + let src = "trait Foo: Bar + Baz {}"; + let noir_trait = parse_trait_no_errors(src); + assert_eq!(noir_trait.name.to_string(), noir_trait_alias.name.to_string()); + assert_eq!(noir_trait.generics.len(), noir_trait_alias.generics.len()); + assert_eq!(noir_trait.bounds.len(), noir_trait_alias.bounds.len()); + assert_eq!(noir_trait.bounds[0].to_string(), noir_trait_alias.bounds[0].to_string()); + assert_eq!(noir_trait.where_clause.is_empty(), noir_trait_alias.where_clause.is_empty()); + assert_eq!(noir_trait.items.is_empty(), noir_trait_alias.items.is_empty()); + assert!(!noir_trait.is_alias); } } diff --git a/compiler/noirc_frontend/src/tests/traits.rs b/compiler/noirc_frontend/src/tests/traits.rs index 0adf5c90bea..811a32bab86 100644 --- a/compiler/noirc_frontend/src/tests/traits.rs +++ b/compiler/noirc_frontend/src/tests/traits.rs @@ -1,5 +1,6 @@ use crate::hir::def_collector::dc_crate::CompilationError; use crate::hir::resolution::errors::ResolverError; +use crate::hir::type_check::TypeCheckError; use crate::tests::{get_program_errors, get_program_with_maybe_parser_errors}; use super::assert_no_errors; @@ -320,6 +321,251 @@ fn regression_6314_double_inheritance() { assert_no_errors(src); } +#[test] +fn trait_alias_single_member() { + let src = r#" + trait Foo { + fn foo(self) -> Self; + } + + trait Baz = Foo; + + impl Foo for Field { + fn foo(self) -> Self { self } + } + + fn baz(x: T) -> T where T: Baz { + x.foo() + } + + fn main() { + let x: Field = 0; + let _ = baz(x); + } + "#; + assert_no_errors(src); +} + +#[test] +fn trait_alias_two_members() { + let src = r#" + pub trait Foo { + fn foo(self) -> Self; + } + + pub trait Bar { + fn bar(self) -> Self; + } + + pub trait Baz = Foo + Bar; + + fn baz(x: T) -> T where T: Baz { + x.foo().bar() + } + + impl Foo for Field { + fn foo(self) -> Self { + self + 1 + } + } + + impl Bar for Field { + fn bar(self) -> Self { + self + 2 + } + } + + fn main() { + assert(0.foo().bar() == baz(0)); + }"#; + + assert_no_errors(src); +} + +#[test] +fn trait_alias_polymorphic_inheritance() { + let src = r#" + trait Foo { + fn foo(self) -> Self; + } + + trait Bar { + fn bar(self) -> T; + } + + trait Baz = Foo + Bar; + + fn baz(x: T) -> U where T: Baz { + x.foo().bar() + } + + impl Foo for Field { + fn foo(self) -> Self { + self + 1 + } + } + + impl Bar for Field { + fn bar(self) -> bool { + true + } + } + + fn main() { + assert(0.foo().bar() == baz(0)); + }"#; + + assert_no_errors(src); +} + +// TODO(https://github.com/noir-lang/noir/issues/6467): currently fails with the +// same errors as the desugared version +#[test] +fn trait_alias_polymorphic_where_clause() { + let src = r#" + trait Foo { + fn foo(self) -> Self; + } + + trait Bar { + fn bar(self) -> T; + } + + trait Baz { + fn baz(self) -> bool; + } + + trait Qux = Foo + Bar where T: Baz; + + fn qux(x: T) -> bool where T: Qux { + x.foo().bar().baz() + } + + impl Foo for Field { + fn foo(self) -> Self { + self + 1 + } + } + + impl Bar for Field { + fn bar(self) -> bool { + true + } + } + + impl Baz for bool { + fn baz(self) -> bool { + self + } + } + + fn main() { + assert(0.foo().bar().baz() == qux(0)); + } + "#; + + // TODO(https://github.com/noir-lang/noir/issues/6467) + // assert_no_errors(src); + let errors = get_program_errors(src); + assert_eq!(errors.len(), 2); + + match &errors[0].0 { + CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { + method_name, .. + }) => { + assert_eq!(method_name, "baz"); + } + other => { + panic!("expected UnresolvedMethodCall, but found {:?}", other); + } + } + + match &errors[1].0 { + CompilationError::TypeError(TypeCheckError::NoMatchingImplFound(err)) => { + assert_eq!(err.constraints.len(), 2); + assert_eq!(err.constraints[0].1, "Baz"); + assert_eq!(err.constraints[1].1, "Qux<_>"); + } + other => { + panic!("expected NoMatchingImplFound, but found {:?}", other); + } + } +} + +// TODO(https://github.com/noir-lang/noir/issues/6467): currently failing, so +// this just tests that the trait alias has an equivalent error to the expected +// desugared version +#[test] +fn trait_alias_with_where_clause_has_equivalent_errors() { + let src = r#" + trait Bar { + fn bar(self) -> Self; + } + + trait Baz { + fn baz(self) -> bool; + } + + trait Qux: Bar where T: Baz {} + + impl Qux for U where + U: Bar, + T: Baz, + {} + + pub fn qux(x: T, _: U) -> bool where U: Qux { + x.baz() + } + + fn main() {} + "#; + + let alias_src = r#" + trait Bar { + fn bar(self) -> Self; + } + + trait Baz { + fn baz(self) -> bool; + } + + trait Qux = Bar where T: Baz; + + pub fn qux(x: T, _: U) -> bool where U: Qux { + x.baz() + } + + fn main() {} + "#; + + let errors = get_program_errors(src); + let alias_errors = get_program_errors(alias_src); + + assert_eq!(errors.len(), 1); + assert_eq!(alias_errors.len(), 1); + + match (&errors[0].0, &alias_errors[0].0) { + ( + CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { + method_name, + object_type, + .. + }), + CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { + method_name: alias_method_name, + object_type: alias_object_type, + .. + }), + ) => { + assert_eq!(method_name, alias_method_name); + assert_eq!(object_type, alias_object_type); + } + other => { + panic!("expected UnresolvedMethodCall, but found {:?}", other); + } + } +} + #[test] fn removes_assumed_parent_traits_after_function_ends() { let src = r#" diff --git a/docs/docs/noir/concepts/traits.md b/docs/docs/noir/concepts/traits.md index 9da00a77587..b6c0a886eb0 100644 --- a/docs/docs/noir/concepts/traits.md +++ b/docs/docs/noir/concepts/traits.md @@ -490,12 +490,95 @@ trait CompSciStudent: Programmer + Student { } ``` +### Trait Aliases + +Similar to the proposed Rust feature for [trait aliases](https://github.com/rust-lang/rust/blob/4d215e2426d52ca8d1af166d5f6b5e172afbff67/src/doc/unstable-book/src/language-features/trait-alias.md), +Noir supports aliasing one or more traits and using those aliases wherever +traits would normally be used. + +```rust +trait Foo { + fn foo(self) -> Self; +} + +trait Bar { + fn bar(self) -> Self; +} + +// Equivalent to: +// trait Baz: Foo + Bar {} +// +// impl Baz for T where T: Foo + Bar {} +trait Baz = Foo + Bar; + +// We can use `Baz` to refer to `Foo + Bar` +fn baz(x: T) -> T where T: Baz { + x.foo().bar() +} +``` + +#### Generic Trait Aliases + +Trait aliases can also be generic by placing the generic arguments after the +trait name. These generics are in scope of every item within the trait alias. + +```rust +trait Foo { + fn foo(self) -> Self; +} + +trait Bar { + fn bar(self) -> T; +} + +// Equivalent to: +// trait Baz: Foo + Bar {} +// +// impl Baz for U where U: Foo + Bar {} +trait Baz = Foo + Bar; +``` + +#### Trait Alias Where Clauses + +Trait aliases support where clauses to add trait constraints to any of their +generic arguments, e.g. ensuring `T: Baz` for a trait alias `Qux`. + +```rust +trait Foo { + fn foo(self) -> Self; +} + +trait Bar { + fn bar(self) -> T; +} + +trait Baz { + fn baz(self) -> bool; +} + +// Equivalent to: +// trait Qux: Foo + Bar where T: Baz {} +// +// impl Qux for U where +// U: Foo + Bar, +// T: Baz, +// {} +trait Qux = Foo + Bar where T: Baz; +``` + +Note that while trait aliases support where clauses, +the equivalent traits can fail due to [#6467](https://github.com/noir-lang/noir/issues/6467) + ### Visibility -By default, like functions, traits are private to the module they exist in. You can use `pub` -to make the trait public or `pub(crate)` to make it public to just its crate: +By default, like functions, traits and trait aliases are private to the module +they exist in. You can use `pub` to make the trait public or `pub(crate)` to make +it public to just its crate: ```rust // This trait is now public pub trait Trait {} -``` \ No newline at end of file + +// This trait alias is now public +pub trait Baz = Foo + Bar; +``` diff --git a/test_programs/execution_success/brillig_unitialised_arrays/Nargo.toml b/test_programs/execution_success/brillig_uninitialized_arrays/Nargo.toml similarity index 58% rename from test_programs/execution_success/brillig_unitialised_arrays/Nargo.toml rename to test_programs/execution_success/brillig_uninitialized_arrays/Nargo.toml index f23ecc787d0..68bcf9929cc 100644 --- a/test_programs/execution_success/brillig_unitialised_arrays/Nargo.toml +++ b/test_programs/execution_success/brillig_uninitialized_arrays/Nargo.toml @@ -1,5 +1,5 @@ [package] -name = "brillig_unitialised_arrays" +name = "brillig_uninitialized_arrays" type = "bin" authors = [""] diff --git a/test_programs/execution_success/brillig_unitialised_arrays/Prover.toml b/test_programs/execution_success/brillig_uninitialized_arrays/Prover.toml similarity index 100% rename from test_programs/execution_success/brillig_unitialised_arrays/Prover.toml rename to test_programs/execution_success/brillig_uninitialized_arrays/Prover.toml diff --git a/test_programs/execution_success/brillig_unitialised_arrays/src/main.nr b/test_programs/execution_success/brillig_uninitialized_arrays/src/main.nr similarity index 100% rename from test_programs/execution_success/brillig_unitialised_arrays/src/main.nr rename to test_programs/execution_success/brillig_uninitialized_arrays/src/main.nr diff --git a/tooling/nargo_fmt/build.rs b/tooling/nargo_fmt/build.rs index 47bb375f7d1..bd2db5f5b18 100644 --- a/tooling/nargo_fmt/build.rs +++ b/tooling/nargo_fmt/build.rs @@ -47,7 +47,10 @@ fn generate_formatter_tests(test_file: &mut File, test_data_dir: &Path) { .join("\n"); let output_source_path = outputs_dir.join(file_name).display().to_string(); - let output_source = std::fs::read_to_string(output_source_path.clone()).unwrap(); + let output_source = + std::fs::read_to_string(output_source_path.clone()).unwrap_or_else(|_| { + panic!("expected output source at {:?} was not found", &output_source_path) + }); write!( test_file, diff --git a/tooling/nargo_fmt/src/formatter/trait_impl.rs b/tooling/nargo_fmt/src/formatter/trait_impl.rs index 73d9a61b3d4..b31da8a4101 100644 --- a/tooling/nargo_fmt/src/formatter/trait_impl.rs +++ b/tooling/nargo_fmt/src/formatter/trait_impl.rs @@ -10,6 +10,11 @@ use super::Formatter; impl<'a> Formatter<'a> { pub(super) fn format_trait_impl(&mut self, trait_impl: NoirTraitImpl) { + // skip synthetic trait impl's, e.g. generated from trait aliases + if trait_impl.is_synthetic { + return; + } + let has_where_clause = !trait_impl.where_clause.is_empty(); self.write_indentation(); diff --git a/tooling/nargo_fmt/src/formatter/traits.rs b/tooling/nargo_fmt/src/formatter/traits.rs index 9a6b84c6537..1f192be471e 100644 --- a/tooling/nargo_fmt/src/formatter/traits.rs +++ b/tooling/nargo_fmt/src/formatter/traits.rs @@ -15,9 +15,18 @@ impl<'a> Formatter<'a> { self.write_identifier(noir_trait.name); self.format_generics(noir_trait.generics); + if noir_trait.is_alias { + self.write_space(); + self.write_token(Token::Assign); + } + if !noir_trait.bounds.is_empty() { self.skip_comments_and_whitespace(); - self.write_token(Token::Colon); + + if !noir_trait.is_alias { + self.write_token(Token::Colon); + } + self.write_space(); for (index, trait_bound) in noir_trait.bounds.into_iter().enumerate() { @@ -34,6 +43,12 @@ impl<'a> Formatter<'a> { self.format_where_clause(noir_trait.where_clause, true); } + // aliases have ';' in lieu of '{ items }' + if noir_trait.is_alias { + self.write_semicolon(); + return; + } + self.write_space(); self.write_left_brace(); if noir_trait.items.is_empty() { diff --git a/tooling/nargo_fmt/tests/expected/trait_alias.nr b/tooling/nargo_fmt/tests/expected/trait_alias.nr new file mode 100644 index 00000000000..926f3160279 --- /dev/null +++ b/tooling/nargo_fmt/tests/expected/trait_alias.nr @@ -0,0 +1,86 @@ +trait Foo { + fn foo(self) -> Self; +} + +trait Baz = Foo; + +impl Foo for Field { + fn foo(self) -> Self { + self + } +} + +fn baz(x: T) -> T +where + T: Baz, +{ + x.foo() +} + +pub trait Foo_2 { + fn foo_2(self) -> Self; +} + +pub trait Bar_2 { + fn bar_2(self) -> Self; +} + +pub trait Baz_2 = Foo_2 + Bar_2; + +fn baz_2(x: T) -> T +where + T: Baz_2, +{ + x.foo_2().bar_2() +} + +impl Foo_2 for Field { + fn foo_2(self) -> Self { + self + 1 + } +} + +impl Bar_2 for Field { + fn bar_2(self) -> Self { + self + 2 + } +} + +trait Foo_3 { + fn foo_3(self) -> Self; +} + +trait Bar_3 { + fn bar_3(self) -> T; +} + +trait Baz_3 = Foo_3 + Bar_3; + +fn baz_3(x: T) -> U +where + T: Baz_3, +{ + x.foo_3().bar_3() +} + +impl Foo_3 for Field { + fn foo_3(self) -> Self { + self + 1 + } +} + +impl Bar_3 for Field { + fn bar_3(self) -> bool { + true + } +} + +fn main() { + let x: Field = 0; + let _ = baz(x); + + assert(0.foo_2().bar_2() == baz_2(0)); + + assert(0.foo_3().bar_3() == baz_3(0)); +} + diff --git a/tooling/nargo_fmt/tests/input/trait_alias.nr b/tooling/nargo_fmt/tests/input/trait_alias.nr new file mode 100644 index 00000000000..53ae756795b --- /dev/null +++ b/tooling/nargo_fmt/tests/input/trait_alias.nr @@ -0,0 +1,78 @@ +trait Foo { + fn foo(self) -> Self; +} + +trait Baz = Foo; + +impl Foo for Field { + fn foo(self) -> Self { self } +} + +fn baz(x: T) -> T where T: Baz { + x.foo() +} + + +pub trait Foo_2 { + fn foo_2(self) -> Self; +} + +pub trait Bar_2 { + fn bar_2(self) -> Self; +} + +pub trait Baz_2 = Foo_2 + Bar_2; + +fn baz_2(x: T) -> T where T: Baz_2 { + x.foo_2().bar_2() +} + +impl Foo_2 for Field { + fn foo_2(self) -> Self { + self + 1 + } +} + +impl Bar_2 for Field { + fn bar_2(self) -> Self { + self + 2 + } +} + + +trait Foo_3 { + fn foo_3(self) -> Self; +} + +trait Bar_3 { + fn bar_3(self) -> T; +} + +trait Baz_3 = Foo_3 + Bar_3; + +fn baz_3(x: T) -> U where T: Baz_3 { + x.foo_3().bar_3() +} + +impl Foo_3 for Field { + fn foo_3(self) -> Self { + self + 1 + } +} + +impl Bar_3 for Field { + fn bar_3(self) -> bool { + true + } +} + + +fn main() { + let x: Field = 0; + let _ = baz(x); + + assert(0.foo_2().bar_2() == baz_2(0)); + + assert(0.foo_3().bar_3() == baz_3(0)); +} + diff --git a/tooling/nargo_toml/src/errors.rs b/tooling/nargo_toml/src/errors.rs index 1ee8e90c8e5..7e1003d04f7 100644 --- a/tooling/nargo_toml/src/errors.rs +++ b/tooling/nargo_toml/src/errors.rs @@ -80,6 +80,8 @@ pub enum ManifestError { #[allow(clippy::enum_variant_names)] #[derive(Error, Debug, PartialEq, Eq, Clone)] pub enum SemverError { + #[error("Invalid value for `compiler_version` in package {package_name}. Requirements may only refer to full releases")] + InvalidCompilerVersionRequirement { package_name: CrateName, required_compiler_version: String }, #[error("Incompatible compiler version in package {package_name}. Required compiler version is {required_compiler_version} but the compiler version is {compiler_version_found}.\n Update the compiler_version field in Nargo.toml to >={required_compiler_version} or compile this project with version {required_compiler_version}")] IncompatibleVersion { package_name: CrateName, diff --git a/tooling/nargo_toml/src/semver.rs b/tooling/nargo_toml/src/semver.rs index 253ac82aa34..ececa1b30dd 100644 --- a/tooling/nargo_toml/src/semver.rs +++ b/tooling/nargo_toml/src/semver.rs @@ -3,11 +3,14 @@ use nargo::{ package::{Dependency, Package}, workspace::Workspace, }; -use semver::{Error, Version, VersionReq}; +use noirc_driver::CrateName; +use semver::{Error, Prerelease, Version, VersionReq}; // Parse a semver compatible version string pub(crate) fn parse_semver_compatible_version(version: &str) -> Result { - Version::parse(version) + let mut version = Version::parse(version)?; + version.pre = Prerelease::EMPTY; + Ok(version) } // Check that all of the packages in the workspace are compatible with the current compiler version @@ -25,10 +28,7 @@ pub(crate) fn semver_check_workspace( } // Check that a package and all of its dependencies are compatible with the current compiler version -pub(crate) fn semver_check_package( - package: &Package, - compiler_version: &Version, -) -> Result<(), SemverError> { +fn semver_check_package(package: &Package, compiler_version: &Version) -> Result<(), SemverError> { // Check that this package's compiler version requirements are satisfied if let Some(version) = &package.compiler_required_version { let version_req = match VersionReq::parse(version) { @@ -40,6 +40,9 @@ pub(crate) fn semver_check_package( }) } }; + + validate_compiler_version_requirement(&package.name, &version_req)?; + if !version_req.matches(compiler_version) { return Err(SemverError::IncompatibleVersion { package_name: package.name.clone(), @@ -61,6 +64,20 @@ pub(crate) fn semver_check_package( Ok(()) } +fn validate_compiler_version_requirement( + package_name: &CrateName, + required_compiler_version: &VersionReq, +) -> Result<(), SemverError> { + if required_compiler_version.comparators.iter().any(|comparator| !comparator.pre.is_empty()) { + return Err(SemverError::InvalidCompilerVersionRequirement { + package_name: package_name.clone(), + required_compiler_version: required_compiler_version.to_string(), + }); + } + + Ok(()) +} + // Strip the build meta data from the version string since it is ignored by semver. fn strip_build_meta_data(version: &Version) -> String { let version_string = version.to_string(); @@ -191,6 +208,26 @@ mod tests { }; } + #[test] + fn test_semver_prerelease() { + let compiler_version = parse_semver_compatible_version("1.0.0-beta.0").unwrap(); + + let package = Package { + compiler_required_version: Some(">=0.1.0".to_string()), + root_dir: PathBuf::new(), + package_type: PackageType::Library, + entry_path: PathBuf::new(), + name: CrateName::from_str("test").unwrap(), + dependencies: BTreeMap::new(), + version: Some("1.0".to_string()), + expression_width: None, + }; + + if let Err(err) = semver_check_package(&package, &compiler_version) { + panic!("{err}"); + }; + } + #[test] fn test_semver_build_data() { let compiler_version = Version::parse("0.1.0+this-is-ignored-by-semver").unwrap(); diff --git a/tooling/noir_js_types/src/types.ts b/tooling/noir_js_types/src/types.ts index 03e2a93f851..0cfe5d05a17 100644 --- a/tooling/noir_js_types/src/types.ts +++ b/tooling/noir_js_types/src/types.ts @@ -1,5 +1,5 @@ export type Field = string | number | boolean; -export type InputValue = Field | InputMap | (Field | InputMap)[]; +export type InputValue = Field | InputMap | InputValue[]; export type InputMap = { [key: string]: InputValue }; export type Visibility = 'public' | 'private' | 'databus';