From 05dc3433ceeb3a395673b9b8431cfdbdc762249f Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 4 Feb 2025 16:06:21 -0300 Subject: [PATCH 1/2] feat: infer lambda parameter types from return type and let type (#7267) --- .../src/elaborator/expressions.rs | 105 +++++++++--- compiler/noirc_frontend/src/elaborator/mod.rs | 3 +- .../src/elaborator/statements.rs | 30 +++- compiler/noirc_frontend/src/tests.rs | 153 ++++++++++++++++++ 4 files changed, 260 insertions(+), 31 deletions(-) diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 16278995104..4c3fb5a7616 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -37,9 +37,17 @@ use super::{Elaborator, LambdaContext, UnsafeBlockStatus}; impl<'context> Elaborator<'context> { pub(crate) fn elaborate_expression(&mut self, expr: Expression) -> (ExprId, Type) { + self.elaborate_expression_with_target_type(expr, None) + } + + pub(crate) fn elaborate_expression_with_target_type( + &mut self, + expr: Expression, + target_type: Option<&Type>, + ) -> (ExprId, Type) { let (hir_expr, typ) = match expr.kind { ExpressionKind::Literal(literal) => self.elaborate_literal(literal, expr.span), - ExpressionKind::Block(block) => self.elaborate_block(block), + ExpressionKind::Block(block) => self.elaborate_block(block, target_type), ExpressionKind::Prefix(prefix) => return self.elaborate_prefix(*prefix, expr.span), ExpressionKind::Index(index) => self.elaborate_index(*index), ExpressionKind::Call(call) => self.elaborate_call(*call, expr.span), @@ -50,18 +58,22 @@ impl<'context> Elaborator<'context> { } ExpressionKind::Cast(cast) => self.elaborate_cast(*cast, expr.span), ExpressionKind::Infix(infix) => return self.elaborate_infix(*infix, expr.span), - ExpressionKind::If(if_) => self.elaborate_if(*if_), + ExpressionKind::If(if_) => self.elaborate_if(*if_, target_type), ExpressionKind::Match(match_) => self.elaborate_match(*match_), ExpressionKind::Variable(variable) => return self.elaborate_variable(variable), - ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple), - ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda, None), - ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr), + ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple, target_type), + ExpressionKind::Lambda(lambda) => { + self.elaborate_lambda_with_target_type(*lambda, target_type) + } + ExpressionKind::Parenthesized(expr) => { + return self.elaborate_expression_with_target_type(*expr, target_type) + } ExpressionKind::Quote(quote) => self.elaborate_quote(quote, expr.span), ExpressionKind::Comptime(comptime, _) => { - return self.elaborate_comptime_block(comptime, expr.span) + return self.elaborate_comptime_block(comptime, expr.span, target_type) } ExpressionKind::Unsafe(block_expression, span) => { - self.elaborate_unsafe_block(block_expression, span) + self.elaborate_unsafe_block(block_expression, span, target_type) } ExpressionKind::Resolved(id) => return (id, self.interner.id_type(id)), ExpressionKind::Interned(id) => { @@ -112,18 +124,29 @@ impl<'context> Elaborator<'context> { } } - pub(super) fn elaborate_block(&mut self, block: BlockExpression) -> (HirExpression, Type) { - let (block, typ) = self.elaborate_block_expression(block); + pub(super) fn elaborate_block( + &mut self, + block: BlockExpression, + target_type: Option<&Type>, + ) -> (HirExpression, Type) { + let (block, typ) = self.elaborate_block_expression(block, target_type); (HirExpression::Block(block), typ) } - fn elaborate_block_expression(&mut self, block: BlockExpression) -> (HirBlockExpression, Type) { + fn elaborate_block_expression( + &mut self, + block: BlockExpression, + target_type: Option<&Type>, + ) -> (HirBlockExpression, Type) { self.push_scope(); let mut block_type = Type::Unit; - let mut statements = Vec::with_capacity(block.statements.len()); + let statements_len = block.statements.len(); + let mut statements = Vec::with_capacity(statements_len); for (i, statement) in block.statements.into_iter().enumerate() { - let (id, stmt_type) = self.elaborate_statement(statement); + let statement_target_type = if i == statements_len - 1 { target_type } else { None }; + let (id, stmt_type) = + self.elaborate_statement_with_target_type(statement, statement_target_type); statements.push(id); if let HirStatement::Semi(expr) = self.interner.statement(&id) { @@ -149,6 +172,7 @@ impl<'context> Elaborator<'context> { &mut self, block: BlockExpression, span: Span, + target_type: Option<&Type>, ) -> (HirExpression, Type) { // Before entering the block we cache the old value of `in_unsafe_block` so it can be restored. let old_in_unsafe_block = self.unsafe_block_status; @@ -161,7 +185,7 @@ impl<'context> Elaborator<'context> { self.unsafe_block_status = UnsafeBlockStatus::InUnsafeBlockWithoutUnconstrainedCalls; - let (hir_block_expression, typ) = self.elaborate_block_expression(block); + let (hir_block_expression, typ) = self.elaborate_block_expression(block, target_type); if let UnsafeBlockStatus::InUnsafeBlockWithoutUnconstrainedCalls = self.unsafe_block_status { @@ -572,7 +596,7 @@ impl<'context> Elaborator<'context> { let span = arg.span; let type_hint = if let Some(Type::Function(func_args, _, _, _)) = typ { Some(func_args) } else { None }; - let (hir_expr, typ) = self.elaborate_lambda(*lambda, type_hint); + let (hir_expr, typ) = self.elaborate_lambda_with_parameter_type_hints(*lambda, type_hint); let id = self.interner.push_expr(hir_expr); self.interner.push_expr_location(id, span, self.file); self.interner.push_expr_type(id, typ.clone()); @@ -884,10 +908,15 @@ impl<'context> Elaborator<'context> { } } - fn elaborate_if(&mut self, if_expr: IfExpression) -> (HirExpression, Type) { + fn elaborate_if( + &mut self, + if_expr: IfExpression, + target_type: Option<&Type>, + ) -> (HirExpression, Type) { let expr_span = if_expr.condition.span; let (condition, cond_type) = self.elaborate_expression(if_expr.condition); - let (consequence, mut ret_type) = self.elaborate_expression(if_expr.consequence); + let (consequence, mut ret_type) = + self.elaborate_expression_with_target_type(if_expr.consequence, target_type); self.unify(&cond_type, &Type::Bool, || TypeCheckError::TypeMismatch { expected_typ: Type::Bool.to_string(), @@ -897,7 +926,8 @@ impl<'context> Elaborator<'context> { let alternative = if_expr.alternative.map(|alternative| { let expr_span = alternative.span; - let (else_, else_type) = self.elaborate_expression(alternative); + let (else_, else_type) = + self.elaborate_expression_with_target_type(alternative, target_type); self.unify(&ret_type, &else_type, || { let err = TypeCheckError::TypeMismatch { @@ -931,12 +961,19 @@ impl<'context> Elaborator<'context> { (HirExpression::Error, Type::Error) } - fn elaborate_tuple(&mut self, tuple: Vec) -> (HirExpression, Type) { + fn elaborate_tuple( + &mut self, + tuple: Vec, + target_type: Option<&Type>, + ) -> (HirExpression, Type) { let mut element_ids = Vec::with_capacity(tuple.len()); let mut element_types = Vec::with_capacity(tuple.len()); - for element in tuple { - let (id, typ) = self.elaborate_expression(element); + for (index, element) in tuple.into_iter().enumerate() { + let target_type = target_type.map(|typ| typ.follow_bindings()); + let expr_target_type = + if let Some(Type::Tuple(types)) = &target_type { types.get(index) } else { None }; + let (id, typ) = self.elaborate_expression_with_target_type(element, expr_target_type); element_ids.push(id); element_types.push(typ); } @@ -944,10 +981,24 @@ impl<'context> Elaborator<'context> { (HirExpression::Tuple(element_ids), Type::Tuple(element_types)) } + fn elaborate_lambda_with_target_type( + &mut self, + lambda: Lambda, + target_type: Option<&Type>, + ) -> (HirExpression, Type) { + let target_type = target_type.map(|typ| typ.follow_bindings()); + + if let Some(Type::Function(args, _, _, _)) = target_type { + return self.elaborate_lambda_with_parameter_type_hints(lambda, Some(&args)); + } + + self.elaborate_lambda_with_parameter_type_hints(lambda, None) + } + /// For elaborating a lambda we might get `parameters_type_hints`. These come from a potential /// call that has this lambda as the argument. /// The parameter type hints will be the types of the function type corresponding to the lambda argument. - fn elaborate_lambda( + fn elaborate_lambda_with_parameter_type_hints( &mut self, lambda: Lambda, parameters_type_hints: Option<&Vec>, @@ -1013,9 +1064,15 @@ impl<'context> Elaborator<'context> { } } - fn elaborate_comptime_block(&mut self, block: BlockExpression, span: Span) -> (ExprId, Type) { - let (block, _typ) = - self.elaborate_in_comptime_context(|this| this.elaborate_block_expression(block)); + fn elaborate_comptime_block( + &mut self, + block: BlockExpression, + span: Span, + target_type: Option<&Type>, + ) -> (ExprId, Type) { + let (block, _typ) = self.elaborate_in_comptime_context(|this| { + this.elaborate_block_expression(block, target_type) + }); let mut interpreter = self.setup_interpreter(); let value = interpreter.evaluate_block(block); diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index c895f87ef88..a8e722a9205 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -500,7 +500,8 @@ impl<'context> Elaborator<'context> { | FunctionKind::Oracle | FunctionKind::TraitFunctionWithoutBody => (HirFunction::empty(), Type::Error), FunctionKind::Normal => { - let (block, body_type) = self.elaborate_block(body); + let return_type = func_meta.return_type(); + let (block, body_type) = self.elaborate_block(body, Some(return_type)); let expr_id = self.intern_expr(block, body_span); self.interner.push_expr_type(expr_id, body_type.clone()); (HirFunction::unchecked_from_expr(expr_id), body_type) diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index a95e260b6a5..b17052d01ef 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -28,6 +28,14 @@ use super::{lints, Elaborator, Loop}; impl<'context> Elaborator<'context> { fn elaborate_statement_value(&mut self, statement: Statement) -> (HirStatement, Type) { + self.elaborate_statement_value_with_target_type(statement, None) + } + + fn elaborate_statement_value_with_target_type( + &mut self, + statement: Statement, + target_type: Option<&Type>, + ) -> (HirStatement, Type) { match statement.kind { StatementKind::Let(let_stmt) => self.elaborate_local_let(let_stmt), StatementKind::Constrain(constrain) => self.elaborate_constrain(constrain), @@ -38,7 +46,7 @@ impl<'context> Elaborator<'context> { StatementKind::Continue => self.elaborate_jump(false, statement.span), StatementKind::Comptime(statement) => self.elaborate_comptime_statement(*statement), StatementKind::Expression(expr) => { - let (expr, typ) = self.elaborate_expression(expr); + let (expr, typ) = self.elaborate_expression_with_target_type(expr, target_type); (HirStatement::Expression(expr), typ) } StatementKind::Semi(expr) => { @@ -48,15 +56,24 @@ impl<'context> Elaborator<'context> { StatementKind::Interned(id) => { let kind = self.interner.get_statement_kind(id); let statement = Statement { kind: kind.clone(), span: statement.span }; - self.elaborate_statement_value(statement) + self.elaborate_statement_value_with_target_type(statement, target_type) } StatementKind::Error => (HirStatement::Error, Type::Error), } } pub(crate) fn elaborate_statement(&mut self, statement: Statement) -> (StmtId, Type) { + self.elaborate_statement_with_target_type(statement, None) + } + + pub(crate) fn elaborate_statement_with_target_type( + &mut self, + statement: Statement, + target_type: Option<&Type>, + ) -> (StmtId, Type) { let span = statement.span; - let (hir_statement, typ) = self.elaborate_statement_value(statement); + let (hir_statement, typ) = + self.elaborate_statement_value_with_target_type(statement, target_type); let id = self.interner.push_stmt(hir_statement); self.interner.push_stmt_location(id, span, self.file); (id, typ) @@ -75,12 +92,13 @@ impl<'context> Elaborator<'context> { let_stmt: LetStatement, global_id: Option, ) -> (HirStatement, Type) { - let expr_span = let_stmt.expression.span; - let (expression, expr_type) = self.elaborate_expression(let_stmt.expression); - let type_contains_unspecified = let_stmt.r#type.contains_unspecified(); let annotated_type = self.resolve_inferred_type(let_stmt.r#type); + let expr_span = let_stmt.expression.span; + let (expression, expr_type) = + self.elaborate_expression_with_target_type(let_stmt.expression, Some(&annotated_type)); + // Require the top-level of a global's type to be fully-specified if type_contains_unspecified && global_id.is_some() { let span = expr_span; diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index b7723ce4242..ed6321dbe50 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -4053,6 +4053,159 @@ fn infers_lambda_argument_from_call_function_type_in_generic_call() { assert_no_errors(src); } +#[test] +fn infers_lambda_argument_from_call_function_type_as_alias() { + let src = r#" + struct Foo { + value: Field, + } + + type MyFn = fn(Foo) -> Field; + + fn call(f: MyFn) -> Field { + f(Foo { value: 1 }) + } + + fn main() { + let _ = call(|foo| foo.value); + } + "#; + assert_no_errors(src); +} + +#[test] +fn infers_lambda_argument_from_function_return_type() { + let src = r#" + pub struct Foo { + value: Field, + } + + pub fn func() -> fn(Foo) -> Field { + |foo| foo.value + } + + fn main() { + } + "#; + assert_no_errors(src); +} + +#[test] +fn infers_lambda_argument_from_function_return_type_multiple_statements() { + let src = r#" + pub struct Foo { + value: Field, + } + + pub fn func() -> fn(Foo) -> Field { + let _ = 1; + |foo| foo.value + } + + fn main() { + } + "#; + assert_no_errors(src); +} + +#[test] +fn infers_lambda_argument_from_function_return_type_when_inside_if() { + let src = r#" + pub struct Foo { + value: Field, + } + + pub fn func() -> fn(Foo) -> Field { + if true { + |foo| foo.value + } else { + |foo| foo.value + } + } + + fn main() { + } + "#; + assert_no_errors(src); +} + +#[test] +fn infers_lambda_argument_from_variable_type() { + let src = r#" + pub struct Foo { + value: Field, + } + + fn main() { + let _: fn(Foo) -> Field = |foo| foo.value; + } + "#; + assert_no_errors(src); +} + +#[test] +fn infers_lambda_argument_from_variable_alias_type() { + let src = r#" + pub struct Foo { + value: Field, + } + + type FooFn = fn(Foo) -> Field; + + fn main() { + let _: FooFn = |foo| foo.value; + } + "#; + assert_no_errors(src); +} + +#[test] +fn infers_lambda_argument_from_variable_double_alias_type() { + let src = r#" + pub struct Foo { + value: Field, + } + + type FooFn = fn(Foo) -> Field; + type FooFn2 = FooFn; + + fn main() { + let _: FooFn2 = |foo| foo.value; + } + "#; + assert_no_errors(src); +} + +#[test] +fn infers_lambda_argument_from_variable_tuple_type() { + let src = r#" + pub struct Foo { + value: Field, + } + + fn main() { + let _: (fn(Foo) -> Field, _) = (|foo| foo.value, 1); + } + "#; + assert_no_errors(src); +} + +#[test] +fn infers_lambda_argument_from_variable_tuple_type_aliased() { + let src = r#" + pub struct Foo { + value: Field, + } + + type Alias = (fn(Foo) -> Field, Field); + + fn main() { + let _: Alias = (|foo| foo.value, 1); + } + "#; + assert_no_errors(src); +} + #[test] fn regression_7088() { // A test for code that initially broke when implementing inferring From 3a42eb5c68f9616f0ebe367c894f0376ba41e0ef Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Tue, 4 Feb 2025 19:32:34 +0000 Subject: [PATCH 2/2] chore: add sha256 library to test suite (#7278) --- .../critical_libraries_status/noir-lang/sha256/.failures.jsonl | 0 EXTERNAL_NOIR_LIBRARIES.yml | 3 +++ 2 files changed, 3 insertions(+) create mode 100644 .github/critical_libraries_status/noir-lang/sha256/.failures.jsonl diff --git a/.github/critical_libraries_status/noir-lang/sha256/.failures.jsonl b/.github/critical_libraries_status/noir-lang/sha256/.failures.jsonl new file mode 100644 index 00000000000..e69de29bb2d diff --git a/EXTERNAL_NOIR_LIBRARIES.yml b/EXTERNAL_NOIR_LIBRARIES.yml index f9b4d6cc07d..5e95d435f3e 100644 --- a/EXTERNAL_NOIR_LIBRARIES.yml +++ b/EXTERNAL_NOIR_LIBRARIES.yml @@ -40,6 +40,9 @@ libraries: noir_json_parser: repo: noir-lang/noir_json_parser timeout: 10 + sha256: + repo: noir-lang/sha256 + timeout: 3 aztec_nr: repo: AztecProtocol/aztec-packages path: noir-projects/aztec-nr