Skip to content

Commit

Permalink
Merge pull request #907 from hacspec/rewrite-control-flow
Browse files Browse the repository at this point in the history
Add RewriteControlFlow phase.
  • Loading branch information
W95Psp authored Sep 25, 2024
2 parents c2093b4 + 120ba54 commit cc38130
Show file tree
Hide file tree
Showing 7 changed files with 382 additions and 216 deletions.
1 change: 1 addition & 0 deletions engine/backends/fstar/fstar_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,7 @@ module TransformToInputLanguage =
|> Side_effect_utils.Hoist
|> Phases.Hoist_disjunctive_patterns
|> Phases.Simplify_match_return
|> Phases.Rewrite_control_flow
|> Phases.Drop_needless_returns
|> Phases.Local_mutation
|> Phases.Reject.Continue
Expand Down
1 change: 1 addition & 0 deletions engine/lib/diagnostics.ml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ module Phase = struct
| ResugarWhileLoops
| ResugarForIndexLoops
| ResugarQuestionMarks
| RewriteControlFlow
| SimplifyQuestionMarks
| Specialize
| HoistSideEffects
Expand Down
107 changes: 107 additions & 0 deletions engine/lib/phases/phase_rewrite_control_flow.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
(* This phase rewrites: `if c {return a}; b` as `if c {return a; b} else {b}`
and does the equivalent transformation for pattern matchings. *)

open! Prelude

module Make (F : Features.T) =
Phase_utils.MakeMonomorphicPhase
(F)
(struct
let phase_id = Diagnostics.Phase.RewriteControlFlow

open Ast.Make (F)
module U = Ast_utils.Make (F)
module Visitors = Ast_visitors.Make (F)

module Error = Phase_utils.MakeError (struct
let ctx = Diagnostics.Context.Phase phase_id
end)

let has_return =
object (_self)
inherit [_] Visitors.reduce as super
method zero = false
method plus = ( || )

method! visit_expr' () e =
match e with Return _ -> true | _ -> super#visit_expr' () e
end

let rewrite_control_flow =
object (self)
inherit [_] Visitors.map as super

method! visit_expr () e =
match e.e with
| _ when not (has_return#visit_expr () e) -> e
(* Returns in loops will be handled by issue #196 *)
| Loop _ -> e
| Let _ -> (
(* Collect let bindings to get the sequence
of "statements", find the first "statement" that is a
control flow containing a return. Rewrite it.
*)
let stmts, final = U.collect_let_bindings e in
let inline_in_branch branch p stmts_after final =
let branch_stmts, branch_final =
U.collect_let_bindings branch
in
let stmts_to_add =
match (p, branch_final) with
(* This avoids adding `let _ = ()` *)
| { p = PWild; _ }, { e = GlobalVar (`TupleCons 0); _ } ->
stmts_after
| stmt -> stmt :: stmts_after
in
U.make_lets (branch_stmts @ stmts_to_add) final
in
let stmts_before, stmt_and_stmts_after =
List.split_while stmts ~f:(fun (_, e) ->
match e.e with
| (If _ | Match _) when has_return#visit_expr () e ->
false
| Return _ -> false
| _ -> true)
in
match stmt_and_stmts_after with
| (p, ({ e = If { cond; then_; else_ }; _ } as rhs))
:: stmts_after ->
(* We know there is no "return" in the condition
so we must rewrite the if *)
let then_ = inline_in_branch then_ p stmts_after final in
let else_ =
Some
(match else_ with
| Some else_ ->
inline_in_branch else_ p stmts_after final
| None -> U.make_lets stmts_after final)
in
U.make_lets stmts_before
{ rhs with e = If { cond; then_; else_ } }
|> self#visit_expr ()
| (p, ({ e = Match { scrutinee; arms }; _ } as rhs))
:: stmts_after ->
let arms =
List.map arms ~f:(fun arm ->
let body =
inline_in_branch arm.arm.body p stmts_after final
in
{ arm with arm = { arm.arm with body } })
in
U.make_lets stmts_before
{ rhs with e = Match { scrutinee; arms } }
|> self#visit_expr ()
(* The statements coming after a "return" are useless. *)
| (_, ({ e = Return _; _ } as rhs)) :: _ ->
U.make_lets stmts_before rhs |> self#visit_expr ()
| _ ->
let stmts =
List.map stmts ~f:(fun (p, e) ->
(p, self#visit_expr () e))
in
U.make_lets stmts (self#visit_expr () final))
| _ -> super#visit_expr () e
end

let ditems = List.map ~f:(rewrite_control_flow#visit_item ())
end)
6 changes: 6 additions & 0 deletions engine/lib/phases/phase_rewrite_control_flow.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
(** This phase finds control flow expression (`if` or `match`) with a `return` expression
in one of the branches. We replace them by replicating what comes after in all the branches.
This allows the `return` to be eliminated by `drop_needless_returns`.
This phase should come after `phase_local_mutation`. *)

module Make : Phase_utils.UNCONSTRAINTED_MONOMORPHIC_PHASE
Loading

0 comments on commit cc38130

Please sign in to comment.