Skip to content

Commit

Permalink
Merge pull request #1060 from lorchrob/any-ops-imp-funs
Browse files Browse the repository at this point in the history
Desugar any operators to imported function calls when possible (rather than always desugaring to imported node calls)
  • Loading branch information
daniel-larraz authored Apr 4, 2024
2 parents 82107c0 + fa033d6 commit 9c26262
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 87 deletions.
38 changes: 38 additions & 0 deletions src/lustre/lustreAstHelpers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,44 @@ let rec vars_without_node_call_ids: expr -> iset =
(* Node calls *)
| Call (_, _, es) -> SI.flatten (List.map vars es)

let rec calls_of_expr: expr -> iset =
function
(* Node calls *)
| Call (_, i, es) -> SI.union (SI.singleton i) (SI.flatten (List.map calls_of_expr es))
| Condact (_, e1, e2, i, es1, es2) ->
SI.union (SI.singleton i)
(SI.flatten (calls_of_expr e1 :: calls_of_expr e2 ::
List.map calls_of_expr es1 @ List.map calls_of_expr es2))
| Activate (_, i, e1, e2, es) ->
SI.union (SI.singleton i)
(SI.flatten (calls_of_expr e1 :: calls_of_expr e2 :: List.map calls_of_expr es))
| RestartEvery (_, i, es, e) ->
SI.union (SI.singleton i)
(SI.flatten (calls_of_expr e :: List.map calls_of_expr es))
(* Everything else *)
| Ident _ -> SI.empty
| ModeRef _ -> SI.empty
| RecordProject (_, e, _) -> calls_of_expr e
| TupleProject (_, e, _) -> calls_of_expr e
| Const _ -> SI.empty
| UnaryOp (_,_,e) -> calls_of_expr e
| BinaryOp (_,_,e1, e2) -> calls_of_expr e1 |> SI.union (calls_of_expr e2)
| TernaryOp (_,_, e1, e2, e3) -> calls_of_expr e1 |> SI.union (calls_of_expr e2) |> SI.union (calls_of_expr e3)
| ConvOp (_,_,e) -> calls_of_expr e
| CompOp (_,_,e1, e2) -> (calls_of_expr e1) |> SI.union (calls_of_expr e2)
| RecordExpr (_, _, flds) -> SI.flatten (List.map calls_of_expr (snd (List.split flds)))
| GroupExpr (_, _, es) -> SI.flatten (List.map calls_of_expr es)
| StructUpdate (_, e1, _, e2) -> SI.union (calls_of_expr e1) (calls_of_expr e2)
| ArrayConstr (_, e1, e2) -> SI.union (calls_of_expr e1) (calls_of_expr e2)
| ArrayIndex (_, e1, e2) -> SI.union (calls_of_expr e1) (calls_of_expr e2)
| Quantifier (_, _, _, e) -> calls_of_expr e
| When (_, e, _) -> calls_of_expr e
| Merge (_, _, es) -> List.split es |> snd |> List.map calls_of_expr |> SI.flatten
| AnyOp (_, (_, i, _), e, None) -> SI.diff (calls_of_expr e) (SI.singleton i)
| AnyOp (_, (_, i, _), e1, Some e2) -> SI.diff (SI.union (calls_of_expr e1) (calls_of_expr e2)) (SI.singleton i)
| Pre (_, e) -> calls_of_expr e
| Arrow (_, e1, e2) -> SI.union (calls_of_expr e1) (calls_of_expr e2)

(* Like 'vars_without_node_calls', but only those vars that are not under a 'pre' expression *)
let rec vars_without_node_call_ids_current: expr -> iset =
let vars = vars_without_node_call_ids_current in
Expand Down
3 changes: 3 additions & 0 deletions src/lustre/lustreAstHelpers.mli
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ val defined_vars_with_pos: node_item -> (Lib.position * index) list
val vars_of_ty_ids: typed_ident -> SI.t
(** returns a singleton set with the only identifier in a typed identifier declaration *)

val calls_of_expr: expr -> SI.t
(** [calls_of_expr e] returns all node/function names for those nodes/functions called in [e] *)

val vars_of_type: lustre_type -> SI.t
(** [vars_of_type ty] returns all variable identifiers that appear in the type [ty]
while excluding node call identifiers and refinement type bound variables *)
Expand Down
198 changes: 111 additions & 87 deletions src/lustre/lustreDesugarAnyOps.ml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ fun pos node_name ->
let name = HString.concat2 name pos in
HString.concat2 node_name name

