Skip to content

Commit

Permalink
Generalize inverse function translation
Browse files Browse the repository at this point in the history
  • Loading branch information
Wonho Shin authored and Wonho Shin committed Jul 16, 2024
1 parent 3dc017c commit 225e6e4
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 85 deletions.
1 change: 1 addition & 0 deletions spectec/src/al/free.mli
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ open Ast

module IdSet : Set.S with type elt = string

val free_list : ('a -> IdSet.t) -> 'a list -> IdSet.t
val free_expr : expr -> IdSet.t
val free_instr : instr -> IdSet.t
73 changes: 54 additions & 19 deletions spectec/src/il2al/translate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -468,14 +468,10 @@ let contains_diff target_ns e =
not (IdSet.is_empty free_ns) && IdSet.disjoint free_ns target_ns

let extract_diff lhs rhs ids cont =
let at = lhs.at in
match lhs.it with
(* TODO: Make this actually consider the targets *)
| CallE (f, args) ->
let hds, tl = Util.Lib.List.split_last args in
let new_lhs = tl in
let new_rhs = callE ("inverse_of_" ^ f, hds @ [ rhs ]) ~at:at in
new_lhs, new_rhs, cont
| CallE (_, _) ->
lhs, rhs, cont
| _ ->
let conds = ref [] in
let target_ns = IdSet.of_list (List.map it ids) in
Expand All @@ -484,7 +480,7 @@ let extract_diff lhs rhs ids cont =
e
else
let new_e = get_lhs_name () in
conds := !conds @ [ binE (EqOp, new_e, e) ];
conds := [ binE (EqOp, new_e, e) ];
new_e
) in
let walker = Al.Walk.walk_expr { Al.Walk.default_config with
Expand All @@ -502,31 +498,70 @@ let rec translate_bindings ids cont bindings =
| _ -> translate_letpr l r ids cont
) bindings cont

and translate_letpr lhs rhs free_ids cont =
(* helpers *)
let contains_free expr =
free_ids
|> List.map it
|> IdSet.of_list
|> IdSet.disjoint (free_expr expr)
|> not
in
let rhs2args e =
match e.it with
| TupE el -> el
| _ -> [e]
in
let args2lhs args = if List.length args = 1 then List.hd args else tupE args in

and translate_letpr lhs rhs ids cont =
let lhs, rhs, cont = extract_diff lhs rhs ids cont in
let lhs_at = lhs.at in
let rhs_at = rhs.at in
let at = over_region [ lhs_at; rhs_at ] in
let lhs, rhs, cont = extract_diff lhs rhs free_ids cont in
let at = over_region [ lhs.at; rhs.at ] in
match lhs.it with
| CallE (f, args) when List.for_all contains_free args ->
let new_lhs = args2lhs args in
let new_rhs = InvCallE (f, [], rhs2args rhs) $ lhs.at in
translate_letpr new_lhs new_rhs free_ids cont
| CallE (f, args) when List.exists contains_free args ->
(* Distinguish free arguments and bound arguments *)
let free_args_with_index, bound_args =
args
|> List.mapi (fun i arg ->
if contains_free arg then Some (arg, i), None
else None, Some arg
)
|> List.split
in
let bound_args = List.filter_map (fun x -> x) bound_args in
let free_args, indices =
free_args_with_index
|> List.filter_map (fun x -> x)
|> List.split
in

(* Free argument become new lhs & InvCallE become new rhs *)
let new_lhs = args2lhs free_args in
let new_rhs = InvCallE (f, indices, bound_args @ rhs2args rhs) $ lhs.at in

(* Recursively translate new_lhs and new_rhs *)
translate_letpr new_lhs new_rhs free_ids cont
| CaseE (tag, es) ->
let bindings, es' = extract_non_names es in
[
ifI (
isCaseOfE (rhs, tag),
letI (caseE (tag, es') ~at:lhs_at, rhs) ~at:at :: translate_bindings ids cont bindings,
letI (caseE (tag, es') ~at:lhs.at, rhs) ~at:at :: translate_bindings free_ids cont bindings,
[]
);
]
| ListE es ->
let bindings, es' = extract_non_names es in
if List.length es >= 2 then (* TODO: remove this. This is temporarily for a pure function returning stores *)
letI (listE es' ~at:lhs_at, rhs) ~at:at :: translate_bindings ids cont bindings
letI (listE es' ~at:lhs.at, rhs) ~at:at :: translate_bindings free_ids cont bindings
else
[
ifI
( binE (EqOp, lenE rhs, numE (Z.of_int (List.length es))),
letI (listE es' ~at:lhs_at, rhs) ~at:at :: translate_bindings ids cont bindings,
letI (listE es' ~at:lhs.at, rhs) ~at:at :: translate_bindings free_ids cont bindings,
[] );
]
| OptE None ->
Expand All @@ -548,7 +583,7 @@ and translate_letpr lhs rhs ids cont =
[
ifI
( isDefinedE rhs,
letI (optE (Some fresh) ~at:lhs_at, rhs) ~at:at :: translate_letpr e fresh ids cont,
letI (optE (Some fresh) ~at:lhs.at, rhs) ~at:at :: translate_letpr e fresh free_ids cont,
[] );
]
| BinE (AddOp, a, b) ->
Expand Down Expand Up @@ -580,15 +615,15 @@ and translate_letpr lhs rhs ids cont =
[
ifI
( cond,
letI (catE (prefix', suffix') ~at:lhs_at, rhs) ~at:at
:: translate_bindings ids cont (bindings_p @ bindings_s),
letI (catE (prefix', suffix') ~at:lhs.at, rhs) ~at:at
:: translate_bindings free_ids cont (bindings_p @ bindings_s),
[] );
]
| SubE (s, t) ->
[
ifI
( hasTypeE (rhs, t),
letI (varE s ~at:lhs_at, rhs) ~at:at :: cont,
letI (varE s ~at:lhs.at, rhs) ~at:at :: cont,
[] )
]
| _ -> letI (lhs, rhs) ~at:at :: cont
Expand Down
Loading

0 comments on commit 225e6e4

Please sign in to comment.