Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wasm] Optimize calls to a statically known function #1790

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions compiler/lib-wasm/generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,16 @@ module Generate (Target : Target_sig.S) = struct

let zero_divide_pc = -2

let exact_call kind =
match kind with
| Generic -> false
| Exact | Known _ -> true

let rec translate_expr ctx context x e =
match e with
| Apply { f; args; exact }
when exact || List.length args = if Var.Set.mem x ctx.in_cps then 2 else 1 ->
| Apply { f; args; kind }
when exact_call kind || List.length args = if Var.Set.mem x ctx.in_cps then 2 else 1
->
let rec loop acc l =
match l with
| [] -> (
Expand All @@ -204,13 +210,14 @@ module Generate (Target : Target_sig.S) = struct
if b
then return (W.Call (f, List.rev (closure :: acc)))
else
match funct with
| W.RefFunc g ->
match funct, kind with
| W.RefFunc g, _ ->
(* Functions with constant closures ignore their
environment. In case of partial application, we
still need the closure. *)
let* cl = if exact then Value.unit else return closure in
let* cl = if exact_call kind then Value.unit else return closure in
return (W.Call (g, List.rev (cl :: acc)))
| _, Known g -> return (W.Call (g, List.rev (closure :: acc)))
| _ -> return (W.Call_ref (ty, funct, List.rev (closure :: acc))))
| x :: r ->
let* x = load x in
Expand Down
17 changes: 12 additions & 5 deletions compiler/lib/code.ml
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,16 @@ type field_type =
| Non_float
| Float

type apply_kind =
| Generic
| Exact
| Known of Var.t

type expr =
| Apply of
{ f : Var.t
; args : Var.t list
; exact : bool
; kind : apply_kind
}
| Block of int * Var.t array * array_or_not * mutability
| Field of Var.t * int * field_type
Expand Down Expand Up @@ -556,10 +561,12 @@ module Print = struct

let expr f e =
match e with
| Apply { f = g; args; exact } ->
if exact
then Format.fprintf f "%a!(%a)" Var.print g var_list args
else Format.fprintf f "%a(%a)" Var.print g var_list args
| Apply { f = g; args; kind } -> (
match kind with
| Generic -> Format.fprintf f "%a(%a)" Var.print g var_list args
| Exact -> Format.fprintf f "%a!(%a)" Var.print g var_list args
| Known h -> Format.fprintf f "%a{=%a}(%a)" Var.print g Var.print h var_list args
)
| Block (t, a, _, mut) ->
Format.fprintf
f
Expand Down
7 changes: 6 additions & 1 deletion compiler/lib/code.mli
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,16 @@ type field_type =
| Non_float
| Float