let rec desugar_expr: Ctx.tc_context -> HString.t -> A.expr -> A.expr * A.declaration list =
fun ctx node_name ->
function
let rec desugar_expr: Ctx.tc_context -> HString.t -> HString.t list -> A.expr -> A.expr * A.declaration list =
fun ctx node_name fun_ids expr ->
let rec_call = desugar_expr ctx node_name fun_ids in
match expr with
| A.AnyOp (pos, (_, id, ty), expr1, expr2_opt) ->
let span = { A.start_pos = pos; A.end_pos = Lib.dummy_pos } in
let contract = match expr2_opt with
Expand Down Expand Up @@ -70,201 +71,224 @@ fun ctx node_name ->
| None -> assert false
) inputs in
let name = mk_fresh_fn_name pos node_name in
(* If the any op expressions are temporal or call a node, we generate an imported node.
Otherwise, we generate an imported function. *)
let has_pre_arrow_or_node_call = match expr2_opt with
| Some expr2 ->
let node_calls1 = AH.calls_of_expr expr1 |> Ctx.SI.elements |> List.filter (fun i -> not (List.mem i fun_ids)) in
let node_calls2 = AH.calls_of_expr expr2 |> Ctx.SI.elements |> List.filter (fun i -> not (List.mem i fun_ids)) in
(AH.has_pre_or_arrow expr1 != None) || node_calls1 != [] ||
(AH.has_pre_or_arrow expr2 != None) || node_calls2 != []
| None ->
let node_calls1 = AH.calls_of_expr expr1 |> Ctx.SI.elements |> List.filter (fun i -> not (List.mem i fun_ids)) in
(AH.has_pre_or_arrow expr1 != None) || (node_calls1 != [])
in
let generated_node =
A.NodeDecl (span,
(name, true, [], inputs,
[pos, id, ty, A.ClockTrue], [], [], Some contract))
if has_pre_arrow_or_node_call then
A.NodeDecl (span,
(name, true, [], inputs,
[pos, id, ty, A.ClockTrue], [], [], Some contract))
else
A.FuncDecl (span,
(name, true, [], inputs,
[pos, id, ty, A.ClockTrue], [], [], Some contract))
in
A.Call(pos, name, inputs_call), [generated_node]

