Skip to content

Commit

Permalink
added support for return statement and error propagation inside closu…
Browse files Browse the repository at this point in the history
…res (#6254)
  • Loading branch information
TomerStarkware authored Aug 22, 2024
1 parent 24d1268 commit aaad921
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 38 deletions.
102 changes: 64 additions & 38 deletions crates/cairo-lang-semantic/src/expr/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,17 @@ pub enum ContextFunction {
Function(Maybe<FunctionId>),
}

/// Context inside loops.
/// Context inside loops or closures.
#[derive(Debug, Clone)]
enum LoopContext {
enum InnerContext {
/// Context inside a `loop`
Loop { type_merger: FlowMergeTypeHelper },
/// Context inside a `while` loop
While,
/// Context inside a `for` loop
For,
/// Context inside a `closure`
Closure { return_type: TypeId },
}

/// Context for computing the semantic model of expression trees.
Expand All @@ -142,7 +144,7 @@ pub struct ComputationContext<'ctx> {
function_id: ContextFunction,
/// Definitions of semantic variables.
pub semantic_defs: UnorderedHashMap<semantic::VarId, semantic::Variable>,
loop_ctx: Option<LoopContext>,
inner_ctx: Option<InnerContext>,
cfg_set: Arc<CfgSet>,
/// whether to look for closures when calling variables.
/// TODO(TomerStarkware): Remove this once we disallow calling shadowed functions.
Expand All @@ -169,7 +171,7 @@ impl<'ctx> ComputationContext<'ctx> {
arenas: Default::default(),
function_id,
semantic_defs,
loop_ctx: None,
inner_ctx: None,
cfg_set: db
.crate_config(owning_crate_id)
.and_then(|cfg| cfg.settings.cfg_set.map(Arc::new))
Expand Down Expand Up @@ -249,6 +251,13 @@ impl<'ctx> ComputationContext<'ctx> {
self.resolver.inference().internal_rewrite(stmt).no_err();
}
}
/// Returns whether the current context is inside a loop.
fn is_inside_loop(&self) -> bool {
match self.inner_ctx {
None | Some(InnerContext::Closure { .. }) => false,
Some(InnerContext::Loop { .. } | InnerContext::While | InnerContext::For) => true,
}
}
}

// TODO(ilya): Change value to VarId.
Expand Down Expand Up @@ -1216,7 +1225,7 @@ fn compute_arm_semantic(
let ast::Expr::Block(arm_expr_syntax) = arm_expr_syntax else {
unreachable!("Expected a block expression for a loop arm.");
};
let (id, _) = compute_loop_body_semantic(new_ctx, arm_expr_syntax, LoopContext::While);
let (id, _) = compute_loop_body_semantic(new_ctx, arm_expr_syntax, InnerContext::While);
let expr = new_ctx.arenas.exprs[id].clone();
ExprAndId { expr, id }
} else {
Expand Down Expand Up @@ -1349,15 +1358,15 @@ fn compute_expr_loop_semantic(
let db = ctx.db;
let syntax_db = db.upcast();

let (body, loop_ctx) = compute_loop_body_semantic(
let (body, inner_ctx) = compute_loop_body_semantic(
ctx,
syntax.body(syntax_db),
LoopContext::Loop { type_merger: FlowMergeTypeHelper::new(db, MultiArmExprKind::Loop) },
InnerContext::Loop { type_merger: FlowMergeTypeHelper::new(db, MultiArmExprKind::Loop) },
);
Ok(Expr::Loop(ExprLoop {
body,
ty: match loop_ctx {
LoopContext::Loop { type_merger, .. } => type_merger.get_final_type(),
ty: match inner_ctx {
InnerContext::Loop { type_merger, .. } => type_merger.get_final_type(),
_ => unreachable!("Expected loop context"),
},
stable_ptr: syntax.stable_ptr().into(),
Expand Down Expand Up @@ -1390,8 +1399,8 @@ fn compute_expr_while_semantic(
(Condition::Let(expr.id, patterns.iter().map(|pattern| pattern.id).collect()), body.id)
}
ast::Condition::Expr(expr) => {
let (body, _loop_ctx) =
compute_loop_body_semantic(ctx, syntax.body(syntax_db), LoopContext::While);
let (body, _inner_ctx) =
compute_loop_body_semantic(ctx, syntax.body(syntax_db), InnerContext::While);
(
Condition::BoolExpr(compute_bool_condition_semantic(ctx, &expr.expr(syntax_db)).id),
body,
Expand Down Expand Up @@ -1511,8 +1520,8 @@ fn compute_expr_for_semantic(
new_ctx.environment.variables.insert(v.name.clone(), var_def.clone());
new_ctx.semantic_defs.insert(var_def.id(), var_def);
}
let (body, _loop_ctx) =
compute_loop_body_semantic(new_ctx, syntax.body(syntax_db), LoopContext::For);
let (body, _inner_ctx) =
compute_loop_body_semantic(new_ctx, syntax.body(syntax_db), InnerContext::For);
(body, new_ctx.arenas.patterns.alloc(inner_pattern.pattern))
});
Ok(Expr::For(ExprFor {
Expand All @@ -1531,13 +1540,13 @@ fn compute_expr_for_semantic(
fn compute_loop_body_semantic(
ctx: &mut ComputationContext<'_>,
syntax: ast::ExprBlock,
loop_ctx: LoopContext,
) -> (ExprId, LoopContext) {
inner_ctx: InnerContext,
) -> (ExprId, InnerContext) {
let db = ctx.db;
let syntax_db = db.upcast();

ctx.run_in_subscope(|new_ctx| {
let old_loop_ctx = std::mem::replace(&mut new_ctx.loop_ctx, Some(loop_ctx));
let old_inner_ctx = std::mem::replace(&mut new_ctx.inner_ctx, Some(inner_ctx));

let mut statements = syntax.statements(syntax_db).elements(syntax_db);
// Remove the typed tail expression, if exists.
Expand All @@ -1560,15 +1569,15 @@ fn compute_loop_body_semantic(
}
}

let loop_ctx = std::mem::replace(&mut new_ctx.loop_ctx, old_loop_ctx).unwrap();
let inner_ctx = std::mem::replace(&mut new_ctx.inner_ctx, old_inner_ctx).unwrap();
let body = new_ctx.arenas.exprs.alloc(Expr::Block(ExprBlock {
statements: statements_semantic,
tail: tail.map(|tail| tail.id),
ty: unit_ty(db),
stable_ptr: syntax.stable_ptr().into(),
}));

(body, loop_ctx)
(body, inner_ctx)
})
}

Expand All @@ -1594,7 +1603,7 @@ fn compute_expr_closure_semantic(
} else {
vec![]
};
let ret_ty = match syntax.ret_ty(syntax_db) {
let return_type = match syntax.ret_ty(syntax_db) {
OptionReturnTypeClause::ReturnTypeClause(ty_syntax) => resolve_type(
new_ctx.db,
new_ctx.diagnostics,
Expand All @@ -1605,21 +1614,24 @@ fn compute_expr_closure_semantic(
new_ctx.resolver.inference().new_type_var(Some(missing.stable_ptr().untyped()))
}
};
let old_inner_ctx =
std::mem::replace(&mut new_ctx.inner_ctx, Some(InnerContext::Closure { return_type }));
let body = match syntax.expr(syntax_db) {
ast::Expr::Block(syntax) => compute_closure_body_semantic(new_ctx, syntax),
_ => compute_expr_semantic(new_ctx, &syntax.expr(syntax_db)).id,
};
std::mem::replace(&mut new_ctx.inner_ctx, old_inner_ctx).unwrap();
let mut inference = new_ctx.resolver.inference();
if let Err((err_set, actual_ty, expected_ty)) =
inference.conform_ty_for_diag(new_ctx.arenas.exprs[body].ty(), ret_ty)
inference.conform_ty_for_diag(new_ctx.arenas.exprs[body].ty(), return_type)
{
let diag_added = new_ctx.diagnostics.report(
syntax.expr(syntax_db).stable_ptr(),
WrongReturnType { expected_ty, actual_ty },
);
inference.consume_reported_error(err_set, diag_added);
}
(params, ret_ty, body)
(params, return_type, body)
});
let parent_function = match ctx.function_id {
ContextFunction::Global => Maybe::Err(ctx.diagnostics.report(syntax, ClosureInGlobalScope)),
Expand Down Expand Up @@ -1704,9 +1716,18 @@ fn compute_expr_error_propagate_semantic(
) -> Maybe<Expr> {
let syntax_db = ctx.db.upcast();

let func_signature =
ctx.get_signature(syntax.into(), UnsupportedOutsideOfFunctionFeatureName::ErrorPropagate)?;
let func_err_prop_ty = unwrap_error_propagation_type(ctx.db, func_signature.return_type)
let return_type = match ctx.inner_ctx {
Some(InnerContext::Closure { return_type }) => return_type,
None | Some(InnerContext::Loop { .. } | InnerContext::While | InnerContext::For) => {
ctx.get_signature(
syntax.into(),
UnsupportedOutsideOfFunctionFeatureName::ErrorPropagate,
)?
.return_type
}
};

let func_err_prop_ty = unwrap_error_propagation_type(ctx.db, return_type)
.ok_or_else(|| ctx.diagnostics.report(syntax, ReturnTypeNotErrorPropagateType))?;

// `inner_expr` is the expr inside the `?`.
Expand All @@ -1729,7 +1750,7 @@ fn compute_expr_error_propagate_semantic(
let inner_expr_err_variant = inner_expr_err_prop_ty.err_variant();

// Disallow error propagation inside a loop.
if ctx.loop_ctx.is_some() {
if ctx.is_inside_loop() {
ctx.diagnostics.report(syntax, SemanticDiagnosticKind::ErrorPropagateNotAllowedInsideALoop);
}

Expand All @@ -1752,7 +1773,7 @@ fn compute_expr_error_propagate_semantic(
ctx.diagnostics.report(
syntax,
IncompatibleErrorPropagateType {
return_ty: func_signature.return_type,
return_ty: return_type,
err_ty: inner_expr_err_variant.ty,
},
);
Expand Down Expand Up @@ -3315,7 +3336,7 @@ pub fn compute_statement_semantic(
})
}
ast::Statement::Continue(continue_syntax) => {
if ctx.loop_ctx.is_none() {
if !ctx.is_inside_loop() {
return Err(ctx
.diagnostics
.report(continue_syntax, ContinueOnlyAllowedInsideALoop));
Expand All @@ -3325,7 +3346,7 @@ pub fn compute_statement_semantic(
})
}
ast::Statement::Return(return_syntax) => {
if ctx.loop_ctx.is_some() {
if ctx.is_inside_loop() {
return Err(ctx.diagnostics.report(return_syntax, ReturnNotAllowedInsideALoop));
}

Expand All @@ -3339,12 +3360,17 @@ pub fn compute_statement_semantic(
(Some(expr.id), expr.ty(), expr_syntax.stable_ptr().untyped())
}
};
let expected_ty = ctx
.get_signature(
return_syntax.into(),
UnsupportedOutsideOfFunctionFeatureName::ReturnStatement,
)?
.return_type;
let expected_ty = match ctx.inner_ctx {
None => {
ctx.get_signature(
return_syntax.into(),
UnsupportedOutsideOfFunctionFeatureName::ReturnStatement,
)?
.return_type
}
Some(InnerContext::Closure { return_type }) => return_type,
_ => unreachable!("Return statement inside a loop"),
};

let expected_ty = ctx.reduce_ty(expected_ty);
let expr_ty = ctx.reduce_ty(expr_ty);
Expand Down Expand Up @@ -3377,11 +3403,11 @@ pub fn compute_statement_semantic(
}
};
let ty = ctx.reduce_ty(ty);
match &mut ctx.loop_ctx {
None => {
match &mut ctx.inner_ctx {
None | Some(InnerContext::Closure { .. }) => {
return Err(ctx.diagnostics.report(break_syntax, BreakOnlyAllowedInsideALoop));
}
Some(LoopContext::Loop { type_merger, .. }) => {
Some(InnerContext::Loop { type_merger, .. }) => {
type_merger.try_merge_types(
ctx.db,
ctx.diagnostics,
Expand All @@ -3390,7 +3416,7 @@ pub fn compute_statement_semantic(
stable_ptr,
);
}
Some(LoopContext::While | LoopContext::For) => {
Some(InnerContext::While | InnerContext::For) => {
if expr_option.is_some() {
ctx.diagnostics.report(break_syntax, BreakWithValueOnlyAllowedInsideALoop);
};
Expand Down
Loading

0 comments on commit aaad921

Please sign in to comment.