Skip to content

Commit

Permalink
Refactor LetPr translation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ShinWonho committed Jul 17, 2024
1 parent fc947bc commit 8705f5a
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 52 deletions.
138 changes: 90 additions & 48 deletions spectec/src/il2al/translate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ let wrap typ e = e $$ no_region % typ
let top = Il.VarT ("TOP" $ no_region, []) $ no_region
let hole = Il.TextE "_" |> wrap top

let insert_nop instrs = match instrs with [] -> [ nopI () ] | _ -> instrs

(* Insert `target` at the innermost if instruction *)
let rec insert_instrs target il =
match Util.Lib.List.split_last_opt il with
| Some ([], { it = OtherwiseI il'; _ }) -> [ otherwiseI (il' @ insert_nop target) ]
| Some (h, { it = IfI (cond, il', []); _ }) ->
h @ [ ifI (cond, insert_instrs (insert_nop target) il' , []) ]
| _ -> il @ target


(** Translation *)

(* `Il.atom` -> `atom` *)
let translate_atom atom = atom.it, atom.note.Il.Atom.def

Expand Down Expand Up @@ -335,8 +348,6 @@ let insert_pop e =

[ insert_assert e; popI (translate_exp e') ~at:e'.at ]

let insert_nop instrs = match instrs with [] -> [ nopI () ] | _ -> instrs


(* Assume that only the iter variable is unbound *)
let is_unbound vars e =
Expand Down Expand Up @@ -407,8 +418,9 @@ let rec translate_rhs exp =
letI (varE "F", frameE (Some (varE arity.it), varE fid.it)) ~at:at;
enterI (varE "F", listE ([caseE (translate_atom atom, [])]), translate_rhs le) ~at:at;
]
(* TODO: Label *)
(* Label *)
| Il.CaseE (

[ { it = Atom "LABEL_"; _ } as atom ] :: _,
{ it = Il.TupE [ arity; e1; e2 ]; _ }
) ->
Expand Down Expand Up @@ -467,61 +479,69 @@ let contains_diff target_ns e =
let free_ns = free_expr e in
not (IdSet.is_empty free_ns) && IdSet.disjoint free_ns target_ns

let extract_diff lhs rhs ids cont =
let handle_partial_bindings lhs rhs ids =
match lhs.it with
(* TODO: Make this actually consider the targets *)
| CallE (_, _) ->
lhs, rhs, cont
lhs, rhs, []
| _ ->
let conds = ref [] in
let target_ns = IdSet.of_list (List.map it ids) in
let target_ns = IdSet.of_list ids in
let pre_expr = (fun e ->
if not (contains_diff target_ns e) then
e
else
else (
let new_e = get_lhs_name () in
conds := [ binE (EqOp, new_e, e) ];
conds := !conds @ [ binE (EqOp, new_e, e) ];
new_e
)
) in
let walker = Al.Walk.walk_expr { Al.Walk.default_config with
pre_expr;
stop_cond_expr = contains_diff target_ns;
} in
let new_lhs = walker lhs in
new_lhs, rhs, List.fold_right (fun c il -> [ ifI (c, il, []) ]) !conds cont
new_lhs, rhs, List.fold_left (fun il c -> [ ifI (c, il, []) ]) [] !conds


let rec translate_bindings ids cont bindings =
let rec translate_bindings ids bindings =
List.fold_right (fun (l, r) cont ->
match l with
| _ when IdSet.is_empty (free_expr l) -> [ ifI (binE (EqOp, r, l), cont, []) ]
| _ -> translate_letpr l r ids cont
) bindings cont

and translate_letpr lhs rhs free_ids cont =
(* helpers *)
let contains_free expr =
free_ids
|> List.map it
| _ when IdSet.is_empty (free_expr l) -> [ ifI (binE (EqOp, r, l), [], []) ]
| _ -> insert_instrs cont (handle_special_lhs l r ids)
) bindings []

and handle_inverse_function lhs rhs free_ids =
(* Helper functions *)
let contains_ids ids expr =
ids
|> IdSet.of_list
|> IdSet.disjoint (free_expr expr)
|> not
in
let contains_free = contains_ids free_ids in
let rhs2args e =
match e.it with
| TupE el -> el
| _ -> [e]
in
let args2lhs args = if List.length args = 1 then List.hd args else tupE args in

let lhs, rhs, cont = extract_diff lhs rhs free_ids cont in
let at = over_region [ lhs.at; rhs.at ] in
match lhs.it with
| CallE (f, args) when List.for_all contains_free args ->
(* Get function name and arguments *)
let f, args =
match lhs.it with
| CallE (f, args) -> f, args
| _ -> assert (false);
in

(* All arguments are free *)
if List.for_all contains_free args then
let new_lhs = args2lhs args in
let new_rhs = InvCallE (f, [], rhs2args rhs) $ lhs.at in
translate_letpr new_lhs new_rhs free_ids cont
| CallE (f, args) when List.exists contains_free args ->
handle_special_lhs new_lhs new_rhs free_ids

(* Some arguments are free *)
else if List.exists contains_free args then

(* Distinguish free arguments and bound arguments *)
let free_args_with_index, bound_args =
args
Expand All @@ -543,54 +563,69 @@ and translate_letpr lhs rhs free_ids cont =
let new_rhs = InvCallE (f, indices, bound_args @ rhs2args rhs) $ lhs.at in

(* Recursively translate new_lhs and new_rhs *)
translate_letpr new_lhs new_rhs free_ids cont
handle_special_lhs new_lhs new_rhs free_ids

(* No argument is free *)
else
Print.string_of_expr lhs
|> sprintf "lhs expression %s doesn't contain free variable"
|> error lhs.at


and handle_special_lhs lhs rhs free_ids =

let at = over_region [ lhs.at; rhs.at ] in
match lhs.it with
(* Handle inverse function call *)
| CallE _ -> handle_inverse_function lhs rhs free_ids
(* Normal cases *)
| CaseE (tag, es) ->
let bindings, es' = extract_non_names es in
[
ifI (
isCaseOfE (rhs, tag),
letI (caseE (tag, es') ~at:lhs.at, rhs) ~at:at :: translate_bindings free_ids cont bindings,
letI (caseE (tag, es') ~at:lhs.at, rhs) ~at:at :: translate_bindings free_ids bindings,
[]
);
]
| ListE es ->
let bindings, es' = extract_non_names es in
if List.length es >= 2 then (* TODO: remove this. This is temporarily for a pure function returning stores *)
letI (listE es' ~at:lhs.at, rhs) ~at:at :: translate_bindings free_ids cont bindings
letI (listE es' ~at:lhs.at, rhs) ~at:at :: translate_bindings free_ids bindings
else
[
ifI
( binE (EqOp, lenE rhs, numE (Z.of_int (List.length es))),
letI (listE es' ~at:lhs.at, rhs) ~at:at :: translate_bindings free_ids cont bindings,
letI (listE es' ~at:lhs.at, rhs) ~at:at :: translate_bindings free_ids bindings,
[] );
]
| OptE None ->
[
ifI
( unE (NotOp, isDefinedE rhs),
cont,
[],
[] );
]
| OptE (Some ({ it = VarE _; _ })) ->
[
ifI
( isDefinedE rhs,
letI (lhs, rhs) ~at:at :: cont,
[letI (lhs, rhs) ~at:at],
[] );
]
| OptE (Some e) ->
let fresh = get_lhs_name() in
[
ifI
( isDefinedE rhs,
letI (optE (Some fresh) ~at:lhs.at, rhs) ~at:at :: translate_letpr e fresh free_ids cont,
letI (optE (Some fresh) ~at:lhs.at, rhs) ~at:at :: handle_special_lhs e fresh free_ids,
[] );
]
| BinE (AddOp, a, b) ->
[
ifI
( binE (GeOp, rhs, b),
letI (a, binE (SubOp, rhs, b) ~at:at) ~at:at :: cont,
[letI (a, binE (SubOp, rhs, b) ~at:at) ~at:at],
[] );
]
| CatE (prefix, suffix) ->
Expand All @@ -616,25 +651,41 @@ and translate_letpr lhs rhs free_ids cont =
ifI
( cond,
letI (catE (prefix', suffix') ~at:lhs.at, rhs) ~at:at
:: translate_bindings free_ids cont (bindings_p @ bindings_s),
:: translate_bindings free_ids (bindings_p @ bindings_s),
[] );
]
| SubE (s, t) ->
[
ifI
( hasTypeE (rhs, t),
letI (varE s ~at:lhs.at, rhs) ~at:at :: cont,
[letI (varE s ~at:lhs.at, rhs) ~at:at],
[] )
]
| _ -> letI (lhs, rhs) ~at:at :: cont
| _ -> [letI (lhs, rhs) ~at:at]

let translate_letpr lhs rhs free_ids =
(* Translate *)
let al_lhs, al_rhs = translate_exp lhs, translate_exp rhs in
let al_ids = List.map it free_ids in

(* Handle partial bindings *)
let al_lhs', al_rhs', cond_instrs = handle_partial_bindings al_lhs al_rhs al_ids in

(* Construct binding instructions *)
let instrs = handle_special_lhs al_lhs' al_rhs' al_ids in

(* Insert conditions *)
if List.length cond_instrs = 0 then instrs
else insert_instrs cond_instrs instrs


(* HARDCODE: Translate each RulePr manually based on their names *)
let translate_rulepr id exp =
let at = id.at in
match id.it, translate_argexp exp with
| "Eval_expr", [z; lhs; _; rhs] ->
(* Note: State is automatically converted into frame by remove_state *)
[
(* TODO: not pushing store without store remover transpiler *)
pushI (frameE (None, z));
letI (rhs, callE ("eval_expr", [ lhs ])) ~at:at;
popI (frameE (None, z));
Expand Down Expand Up @@ -679,20 +730,11 @@ and translate_prem prem =
| Il.ElsePr -> [ otherwiseI [] ~at:at ]
| Il.LetPr (exp1, exp2, ids) ->
init_lhs_id ();
translate_letpr (translate_exp exp1) (translate_exp exp2) ids []
translate_letpr exp1 exp2 ids
| Il.RulePr (id, _, exp) -> translate_rulepr id exp
| Il.IterPr (pr, exp) -> translate_iterpr pr exp


(* Insert `target` at the innermost if instruction *)
let rec insert_instrs target il =
match Util.Lib.List.split_last_opt il with
| Some ([], { it = OtherwiseI il'; _ }) -> [ otherwiseI (il' @ insert_nop target) ]
| Some (h, { it = IfI (cond, il', []); _ }) ->
h @ [ ifI (cond, insert_instrs (insert_nop target) il' , []) ]
| _ -> il @ target


(* `premise list` -> `instr list` (return instructions) -> `instr list` *)
let translate_prems =
List.fold_right (fun prem il -> translate_prem prem |> insert_instrs il)
Expand Down
7 changes: 3 additions & 4 deletions spectec/test-prose/TEST.md
Original file line number Diff line number Diff line change
Expand Up @@ -7033,10 +7033,9 @@ execution_of_ARRAY.NEW_DATA x y
8. Let (mut, zt) be y_0.
9. If ((i + ((n · $zsize(zt)) / 8)) > |$data(z, y).BYTES|), then:
a. Trap.
10. Let $zbytes(y_0, c)^n be $concat_^-1($data(z, y).BYTES[i : ((n · $zsize(zt)) / 8)]).
11. If (y_0 is zt), then:
a. Push the values $const($cunpack(zt), $cunpacknum(zt, c))^n to the stack.
b. Execute the instruction (ARRAY.NEW_FIXED x n).
10. Let $zbytes(zt, c)^n be $concat_^-1($data(z, y).BYTES[i : ((n · $zsize(zt)) / 8)]).
11. Push the values $const($cunpack(zt), $cunpacknum(zt, c))^n to the stack.
12. Execute the instruction (ARRAY.NEW_FIXED x n).

execution_of_ARRAY.GET sx? x
1. Let z be the current state.
Expand Down

0 comments on commit 8705f5a

Please sign in to comment.