| Ident _ as e -> e, []
| ModeRef (_, _) as e -> e, []
| Const (_, _) as e -> e, []
| RecordProject (pos, e, idx) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = rec_call e in
RecordProject (pos, e, idx), gen_nodes
| TupleProject (pos, e, idx) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = rec_call e in
TupleProject (pos, e, idx), gen_nodes
| UnaryOp (pos, op, e) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = rec_call e in
UnaryOp (pos, op, e), gen_nodes
| BinaryOp (pos, op, e1, e2) ->
let e1, gen_nodes1 = desugar_expr ctx node_name e1 in
let e2, gen_nodes2 = desugar_expr ctx node_name e2 in
let e1, gen_nodes1 = rec_call e1 in
let e2, gen_nodes2 = rec_call e2 in
BinaryOp (pos, op, e1, e2), gen_nodes1 @ gen_nodes2
| TernaryOp (pos, op, e1, e2, e3) ->
let e1, gen_nodes1 = desugar_expr ctx node_name e1 in
let e2, gen_nodes2 = desugar_expr ctx node_name e2 in
let e3, gen_nodes3 = desugar_expr ctx node_name e3 in
let e1, gen_nodes1 = rec_call e1 in
let e2, gen_nodes2 = rec_call e2 in
let e3, gen_nodes3 = rec_call e3 in
TernaryOp (pos, op, e1, e2, e3), gen_nodes1 @ gen_nodes2 @ gen_nodes3
| ConvOp (pos, op, e) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = rec_call e in
ConvOp (pos, op, e), gen_nodes
| CompOp (pos, op, e1, e2) ->
let e1, gen_nodes1 = desugar_expr ctx node_name e1 in
let e2, gen_nodes2 = desugar_expr ctx node_name e2 in
let e1, gen_nodes1 = rec_call e1 in
let e2, gen_nodes2 = rec_call e2 in
CompOp (pos, op, e1, e2), gen_nodes1 @ gen_nodes2
| RecordExpr (pos, ident, expr_list) ->
let id_list, exprs_gen_nodes =
List.map (fun (i, e) -> (i, (desugar_expr ctx node_name) e)) expr_list |> List.split
List.map (fun (i, e) -> (i, (rec_call) e)) expr_list |> List.split
in
let expr_list, gen_nodes = List.split exprs_gen_nodes in
RecordExpr (pos, ident, List.combine id_list expr_list), List.flatten gen_nodes
| GroupExpr (pos, kind, expr_list) ->
let expr_list, gen_nodes = List.map (desugar_expr ctx node_name) expr_list |> List.split in
let expr_list, gen_nodes = List.map (rec_call) expr_list |> List.split in
GroupExpr (pos, kind, expr_list), List.flatten gen_nodes
| StructUpdate (pos, e1, idx, e2) ->
let e1, gen_nodes1 = desugar_expr ctx node_name e1 in
let e2, gen_nodes2 = desugar_expr ctx node_name e2 in
let e1, gen_nodes1 = rec_call e1 in
let e2, gen_nodes2 = rec_call e2 in
StructUpdate (pos, e1, idx, e2), gen_nodes1 @ gen_nodes2
| ArrayConstr (pos, e1, e2) ->
let e1, gen_nodes1 = desugar_expr ctx node_name e1 in
let e2, gen_nodes2 = desugar_expr ctx node_name e2 in
let e1, gen_nodes1 = rec_call e1 in
let e2, gen_nodes2 = rec_call e2 in
ArrayConstr (pos, e1, e2), gen_nodes1 @ gen_nodes2
| ArrayIndex (pos, e1, e2) ->
let e1, gen_nodes1 = desugar_expr ctx node_name e1 in
let e2, gen_nodes2 = desugar_expr ctx node_name e2 in
let e1, gen_nodes1 = rec_call e1 in
let e2, gen_nodes2 = rec_call e2 in
ArrayIndex (pos, e1, e2), gen_nodes1 @ gen_nodes2
| Quantifier (pos, kind, idents, e) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = rec_call e in
Quantifier (pos, kind, idents, e), gen_nodes
| When (pos, e, clock) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = rec_call e in
When (pos, e, clock), gen_nodes
| Condact (pos, e1, e2, id, expr_list1, expr_list2) ->
let e1, gen_nodes1 = desugar_expr ctx node_name e1 in
let e2, gen_nodes2 = desugar_expr ctx node_name e2 in
let expr_list1, gen_nodes3 = List.map (desugar_expr ctx node_name) expr_list1 |> List.split in
let expr_list2, gen_nodes4 = List.map (desugar_expr ctx node_name) expr_list2 |> List.split in
let e1, gen_nodes1 = rec_call e1 in
let e2, gen_nodes2 = rec_call e2 in
let expr_list1, gen_nodes3 = List.map rec_call expr_list1 |> List.split in
let expr_list2, gen_nodes4 = List.map rec_call expr_list2 |> List.split in
Condact (pos, e1, e2, id, expr_list1, expr_list2), gen_nodes1 @ gen_nodes2 @
List.flatten gen_nodes3 @ List.flatten gen_nodes4
| Activate (pos, ident, e1, e2, expr_list) ->
let e1, gen_nodes1 = desugar_expr ctx node_name e1 in
let e2, gen_nodes2 = desugar_expr ctx node_name e2 in
let e1, gen_nodes1 = rec_call e1 in
let e2, gen_nodes2 = rec_call e2 in
Activate (pos, ident, e1, e2, expr_list), gen_nodes1 @ gen_nodes2
| Merge (pos, ident, expr_list) ->
let id_list, exprs_gen_nodes =
List.map (fun (i, e) -> (i, (desugar_expr ctx node_name) e)) expr_list |> List.split
List.map (fun (i, e) -> (i, (rec_call) e)) expr_list |> List.split
in
let expr_list, gen_nodes = List.split exprs_gen_nodes in
Merge (pos, ident, List.combine id_list expr_list), List.flatten gen_nodes
| RestartEvery (pos, ident, expr_list, e) ->
let expr_list, gen_nodes1 = List.map (desugar_expr ctx node_name) expr_list |> List.split in
let e, gen_nodes2 = desugar_expr ctx node_name e in
let expr_list, gen_nodes1 = List.map (rec_call) expr_list |> List.split in
let e, gen_nodes2 = rec_call e in
RestartEvery (pos, ident, expr_list, e), List.flatten gen_nodes1 @ gen_nodes2
| Pre (pos, e) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = rec_call e in
Pre (pos, e), gen_nodes
| Arrow (pos, e1, e2) ->
let e1, gen_nodes1 = desugar_expr ctx node_name e1 in
let e2, gen_nodes2 = desugar_expr ctx node_name e2 in
let e1, gen_nodes1 = rec_call e1 in
let e2, gen_nodes2 = rec_call e2 in
Arrow (pos, e1, e2), gen_nodes1 @ gen_nodes2
| Call (pos, id, expr_list) ->
let expr_list, gen_nodes = List.map (desugar_expr ctx node_name) expr_list |> List.split in
let expr_list, gen_nodes = List.map rec_call expr_list |> List.split in
Call (pos, id, expr_list), List.flatten gen_nodes