type apply_kind =
| Generic
| Exact (* # of arguments = # of parameters *)
| Known of Var.t (* Exact and we know which function is called *)

type expr =
| Apply of
{ f : Var.t
; args : Var.t list
; exact : bool (* if true, then # of arguments = # of parameters *)
; kind : apply_kind
}
| Block of int * Var.t array * array_or_not * mutability
| Field of Var.t * int * field_type
Expand Down
6 changes: 3 additions & 3 deletions compiler/lib/driver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ let effects ~deadcode_sentinal p =
p
|> Effects.f ~flow_info:info ~live_vars
|> map_fst
(match effects with
| `Double_translation -> Fun.id
| `Cps -> Lambda_lifting.f)
(match Config.target (), effects with
| `Wasm, _ | _, `Double_translation -> Fun.id
| `JavaScript, `Cps -> Lambda_lifting.f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a bug fix ?

| `Disabled | `Jspi ->
( p
, (Code.Var.Set.empty : Effects.trampolined_calls)
Expand Down
68 changes: 40 additions & 28 deletions compiler/lib/effects.ml
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,15 @@ let allocate_closure ~st ~params ~body ~branch =
let name = Var.fresh () in
[ Let (name, Closure (params, (pc, []))) ], name

let tail_call ~st ?(instrs = []) ~exact ~in_cps ~check ~f args =
assert (exact || check);
let tail_call ~st ?(instrs = []) ~kind ~in_cps ~check ~f args =
assert (
match kind with
| Generic -> check
| Exact | Known _ -> true);
let ret = Var.fresh () in
if check then st.trampolined_calls := Var.Set.add ret !(st.trampolined_calls);
if in_cps then st.in_cps := Var.Set.add ret !(st.in_cps);
instrs @ [ Let (ret, Apply { f; args; exact }) ], Return ret
instrs @ [ Let (ret, Apply { f; args; kind }) ], Return ret

let cps_branch ~st ~src (pc, args) =
match Addr.Set.mem pc st.blocks_to_transform with
Expand All @@ -359,14 +362,8 @@ let cps_branch ~st ~src (pc, args) =
(* We check the stack depth only for backward edges (so, at
least once per loop iteration) *)
let check = Hashtbl.find st.block_order src >= Hashtbl.find st.block_order pc in
tail_call
~st
~instrs
~exact:true
~in_cps:false
~check
~f:(closure_of_pc ~st pc)
args
let f = closure_of_pc ~st pc in
tail_call ~st ~instrs ~kind:(Known f) ~in_cps:false ~check ~f args

let cps_jump_cont ~st ~src ((pc, _) as cont) =
match Addr.Set.mem pc st.blocks_to_transform with
Expand Down Expand Up @@ -433,7 +430,7 @@ let cps_last ~st ~alloc_jump_closures pc (last : last) ~k : instr list * last =
(* If the number of successive 'returns' is unbounded in CPS, it
means that we have an unbounded of calls in direct style
(even with tail call optimization) *)
tail_call ~st ~exact:true ~in_cps:false ~check:false ~f:k [ x ]
tail_call ~st ~kind:Exact ~in_cps:false ~check:false ~f:k [ x ]
| Raise (x, rmode) -> (
assert (List.is_empty alloc_jump_closures);
match Hashtbl.find_opt st.matching_exn_handler pc with
Expand Down Expand Up @@ -468,7 +465,7 @@ let cps_last ~st ~alloc_jump_closures pc (last : last) ~k : instr list * last =
tail_call
~st
~instrs:(Let (exn_handler, Prim (Extern "caml_pop_trap", [])) :: instrs)
~exact:true
~kind:Exact
~in_cps:false
~check:false
~f:exn_handler
Expand Down Expand Up @@ -522,6 +519,14 @@ let cps_last ~st ~alloc_jump_closures pc (last : last) ~k : instr list * last =
@ (Let (exn_handler, Prim (Extern "caml_pop_trap", [])) :: body)
, branch ))

let refine_kind k k' =
match k, k' with
| Known _, _ -> k
| _, Known _ -> k'
| Exact, _ -> k
| _, Exact -> k'
| Generic, Generic -> k

let rewrite_instr ~st (instr : instr) : instr =
match instr with
| Let (x, Closure (_, (pc, _))) when Var.Set.mem x st.cps_needed ->
Expand All @@ -542,27 +547,34 @@ let rewrite_instr ~st (instr : instr) : instr =
(Extern "caml_alloc_dummy_function", [ size; Pc (Int (Targetint.succ a)) ])
)
| _ -> assert false)
| Let (x, Apply { f; args; _ }) when not (Var.Set.mem x st.cps_needed) ->
| Let (x, Apply { f; args; kind }) when not (Var.Set.mem x st.cps_needed) ->
(* At the moment, we turn into CPS any function not called with
the right number of parameter *)
assert (
let kind' =
(* If this function is unknown to the global flow analysis, then it was
introduced by the lambda lifting and we don't have exactness info any more. *)
Var.idx f >= Var.Tbl.length st.flow_info.info_approximation
|| Global_flow.exact_call st.flow_info f (List.length args));
Let (x, Apply { f; args; exact = true })
if Var.idx f >= Var.Tbl.length st.flow_info.info_approximation
then Exact
else Global_flow.apply_kind st.flow_info f (List.length args)
in
assert (
match kind' with
| Generic -> false
| Exact | Known _ -> true);
Let (x, Apply { f; args; kind = refine_kind kind kind' })
| Let (_, e) when effect_primitive_or_application e ->
(* For the CPS target, applications of CPS functions and effect primitives require
more work (allocating a continuation and/or modifying end-of-block branches) and
are handled in a specialized function. *)
assert false
| _ -> instr

let call_exact flow_info (f : Var.t) nargs : bool =
let call_kind flow_info (f : Var.t) nargs =
(* If [f] is unknown to the global flow analysis, then it was introduced by
the lambda lifting and we don't have exactness about it. *)
Var.idx f < Var.Tbl.length flow_info.Global_flow.info_approximation
&& Global_flow.exact_call flow_info f nargs
if Var.idx f >= Var.Tbl.length flow_info.Global_flow.info_approximation
then Generic
else Global_flow.apply_kind flow_info f nargs

let cps_instr ~st (instr : instr) : instr list =
match instr with
Expand All @@ -571,7 +583,7 @@ let cps_instr ~st (instr : instr) : instr list =
Otherwise, the runtime primitive is used. *)
let unit = Var.fresh_n "unit" in
[ Let (unit, Constant (Int Targetint.zero))
; Let (x, Apply { exact = call_exact st.flow_info f 1; f; args = [ unit ] })
; Let (x, Apply { kind = call_kind st.flow_info f 1; f; args = [ unit ] })
]
| _ -> [ rewrite_instr ~st instr ]

Expand Down Expand Up @@ -646,11 +658,11 @@ let cps_block ~st ~k ~orig_pc block =
[ Let (x, e) ], Return x)
in
match e with
| Apply { f; args; exact } when Var.Set.mem x st.cps_needed ->
| Apply { f; args; kind } when Var.Set.mem x st.cps_needed ->
Some
(fun ~k ->
let exact = exact || call_exact st.flow_info f (List.length args) in
tail_call ~st ~exact ~in_cps:true ~check:true ~f (args @ [ k ]))
let kind = refine_kind kind (call_kind st.flow_info f (List.length args)) in
tail_call ~st ~kind ~in_cps:true ~check:true ~f (args @ [ k ]))
| Prim (Extern "%resume", [ Pv stack; Pv f; Pv arg; tail ]) ->
Some
(fun ~k ->
Expand All @@ -659,7 +671,7 @@ let cps_block ~st ~k ~orig_pc block =
~st
~instrs:
[ Let (k', Prim (Extern "caml_resume_stack", [ Pv stack; tail; Pv k ])) ]
~exact:(call_exact st.flow_info f 1)
~kind:(call_kind st.flow_info f 1)
~in_cps:true
~check:true
~f
Expand Down Expand Up @@ -747,8 +759,8 @@ let rewrite_direct_block ~st ~cps_needed ~closure_info ~pc block =
(* We just need to call [f] in direct style. *)
let unit = Var.fresh_n "unit" in
let unit_val = Int Targetint.zero in
let exact = call_exact st.flow_info f 1 in
[ Let (unit, Constant unit_val); Let (x, Apply { exact; f; args = [ unit ] }) ]
let kind = call_kind st.flow_info f 1 in
[ Let (unit, Constant unit_val); Let (x, Apply { kind; f; args = [ unit ] }) ]
| (Let _ | Assign _ | Set_field _ | Offset_ref _ | Array_set _ | Event _) as instr
-> [ instr ]
in
Expand Down
14 changes: 12 additions & 2 deletions compiler/lib/generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,14 @@ module Share = struct
List.fold_left block.body ~init:share ~f:(fun share i ->
match i with
| Let (_, Constant c) -> get_constant c share
| Let (x, Apply { args; exact; _ }) ->
| Let (x, Apply { args; kind; _ }) ->
let trampolined = Var.Set.mem x trampolined_calls in
let in_cps = Var.Set.mem x in_cps in
let exact =
match kind with
| Generic -> false
| Exact | Known _ -> true
in
if (not exact) || trampolined
then
add_apply
Expand Down Expand Up @@ -1230,7 +1235,12 @@ let remove_unused_tail_args ctx exact trampolined args =
let rec translate_expr ctx loc x e level : (_ * J.statement_list) Expr_builder.t =
let open Expr_builder in
match e with
| Apply { f; args; exact } ->
| Apply { f; args; kind } ->
let exact =
match kind with
| Generic -> false
| Exact | Known _ -> true
in
let trampolined = Var.Set.mem x ctx.Ctx.trampolined_calls in
let args = remove_unused_tail_args ctx exact trampolined args in
let* () = info ~need_loc:true mutator_p in
Expand Down
14 changes: 7 additions & 7 deletions compiler/lib/generate_closure.ml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ let rec collect_apply pc blocks visited tc =
match block.branch with
| Return x -> (
match List.last block.body with
| Some (Let (y, Apply { f; exact = true; _ })) when Code.Var.compare x y = 0 ->
Some (add_multi f pc tc)
| Some (Let (y, Apply { f; kind = Exact | Known _; _ }))
when Code.Var.compare x y = 0 -> Some (add_multi f pc tc)
| None -> None
| Some _ -> None)
| _ -> None
Expand Down Expand Up @@ -100,7 +100,7 @@ module Trampoline = struct
match counter with
| None ->
{ params = []
; body = [ Let (return, Apply { f; args; exact = true }) ]
; body = [ Let (return, Apply { f; args; kind = Known f }) ]
; branch = Return return
}
| Some counter ->
Expand All @@ -110,7 +110,7 @@ module Trampoline = struct
[ Let
( counter_plus_1
, Prim (Extern "%int_add", [ Pv counter; Pc (Int Targetint.one) ]) )
; Let (return, Apply { f; args = counter_plus_1 :: args; exact = true })
; Let (return, Apply { f; args = counter_plus_1 :: args; kind = Known f })
]
; branch = Return return
}
Expand Down Expand Up @@ -139,14 +139,14 @@ module Trampoline = struct
(match counter with
| None ->
[ Event loc
; Let (result1, Apply { f; args; exact = true })
; Let (result1, Apply { f; args; kind = Known f })
; Event Parse_info.zero
; Let (result2, Prim (Extern "caml_trampoline", [ Pv result1 ]))
]
| Some counter ->
[ Event loc
; Let (counter, Constant (Int Targetint.zero))
; Let (result1, Apply { f; args = counter :: args; exact = true })
; Let (result1, Apply { f; args = counter :: args; kind = Known f })
; Event Parse_info.zero
; Let (result2, Prim (Extern "caml_trampoline", [ Pv result1 ]))
])
Expand Down Expand Up @@ -222,7 +222,7 @@ module Trampoline = struct
let bounce_call_pc = free_pc + 1 in
let free_pc = free_pc + 2 in
match List.rev block.body with
| Let (x, Apply { f; args; exact = true }) :: rem_rev ->
| Let (x, Apply { f; args; kind = Exact | Known _ }) :: rem_rev ->
assert (Var.equal f ci.f_name);
let blocks =
Addr.Map.add
Expand Down
39 changes: 26 additions & 13 deletions compiler/lib/global_flow.ml
Original file line number Diff line number Diff line change
Expand Up @@ -704,17 +704,29 @@ let f ~fast p =
; info_return_vals = rets
}

let exact_call info f n =
let apply_kind info f n =
match Var.Tbl.get info.info_approximation f with
| Top | Values { others = true; _ } -> false
| Values { known; others = false } ->
Var.Set.for_all
(fun g ->
match info.info_defs.(Var.idx g) with
| Expr (Closure (params, _)) -> List.length params = n
| Expr (Block _) -> true
| Expr _ | Phi _ -> assert false)
known
| Top | Values { others = true; _ } -> Generic
| Values { known; others = false } -> (
match
Var.Set.fold
(fun g acc ->
match info.info_defs.(Var.idx g) with
| Expr (Closure (params, _)) ->
if List.length params = n
then
match acc with
| None -> Some (Known g)
| Some (Known _) -> Some Exact
| Some (Exact | Generic) -> acc
else Some Generic
| Expr (Block _) -> acc
| Expr _ | Phi _ -> assert false)
known
None
with
| None -> Exact
| Some kind -> kind)

let function_arity info f =
match Var.Tbl.get info.info_approximation f with
Expand All @@ -727,9 +739,10 @@ let function_arity info f =
| Expr (Closure (params, _)) -> (
let n = List.length params in
match acc with
| None -> Some (Some n)
| Some (Some n') when n <> n' -> Some None
| Some _ -> acc)
| None -> Some (Some (n, Known g))
| Some (Some (n', _)) when n <> n' -> Some None
| Some (Some (_, Known _)) -> Some (Some (n, Exact))
| Some (None | Some (_, (Exact | Generic))) -> acc)
| Expr (Block _) -> acc
| Expr _ | Phi _ -> assert false)
known
Expand Down
Loading
Loading