From aaad921bba52e729dc24ece07fab2edf09ccfa15 Mon Sep 17 00:00:00 2001 From: TomerStarkware <144585788+TomerStarkware@users.noreply.github.com> Date: Thu, 22 Aug 2024 13:58:18 +0300 Subject: [PATCH] added support for return statement and error propagation inside closures (#6254) --- .../cairo-lang-semantic/src/expr/compute.rs | 102 ++++++----- .../src/expr/test_data/closure | 159 ++++++++++++++++++ 2 files changed, 223 insertions(+), 38 deletions(-) diff --git a/crates/cairo-lang-semantic/src/expr/compute.rs b/crates/cairo-lang-semantic/src/expr/compute.rs index 117b3a327ae..9ad16a032c4 100644 --- a/crates/cairo-lang-semantic/src/expr/compute.rs +++ b/crates/cairo-lang-semantic/src/expr/compute.rs @@ -119,15 +119,17 @@ pub enum ContextFunction { Function(Maybe), } -/// Context inside loops. +/// Context inside loops or closures. #[derive(Debug, Clone)] -enum LoopContext { +enum InnerContext { /// Context inside a `loop` Loop { type_merger: FlowMergeTypeHelper }, /// Context inside a `while` loop While, /// Context inside a `for` loop For, + /// Context inside a `closure` + Closure { return_type: TypeId }, } /// Context for computing the semantic model of expression trees. @@ -142,7 +144,7 @@ pub struct ComputationContext<'ctx> { function_id: ContextFunction, /// Definitions of semantic variables. pub semantic_defs: UnorderedHashMap, - loop_ctx: Option, + inner_ctx: Option, cfg_set: Arc, /// whether to look for closures when calling variables. /// TODO(TomerStarkware): Remove this once we disallow calling shadowed functions. @@ -169,7 +171,7 @@ impl<'ctx> ComputationContext<'ctx> { arenas: Default::default(), function_id, semantic_defs, - loop_ctx: None, + inner_ctx: None, cfg_set: db .crate_config(owning_crate_id) .and_then(|cfg| cfg.settings.cfg_set.map(Arc::new)) @@ -249,6 +251,13 @@ impl<'ctx> ComputationContext<'ctx> { self.resolver.inference().internal_rewrite(stmt).no_err(); } } + /// Returns whether the current context is inside a loop. + fn is_inside_loop(&self) -> bool { + match self.inner_ctx { + None | Some(InnerContext::Closure { .. }) => false, + Some(InnerContext::Loop { .. } | InnerContext::While | InnerContext::For) => true, + } + } } // TODO(ilya): Change value to VarId. @@ -1216,7 +1225,7 @@ fn compute_arm_semantic( let ast::Expr::Block(arm_expr_syntax) = arm_expr_syntax else { unreachable!("Expected a block expression for a loop arm."); }; - let (id, _) = compute_loop_body_semantic(new_ctx, arm_expr_syntax, LoopContext::While); + let (id, _) = compute_loop_body_semantic(new_ctx, arm_expr_syntax, InnerContext::While); let expr = new_ctx.arenas.exprs[id].clone(); ExprAndId { expr, id } } else { @@ -1349,15 +1358,15 @@ fn compute_expr_loop_semantic( let db = ctx.db; let syntax_db = db.upcast(); - let (body, loop_ctx) = compute_loop_body_semantic( + let (body, inner_ctx) = compute_loop_body_semantic( ctx, syntax.body(syntax_db), - LoopContext::Loop { type_merger: FlowMergeTypeHelper::new(db, MultiArmExprKind::Loop) }, + InnerContext::Loop { type_merger: FlowMergeTypeHelper::new(db, MultiArmExprKind::Loop) }, ); Ok(Expr::Loop(ExprLoop { body, - ty: match loop_ctx { - LoopContext::Loop { type_merger, .. } => type_merger.get_final_type(), + ty: match inner_ctx { + InnerContext::Loop { type_merger, .. } => type_merger.get_final_type(), _ => unreachable!("Expected loop context"), }, stable_ptr: syntax.stable_ptr().into(), @@ -1390,8 +1399,8 @@ fn compute_expr_while_semantic( (Condition::Let(expr.id, patterns.iter().map(|pattern| pattern.id).collect()), body.id) } ast::Condition::Expr(expr) => { - let (body, _loop_ctx) = - compute_loop_body_semantic(ctx, syntax.body(syntax_db), LoopContext::While); + let (body, _inner_ctx) = + compute_loop_body_semantic(ctx, syntax.body(syntax_db), InnerContext::While); ( Condition::BoolExpr(compute_bool_condition_semantic(ctx, &expr.expr(syntax_db)).id), body, @@ -1511,8 +1520,8 @@ fn compute_expr_for_semantic( new_ctx.environment.variables.insert(v.name.clone(), var_def.clone()); new_ctx.semantic_defs.insert(var_def.id(), var_def); } - let (body, _loop_ctx) = - compute_loop_body_semantic(new_ctx, syntax.body(syntax_db), LoopContext::For); + let (body, _inner_ctx) = + compute_loop_body_semantic(new_ctx, syntax.body(syntax_db), InnerContext::For); (body, new_ctx.arenas.patterns.alloc(inner_pattern.pattern)) }); Ok(Expr::For(ExprFor { @@ -1531,13 +1540,13 @@ fn compute_expr_for_semantic( fn compute_loop_body_semantic( ctx: &mut ComputationContext<'_>, syntax: ast::ExprBlock, - loop_ctx: LoopContext, -) -> (ExprId, LoopContext) { + inner_ctx: InnerContext, +) -> (ExprId, InnerContext) { let db = ctx.db; let syntax_db = db.upcast(); ctx.run_in_subscope(|new_ctx| { - let old_loop_ctx = std::mem::replace(&mut new_ctx.loop_ctx, Some(loop_ctx)); + let old_inner_ctx = std::mem::replace(&mut new_ctx.inner_ctx, Some(inner_ctx)); let mut statements = syntax.statements(syntax_db).elements(syntax_db); // Remove the typed tail expression, if exists. @@ -1560,7 +1569,7 @@ fn compute_loop_body_semantic( } } - let loop_ctx = std::mem::replace(&mut new_ctx.loop_ctx, old_loop_ctx).unwrap(); + let inner_ctx = std::mem::replace(&mut new_ctx.inner_ctx, old_inner_ctx).unwrap(); let body = new_ctx.arenas.exprs.alloc(Expr::Block(ExprBlock { statements: statements_semantic, tail: tail.map(|tail| tail.id), @@ -1568,7 +1577,7 @@ fn compute_loop_body_semantic( stable_ptr: syntax.stable_ptr().into(), })); - (body, loop_ctx) + (body, inner_ctx) }) } @@ -1594,7 +1603,7 @@ fn compute_expr_closure_semantic( } else { vec![] }; - let ret_ty = match syntax.ret_ty(syntax_db) { + let return_type = match syntax.ret_ty(syntax_db) { OptionReturnTypeClause::ReturnTypeClause(ty_syntax) => resolve_type( new_ctx.db, new_ctx.diagnostics, @@ -1605,13 +1614,16 @@ fn compute_expr_closure_semantic( new_ctx.resolver.inference().new_type_var(Some(missing.stable_ptr().untyped())) } }; + let old_inner_ctx = + std::mem::replace(&mut new_ctx.inner_ctx, Some(InnerContext::Closure { return_type })); let body = match syntax.expr(syntax_db) { ast::Expr::Block(syntax) => compute_closure_body_semantic(new_ctx, syntax), _ => compute_expr_semantic(new_ctx, &syntax.expr(syntax_db)).id, }; + std::mem::replace(&mut new_ctx.inner_ctx, old_inner_ctx).unwrap(); let mut inference = new_ctx.resolver.inference(); if let Err((err_set, actual_ty, expected_ty)) = - inference.conform_ty_for_diag(new_ctx.arenas.exprs[body].ty(), ret_ty) + inference.conform_ty_for_diag(new_ctx.arenas.exprs[body].ty(), return_type) { let diag_added = new_ctx.diagnostics.report( syntax.expr(syntax_db).stable_ptr(), @@ -1619,7 +1631,7 @@ fn compute_expr_closure_semantic( ); inference.consume_reported_error(err_set, diag_added); } - (params, ret_ty, body) + (params, return_type, body) }); let parent_function = match ctx.function_id { ContextFunction::Global => Maybe::Err(ctx.diagnostics.report(syntax, ClosureInGlobalScope)), @@ -1704,9 +1716,18 @@ fn compute_expr_error_propagate_semantic( ) -> Maybe { let syntax_db = ctx.db.upcast(); - let func_signature = - ctx.get_signature(syntax.into(), UnsupportedOutsideOfFunctionFeatureName::ErrorPropagate)?; - let func_err_prop_ty = unwrap_error_propagation_type(ctx.db, func_signature.return_type) + let return_type = match ctx.inner_ctx { + Some(InnerContext::Closure { return_type }) => return_type, + None | Some(InnerContext::Loop { .. } | InnerContext::While | InnerContext::For) => { + ctx.get_signature( + syntax.into(), + UnsupportedOutsideOfFunctionFeatureName::ErrorPropagate, + )? + .return_type + } + }; + + let func_err_prop_ty = unwrap_error_propagation_type(ctx.db, return_type) .ok_or_else(|| ctx.diagnostics.report(syntax, ReturnTypeNotErrorPropagateType))?; // `inner_expr` is the expr inside the `?`. @@ -1729,7 +1750,7 @@ fn compute_expr_error_propagate_semantic( let inner_expr_err_variant = inner_expr_err_prop_ty.err_variant(); // Disallow error propagation inside a loop. - if ctx.loop_ctx.is_some() { + if ctx.is_inside_loop() { ctx.diagnostics.report(syntax, SemanticDiagnosticKind::ErrorPropagateNotAllowedInsideALoop); } @@ -1752,7 +1773,7 @@ fn compute_expr_error_propagate_semantic( ctx.diagnostics.report( syntax, IncompatibleErrorPropagateType { - return_ty: func_signature.return_type, + return_ty: return_type, err_ty: inner_expr_err_variant.ty, }, ); @@ -3315,7 +3336,7 @@ pub fn compute_statement_semantic( }) } ast::Statement::Continue(continue_syntax) => { - if ctx.loop_ctx.is_none() { + if !ctx.is_inside_loop() { return Err(ctx .diagnostics .report(continue_syntax, ContinueOnlyAllowedInsideALoop)); @@ -3325,7 +3346,7 @@ pub fn compute_statement_semantic( }) } ast::Statement::Return(return_syntax) => { - if ctx.loop_ctx.is_some() { + if ctx.is_inside_loop() { return Err(ctx.diagnostics.report(return_syntax, ReturnNotAllowedInsideALoop)); } @@ -3339,12 +3360,17 @@ pub fn compute_statement_semantic( (Some(expr.id), expr.ty(), expr_syntax.stable_ptr().untyped()) } }; - let expected_ty = ctx - .get_signature( - return_syntax.into(), - UnsupportedOutsideOfFunctionFeatureName::ReturnStatement, - )? - .return_type; + let expected_ty = match ctx.inner_ctx { + None => { + ctx.get_signature( + return_syntax.into(), + UnsupportedOutsideOfFunctionFeatureName::ReturnStatement, + )? + .return_type + } + Some(InnerContext::Closure { return_type }) => return_type, + _ => unreachable!("Return statement inside a loop"), + }; let expected_ty = ctx.reduce_ty(expected_ty); let expr_ty = ctx.reduce_ty(expr_ty); @@ -3377,11 +3403,11 @@ pub fn compute_statement_semantic( } }; let ty = ctx.reduce_ty(ty); - match &mut ctx.loop_ctx { - None => { + match &mut ctx.inner_ctx { + None | Some(InnerContext::Closure { .. }) => { return Err(ctx.diagnostics.report(break_syntax, BreakOnlyAllowedInsideALoop)); } - Some(LoopContext::Loop { type_merger, .. }) => { + Some(InnerContext::Loop { type_merger, .. }) => { type_merger.try_merge_types( ctx.db, ctx.diagnostics, @@ -3390,7 +3416,7 @@ pub fn compute_statement_semantic( stable_ptr, ); } - Some(LoopContext::While | LoopContext::For) => { + Some(InnerContext::While | InnerContext::For) => { if expr_option.is_some() { ctx.diagnostics.report(break_syntax, BreakWithValueOnlyAllowedInsideALoop); }; diff --git a/crates/cairo-lang-semantic/src/expr/test_data/closure b/crates/cairo-lang-semantic/src/expr/test_data/closure index 6ec61886e48..3ad1312e8c6 100644 --- a/crates/cairo-lang-semantic/src/expr/test_data/closure +++ b/crates/cairo-lang-semantic/src/expr/test_data/closure @@ -440,3 +440,162 @@ error: Arguments to closure functions cannot be references --> lib.cairo:6:23 let _f: u32 = bar(ref a); ^***^ + +//! > ========================================================================== + +//! > Closure with return statement argument. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: false) + +//! > function +fn foo() { + let bar = |a| { + return a; + }; + let a = 5; + let _f: u32 = bar(a); +} + +//! > function_name +foo + +//! > module_code + +//! > expected_diagnostics + +//! > ========================================================================== + +//! > Closure with option propagation. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: false) + +//! > function +fn foo() { + let bar = |a: Option| -> Option { + Option::Some(a?) + }; + let a = Option::Some(5); + let _f: u32 = bar(a).unwrap(); +} + +//! > function_name +foo + +//! > module_code + +//! > expected_diagnostics + +//! > ========================================================================== + +//! > Closure with option propagation inference required. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: true) + +//! > function +//TODO(Tomer): Add support for inference of Option type in closures. +fn foo() { + let bar1 = |a| -> Option { + Option::Some(a?) + }; + let bar2 = |b: Option| { + Option::Some(b?) + }; + let a = Option::Some(5); + let _f: u32 = bar(a).unwrap(); + let _f: u32 = bar2(a).unwrap(); +} + +//! > function_name +foo + +//! > module_code + +//! > expected_diagnostics +error: Type "?0" can not error propagate + --> lib.cairo:4:22 + Option::Some(a?) + ^^ + +error: `?` can only be used in a function with `Option` or `Result` return type. + --> lib.cairo:7:22 + Option::Some(b?) + ^^ + +warning[E0001]: Unused variable. Consider ignoring by prefixing with `_`. + --> lib.cairo:6:17 + let bar2 = |b: Option| { + ^************^ + +error: Function not found. + --> lib.cairo:10:19 + let _f: u32 = bar(a).unwrap(); + ^*^ + +error: Ambiguous method call. More than one applicable trait function with a suitable self type was found: OptionTrait::unwrap and ResultTrait::unwrap. Consider adding type annotations or explicitly refer to the impl function. + --> lib.cairo:10:26 + let _f: u32 = bar(a).unwrap(); + ^****^ + +warning[E0001]: Unused variable. Consider ignoring by prefixing with `_`. + --> lib.cairo:3:9 + let bar1 = |a| -> Option { + ^**^ + +//! > ========================================================================== + +//! > break inside a closure. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: true) + +//! > function +fn foo() { + loop { + let c = || { + break; + }; + c(); + } +} + +//! > function_name +foo + +//! > module_code + +//! > expected_diagnostics +error: `break` only allowed inside a `loop`. + --> lib.cairo:4:13 + break; + ^****^ + +//! > ========================================================================== + +//! > continue inside a closure. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: true) + +//! > function +fn foo() { + loop { + let c = || { + continue; + }; + c(); + } +} + +//! > function_name +foo + +//! > module_code + +//! > expected_diagnostics +error: `continue` only allowed inside a `loop`. + --> lib.cairo:4:13 + continue; + ^*******^