let desugar_contract_item: Ctx.tc_context -> HString.t -> A.contract_node_equation -> A.contract_node_equation * A.declaration list =
fun ctx node_name ->
function
let desugar_contract_item: Ctx.tc_context -> HString.t -> HString.t list -> A.contract_node_equation -> A.contract_node_equation * A.declaration list =
fun ctx node_name fun_ids ci ->
let rec_call = desugar_expr ctx node_name fun_ids in
match ci with
| A.GhostVars (pos, lhs, e) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = rec_call e in
A.GhostVars (pos, lhs, e), gen_nodes
| Assume (pos, name, b, e) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = rec_call e in
Assume (pos, name, b, e), gen_nodes
| Guarantee (pos, name, b, e) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = rec_call e in
Guarantee (pos, name, b, e), gen_nodes
| Mode (pos, i, reqs, enss) ->
let (reqs, gen_nodes1) =
List.map (fun (pos, id, expr) -> (pos, id, desugar_expr ctx node_name expr)) reqs |>
List.map (fun (pos, id, expr) -> (pos, id, rec_call expr)) reqs |>
List.map (fun (pos, id, (expr, decls)) -> ((pos, id, expr), decls)) |>
List.split in
let (enss, gen_nodes2) =
List.map (fun (pos, id, expr) -> (pos, id, desugar_expr ctx node_name expr)) enss |>
List.map (fun (pos, id, expr) -> (pos, id, rec_call expr)) enss |>
List.map (fun (pos, id, (expr, decls)) -> ((pos, id, expr), decls)) |>
List.split in
Mode (pos, i, reqs, enss), (List.flatten gen_nodes1) @ (List.flatten gen_nodes2)
| ContractCall (pos, i, exprs, ids) ->
let (exprs, gen_nodes) = List.map (desugar_expr ctx node_name) exprs |> List.split in
let (exprs, gen_nodes) = List.map rec_call exprs |> List.split in
ContractCall (pos, i, exprs, ids), List.flatten gen_nodes
| GhostConst _
| AssumptionVars _ as ci -> ci, []

let desugar_contract: Ctx.tc_context -> HString.t -> A.contract_node_equation list option -> A.contract_node_equation list option * A.declaration list =
fun ctx node_name contract ->
let desugar_contract: Ctx.tc_context -> HString.t -> HString.t list -> A.contract_node_equation list option -> A.contract_node_equation list option * A.declaration list =
fun ctx node_name fun_ids contract ->
match contract with
| Some contract_items ->
let items, gen_nodes = (List.map (desugar_contract_item ctx node_name) contract_items) |> List.split in
Some items, List.flatten gen_nodes
| None -> None, []
| Some contract_items ->
let items, gen_nodes = (List.map (desugar_contract_item ctx node_name fun_ids) contract_items) |> List.split in
Some items, List.flatten gen_nodes
| None -> None, []

let rec desugar_node_item: Ctx.tc_context -> HString.t -> A.node_item -> A.node_item * A.declaration list =
fun ctx node_name ni ->
let rec desugar_node_item: Ctx.tc_context -> HString.t -> HString.t list -> A.node_item -> A.node_item * A.declaration list =
fun ctx node_name fun_ids ni ->
let rec_call = desugar_node_item ctx node_name fun_ids in
match ni with
| A.Body (Equation (pos, lhs, rhs)) ->
let rhs, gen_nodes = desugar_expr ctx node_name rhs in
let rhs, gen_nodes = desugar_expr ctx node_name fun_ids rhs in
A.Body (Equation (pos, lhs, rhs)), gen_nodes
| AnnotProperty (pos, name, e, k) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = desugar_expr ctx node_name fun_ids e in
AnnotProperty(pos, name, e, k), gen_nodes
| IfBlock (pos, cond, nis1, nis2) ->
let nis1, gen_nodes1 = List.map (desugar_node_item ctx node_name) nis1 |> List.split in
let nis2, gen_nodes2 = List.map (desugar_node_item ctx node_name) nis2 |> List.split in
let cond, gen_nodes3 = desugar_expr ctx node_name cond in
let nis1, gen_nodes1 = List.map rec_call nis1 |> List.split in
let nis2, gen_nodes2 = List.map rec_call nis2 |> List.split in
let cond, gen_nodes3 = desugar_expr ctx node_name fun_ids cond in
A.IfBlock (pos, cond, nis1, nis2), List.flatten gen_nodes1 @ List.flatten gen_nodes2 @ gen_nodes3
| FrameBlock (pos, vars, nes, nis) ->
let nes = List.map (fun x -> A.Body x) nes in
let nes, gen_nodes1 = List.map (desugar_node_item ctx node_name) nes |> List.split in
let nes, gen_nodes1 = List.map rec_call nes |> List.split in
let nes = List.map (fun ne -> match ne with
| A.Body (A.Equation _ as eq) -> eq
| _ -> assert false
) nes in
let nis, gen_nodes2 = List.map (desugar_node_item ctx node_name) nis |> List.split in
let nis, gen_nodes2 = List.map rec_call nis |> List.split in
FrameBlock(pos, vars, nes, nis), List.flatten gen_nodes1 @ List.flatten gen_nodes2
| Body (Assert (pos, e)) ->
let e, gen_nodes = desugar_expr ctx node_name e in
let e, gen_nodes = desugar_expr ctx node_name fun_ids e in
Body (Assert (pos, e)), gen_nodes
| AnnotMain _ -> ni, []


