Skip to content

Commit

Permalink
Clean up ret_mode and region in Lambda.lfunction (#2985)
Browse files Browse the repository at this point in the history
Co-authored-by: Zesen Qian <[email protected]>
  • Loading branch information
mshinwell and riaqn authored Feb 18, 2025
1 parent e743d6c commit 1a73a72
Show file tree
Hide file tree
Showing 21 changed files with 96 additions and 170 deletions.
17 changes: 8 additions & 9 deletions lambda/lambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ and lfunction =
loc: scoped_location;
mode: locality_mode;
ret_mode: locality_mode;
region: bool; }
}

and lambda_while =
{ wh_cond : lambda;
Expand Down Expand Up @@ -938,7 +938,7 @@ let max_arity () =
(* 126 = 127 (the maximal number of parameters supported in C--)
- 1 (the hidden parameter containing the environment) *)

let lfunction' ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode ~region =
let lfunction' ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode =
assert (List.length params > 0);
assert (List.length params <= max_arity ());
(* A curried function type with n parameters has n arrows. Of these,
Expand All @@ -959,14 +959,13 @@ let lfunction' ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode ~region =
let nparams = List.length params in
assert (0 <= nlocal);
assert (nlocal <= nparams);
if not region then assert (nlocal >= 1);
if is_local_mode ret_mode then assert (nlocal >= 1);
if is_local_mode mode then assert (nlocal = nparams)
end;
{ kind; params; return; body; attr; loc; mode; ret_mode; region }
{ kind; params; return; body; attr; loc; mode; ret_mode }

let lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode ~region =
Lfunction
(lfunction' ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode ~region)
let lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode =
Lfunction (lfunction' ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode)

let lambda_unit = Lconst const_unit

Expand Down Expand Up @@ -1601,9 +1600,9 @@ let duplicate_function =
Ident.Map.empty).subst_lfunction

let map_lfunction f { kind; params; return; body; attr; loc;
mode; ret_mode; region } =
mode; ret_mode } =
let body = f body in
{ kind; params; return; body; attr; loc; mode; ret_mode; region }
{ kind; params; return; body; attr; loc; mode; ret_mode }

let shallow_map ~tail ~non_tail:f = function
| Lvar _
Expand Down
36 changes: 29 additions & 7 deletions lambda/lambda.mli
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,33 @@ type loop_attribute =
| Default_loop (* no [@loop] attribute *)

type curried_function_kind = { nlocal: int } [@@unboxed]
(* [nlocal] determines how many arguments may be partially applied
before the resulting closure must be locally allocated.
See [lfunction] for details *)
(** A well-formed function parameter list is of the form
[G @ L @ [ Final_arg ]],
where the values of G and L are of the form [More_args { partial_mode }],
where [partial_mode] has locality Global in G and locality Local in L.
[nlocal] is defined as follows:
- if {v |L| > 0 v}, then {v nlocal = |L| + 1 v}.
- if {v |L| = 0 v},
* if the function returns at mode local, the final arg has mode local,
or the function itself is allocated locally, then {v nlocal = 1 v}.
* otherwise, {v nlocal = 0 v}.
*)

(* CR-someday: Now that some functions' arity won't be changed downstream of
lambda (see [may_fuse_arity = false]), we could change [nlocal] to be
more expressive. I suggest the variant:
{[
type partial_application_is_local_when =
| Applied_up_to_nth_argument_from_end of int
| Never
]}
I believe this will allow us to get rid of the complicated logic for
|L| = 0, and help clarify how clients use this type. I plan on doing
this in a follow-on PR.
*)

type function_kind = Curried of curried_function_kind | Tupled

Expand Down Expand Up @@ -797,8 +821,8 @@ and lfunction = private
loc : scoped_location;
mode : locality_mode; (* locality of the closure itself *)
ret_mode: locality_mode;
region : bool; (* false if this function may locally
allocate in the caller's region *)
(** alloc mode of the returned value. Also indicates if the function might
allocate in the caller's region. *)
}

and lambda_while =
Expand Down Expand Up @@ -1002,7 +1026,6 @@ val lfunction :
loc:scoped_location ->
mode:locality_mode ->
ret_mode:locality_mode ->
region:bool ->
lambda

val lfunction' :
Expand All @@ -1014,7 +1037,6 @@ val lfunction' :
loc:scoped_location ->
mode:locality_mode ->
ret_mode:locality_mode ->
region:bool ->
lfunction


Expand Down
35 changes: 18 additions & 17 deletions lambda/simplif.ml
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,10 @@ let simplify_exits lam =
| Lletrec(bindings, body) ->
let bindings =
List.map (fun ({ def = {kind; params; return; body = l; attr; loc;
mode; ret_mode; region} }
mode; ret_mode } }
as rb) ->
let def =
lfunction' ~kind ~params ~return ~mode ~ret_mode ~region
lfunction' ~kind ~params ~return ~mode ~ret_mode
~body:(simplif ~layout:None ~try_depth l) ~attr ~loc
in
{ rb with def })
Expand Down Expand Up @@ -587,12 +587,12 @@ let simplify_lets lam =
| _ -> no_opt ()
end
| Lfunction{kind=outer_kind; params; return=outer_return; body = l;
attr=attr1; loc; ret_mode; mode; region=outer_region} ->
begin match outer_kind, outer_region, simplif l with
attr=attr1; loc; ret_mode; mode} ->
begin match outer_kind, ret_mode, simplif l with
Curried {nlocal=0},
true,
Alloc_heap,
Lfunction{kind=Curried _ as kind; params=params'; return=return2;
body; attr=attr2; loc; mode=inner_mode; ret_mode; region}
body; attr=attr2; loc; mode=inner_mode; ret_mode}
when optimize &&
attr1.may_fuse_arity && attr2.may_fuse_arity &&
List.length params + List.length params' <= Lambda.max_arity() ->
Expand All @@ -603,9 +603,9 @@ let simplify_lets lam =
type of the merged function taking [params @ params'] as
parameters is the type returned after applying [params']. *)
let return = return2 in
lfunction ~kind ~params:(params @ params') ~return ~body ~attr:attr1 ~loc ~mode ~ret_mode ~region
| kind, region, body ->
lfunction ~kind ~params ~return:outer_return ~body ~attr:attr1 ~loc ~mode ~ret_mode ~region
lfunction ~kind ~params:(params @ params') ~return ~body ~attr:attr1 ~loc ~mode ~ret_mode
| kind, ret_mode, body ->
lfunction ~kind ~params ~return:outer_return ~body ~attr:attr1 ~loc ~mode ~ret_mode
end
| Llet(_str, _k, v, Lvar w, l2) when optimize ->
Hashtbl.add subst v (simplif (Lvar w));
Expand Down Expand Up @@ -802,7 +802,7 @@ and emit_tail_infos_lfunction _is_tail lfun =
function's body. *)

let split_default_wrapper ~id:fun_id ~kind ~params ~return ~body
~attr ~loc ~mode ~ret_mode ~region:orig_region =
~attr ~loc ~mode ~ret_mode =
let rec aux map add_region = function
(* When compiling [fun ?(x=expr) -> body], this is first translated
to:
Expand Down Expand Up @@ -882,28 +882,29 @@ let split_default_wrapper ~id:fun_id ~kind ~params ~return ~body
let inner_fun =
lfunction' ~kind:(Curried {nlocal=0})
~params:new_ids
~return ~body ~attr ~loc ~mode ~ret_mode ~region:true
~return ~body ~attr ~loc ~mode ~ret_mode
in
(wrapper_body, { id = inner_id;
def = inner_fun })
in
try
(* TODO: enable this optimisation even in the presence of local returns *)
begin match kind with
| Curried {nlocal} when nlocal > 0 -> raise Exit
| Tupled when not orig_region -> raise Exit
| _ -> assert orig_region
begin match kind, ret_mode with
| Curried {nlocal}, _ when nlocal > 0 -> raise Exit
| Tupled, Alloc_local -> raise Exit
| _, Alloc_heap -> ()
| _, Alloc_local -> assert false
end;
let body, inner = aux [] false body in
let attr = { default_stub_attribute with zero_alloc = attr.zero_alloc } in
[{ id = fun_id;
def = lfunction' ~kind ~params ~return ~body ~attr ~loc
~mode ~ret_mode ~region:true };
~mode ~ret_mode };
inner]
with Exit ->
[{ id = fun_id;
def = lfunction' ~kind ~params ~return ~body ~attr ~loc
~mode ~ret_mode ~region:orig_region }]
~mode ~ret_mode }]

(* Simplify local let-bound functions: if all occurrences are
fully-applied function calls in the same "tail scope", replace the
Expand Down
1 change: 0 additions & 1 deletion lambda/simplif.mli
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,4 @@ val split_default_wrapper
-> loc:Lambda.scoped_location
-> mode:Lambda.locality_mode
-> ret_mode:Lambda.locality_mode
-> region:bool
-> rec_binding list
5 changes: 2 additions & 3 deletions lambda/tmc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1045,9 +1045,9 @@ and make_dps_variant var inner_ctx outer_ctx (lfun : lfunction) =
(Debuginfo.Scoped_location.to_location lfun.loc)
Warnings.Unused_tmc_attribute;
let direct =
let { kind; params; return; body = _; attr; loc; mode; ret_mode; region } = lfun in
let { kind; params; return; body = _; attr; loc; mode; ret_mode } = lfun in
let body = Choice.direct fun_choice in
lfunction' ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode ~region in
lfunction' ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode in
let dps =
let dst_param = {
var = Ident.create_local "dst";
Expand Down Expand Up @@ -1076,7 +1076,6 @@ and make_dps_variant var inner_ctx outer_ctx (lfun : lfunction) =
~loc:lfun.loc
~mode:lfun.mode
~ret_mode:lfun.ret_mode
~region:true
in
let dps_var = special.dps_id in
[var, direct; dps_var, dps]
Expand Down
3 changes: 1 addition & 2 deletions lambda/transl_list_comprehension.ml
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,7 @@ let rec translate_bindings ~transl_exp ~scopes ~loc ~inner_body ~accumulator =
mode = alloc_local
} ]
~return:layout_any_value ~attr:default_function_attribute ~loc
~mode:alloc_local ~ret_mode:alloc_local ~region:false
~body:(add_bindings body)
~mode:alloc_local ~ret_mode:alloc_local ~body:(add_bindings body)
in
let result =
Lambda_utils.apply ~loc ~mode:alloc_local (Lazy.force builder)
Expand Down
4 changes: 2 additions & 2 deletions lambda/translattribute.ml
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ let check_opaque_local loc attr =


let lfunction_with_attr ~attr
{ kind; params; return; body; attr=_; loc; mode; ret_mode; region } =
lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode ~region
{ kind; params; return; body; attr=_; loc; mode; ret_mode } =
lfunction ~kind ~params ~return ~body ~attr ~loc ~mode ~ret_mode

let add_inline_attribute expr loc attributes =
match expr with
Expand Down
16 changes: 4 additions & 12 deletions lambda/translclass.ml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ let layout_meth = layout_any_value
let layout_tables = layout_any_value


let lfunction ?(kind=Curried {nlocal=0}) ?(region=true) ?(ret_mode=alloc_heap) return_layout params body =
let lfunction ?(kind=Curried {nlocal=0}) ?(ret_mode=alloc_heap) return_layout params body =
if params = [] then body else
match kind, body with
| Curried {nlocal=0},
Lfunction {kind = Curried _ as kind; params = params';
body = body'; attr; loc; mode = Alloc_heap; ret_mode; region}
body = body'; attr; loc; mode = Alloc_heap; ret_mode }
when attr.may_fuse_arity &&
List.length params + List.length params' <= Lambda.max_arity() ->
lfunction ~kind ~params:(params @ params')
Expand All @@ -52,15 +52,13 @@ let lfunction ?(kind=Curried {nlocal=0}) ?(region=true) ?(ret_mode=alloc_heap) r
~loc
~mode:alloc_heap
~ret_mode
~region
| _ ->
lfunction ~kind ~params ~return:return_layout
~body
~attr:default_function_attribute
~loc:Loc_unknown
~mode:alloc_heap
~ret_mode
~region

let lapply ap =
match ap.ap_func with
Expand Down Expand Up @@ -230,7 +228,6 @@ let rec build_object_init ~scopes cl_table obj params inh_init obj_init cl =
~body
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
in
begin match obj_init with
Lfunction {kind = Curried {nlocal=0}; params; body = rem} ->
Expand Down Expand Up @@ -520,7 +517,6 @@ let rec transl_class_rebind ~scopes obj_init cl vf =
~body
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
in
(path, path_lam,
match obj_init with
Expand Down Expand Up @@ -799,7 +795,7 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
let new_ids_meths = ref [] in
let no_env_update _ _ env = env in
let msubst arr = function
Lfunction {kind = Curried _ as kind; region; ret_mode;
Lfunction {kind = Curried _ as kind; ret_mode;
params = self :: args; return; body} ->
let env = Ident.create_local "env" in
let body' =
Expand All @@ -811,7 +807,7 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
if not arr || !Clflags.debug then raise Not_found;
builtin_meths [self.name] env env2 (lfunction return args body')
with Not_found ->
[lfunction ~kind ~region ~ret_mode return (self :: args)
[lfunction ~kind ~ret_mode return (self :: args)
(if not (Ident.Set.mem env (free_variables body')) then body' else
Llet(Alias, layout_block, env,
Lprim(Pfield_computed Reads_vary,
Expand Down Expand Up @@ -885,7 +881,6 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
~return:layout_function
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
~params:[lparam cla layout_table]
~body:cl_init,
Dynamic (* Placeholder, real kind is computed in [lbody] below *))
Expand Down Expand Up @@ -917,7 +912,6 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
~return:layout_function
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
~params:[lparam cla layout_table] ~body:cl_init;
lambda_unit; lenvs],
Loc_unknown),
Expand Down Expand Up @@ -980,7 +974,6 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
~loc:Loc_unknown
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
~body:(def_ids cla cl_init), lam)
and lset cached i lam =
Lprim(Psetfield(i, Pointer, Assignment modify_heap),
Expand All @@ -999,7 +992,6 @@ let transl_class ~scopes ids cl_id pub_meths cl vflag =
~loc:Loc_unknown
~mode:alloc_heap
~ret_mode:alloc_heap
~region:true
~return:layout_function
~params:[lparam cla layout_table]
~body:(def_ids cla cl_init))
Expand Down
Loading

0 comments on commit 1a73a72

Please sign in to comment.