let desugar_any_ops: Ctx.tc_context -> A.declaration list -> A.declaration list =
fun ctx decls ->
let fun_ids = List.filter_map
(fun decl -> match decl with | A.FuncDecl (_, (id, _, _, _, _, _, _, _)) -> Some id | _ -> None)
decls
in
let decls =
List.fold_left (fun decls decl ->
match decl with
| A.NodeDecl (span, (id, ext, params, inputs, outputs, locals, items, contract)) ->
(
match Chk.add_full_node_ctx ctx id inputs outputs locals with
| Ok ctx ->
let items, gen_nodes = List.map (desugar_node_item ctx id) items |> List.split in
let contract, gen_nodes2 = desugar_contract ctx id contract in
let gen_nodes = List.flatten gen_nodes in
decls @ gen_nodes @ gen_nodes2 @ [A.NodeDecl (span, (id, ext, params, inputs, outputs, locals, items, contract))]
(* If there is an error in context collection, it will be detected later in type checking *)
| Error _ -> decl :: decls
| Ok ctx ->
let items, gen_nodes = List.map (desugar_node_item ctx id fun_ids) items |> List.split in
let contract, gen_nodes2 = desugar_contract ctx id fun_ids contract in
let gen_nodes = List.flatten gen_nodes in
decls @ gen_nodes @ gen_nodes2 @ [A.NodeDecl (span, (id, ext, params, inputs, outputs, locals, items, contract))]
(* If there is an error in context collection, it will be detected later in type checking *)
| Error _ -> decl :: decls
)
| A.FuncDecl (span, (id, ext, params, inputs, outputs, locals, items, contract)) ->
(
match Chk.add_full_node_ctx ctx id inputs outputs locals with
| Ok ctx ->
let items, gen_nodes = List.map (desugar_node_item ctx id) items |> List.split in
let contract, gen_nodes2 = desugar_contract ctx id contract in
let gen_nodes = List.flatten gen_nodes in
decls @ gen_nodes @ gen_nodes2 @ [A.FuncDecl (span, (id, ext, params, inputs, outputs, locals, items, contract))]
(* If there is an error in context collection, it will be detected later in type checking *)
| Error _ -> decl :: decls
| Ok ctx ->
let items, gen_nodes = List.map (desugar_node_item ctx id fun_ids) items |> List.split in
let contract, gen_nodes2 = desugar_contract ctx id fun_ids contract in
let gen_nodes = List.flatten gen_nodes in
decls @ gen_nodes @ gen_nodes2 @ [A.FuncDecl (span, (id, ext, params, inputs, outputs, locals, items, contract))]
(* If there is an error in context collection, it will be detected later in type checking *)
| Error _ -> decl :: decls
)
| A.ContractNodeDecl (span, (id, params, inputs, outputs, contract)) ->
(
let ctx = Chk.add_io_node_ctx ctx inputs outputs in
let contract, gen_nodes = desugar_contract ctx id (Some contract) in
let contract = match contract with
| Some contract -> contract
| None -> assert false in (* Must have a contract *)
decls @ gen_nodes @ [A.ContractNodeDecl (span, (id, params, inputs, outputs, contract))]
)
(
let ctx = Chk.add_io_node_ctx ctx inputs outputs in
let contract, gen_nodes = desugar_contract ctx id fun_ids (Some contract) in
let contract = match contract with
| Some contract -> contract
| None -> assert false in (* Must have a contract *)
decls @ gen_nodes @ [A.ContractNodeDecl (span, (id, params, inputs, outputs, contract))]
)
| _ -> decl :: decls
) [] decls in
decls

0 comments on commit 9c26262

Please sign in to comment.