Skip to content

Commit

Permalink
* handle const function objects.
Browse files Browse the repository at this point in the history
* don't generate argument values for captured values in top-level functions, these are directly passed to the code generation
  • Loading branch information
jumerckx committed Jul 17, 2024
1 parent 86e7905 commit 203d0c2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 24 deletions.
7 changes: 6 additions & 1 deletion src/Generate/CodegenContext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ mutable struct CodegenContext{T} <: AbstractCodegenContext

function (cg::CodegenContext)(f, types)
mod = IR.Module()
methods = collect_methods(f, types)
toplevel, methods = collect_methods(f, types)
funcs = []

captures = collect_captures(f)
mlir_func = generate!(cg, toplevel.ir, toplevel.ret; mi=toplevel.mi, captures)
push!(funcs, mlir_func)

for (mi, (ir, ret)) in methods
mlir_func = generate!(cg, ir, ret; mi)
push!(funcs, mlir_func)
Expand Down
54 changes: 31 additions & 23 deletions src/Generate/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ function handle_return(cg, val::T) where T
generate_return(cg, returnvalues; location=IR.Location())
end
function handle_invoke(cg, fname, ret, args...)
if (first(args) isa Core.Const)
# disregard first argument which contains the called function, if it is const.
args = args[begin+1:end]
if first(args) isa Core.Const
args[begin] = args[begin].val
end
unpacked = []
for arg in args
Expand Down Expand Up @@ -61,7 +60,7 @@ function infer_ir!(ir, interp::CC.AbstractInterpreter, mi::CC.MethodInstance)
min_world = world = CC.get_world_counter()
max_world = Base.get_world_counter()
irsv = CC.IRInterpretationState(interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world)
rt = CC._ir_abstract_constant_propagation(interp, irsv)
ret = CC._ir_abstract_constant_propagation(interp, irsv)
return ir
end

Expand Down Expand Up @@ -120,7 +119,7 @@ region = builder(args)
For a more easy-to-use interface, use methods from `CodegenContext`.
"""
function generate!(cg, ir::CC.IRCode, ret; mi=get_toplevel_mi_from_ir(ir, @__MODULE__))
function generate!(cg, ir::CC.IRCode, ret; mi=get_toplevel_mi_from_ir(ir, @__MODULE__), captures=nothing)
original_currentblock = IR.currentblock[]

reg = IR.Region()
Expand All @@ -131,15 +130,14 @@ function generate!(cg, ir::CC.IRCode, ret; mi=get_toplevel_mi_from_ir(ir, @__MOD
]
entryblock = first(blocks)

# the first argument is the function object itself. In the case where this is a closure,
# we want the captured variables to be included as function arguments in the generated function
captures = if first(ir.argtypes) isa Core.Const
()
else
map(IR.unpack(first(ir.argtypes))) do t
if isnothing(captures)
# The first argument is the function object itself. In the case where this is a closure,
# we want the captured variables to be included as function arguments in the generated function.
func_arg = first(ir.argtypes) isa Core.Const ? first(ir.argtypes).val : first(ir.argtypes)
captures = map(IR.unpack(func_arg)) do t
val = IR.push_argument!(entryblock, IR.Type(t))
end |> Tuple
end
end

args = map(enumerate(ir.argtypes[begin+1:end])) do (i, argtype)
temp = map(IR.unpack(argtype)) do t
Expand All @@ -158,7 +156,6 @@ function generate!(cg, ir::CC.IRCode, ret; mi=get_toplevel_mi_from_ir(ir, @__MOD
end

transformed_ir = transform(cg, ir, blocks, next_block)

builder! = Core.OpaqueClosure(transformed_ir, captures...; do_compile=true)

builder!(args...)
Expand All @@ -172,8 +169,8 @@ function generate!(cg, ir::CC.IRCode, ret; mi=get_toplevel_mi_from_ir(ir, @__MOD

return f_mlir
end
generate(cg, ir::CC.IRCode, ret; mi=get_toplevel_mi_from_ir(ir, @__MODULE__)) = generate!(cg, Core.Compiler.copy(ir), ret; mi)
generate(cg, f, types) = generate!(cg, only(Core.Compiler.code_ircode(f, types, interp=MLIRInterpreter()))...)
generate(cg, ir::CC.IRCode, ret; mi=get_toplevel_mi_from_ir(ir, @__MODULE__), captures=nothing) = generate!(cg, Core.Compiler.copy(ir), ret; mi, captures)
generate(cg, f, types; captures=getfield.(f, fieldnames(typeof(f)))) = generate!(cg, only(Core.Compiler.code_ircode(f, types, interp=MLIRInterpreter()))...; captures)


function get_mi(f, types)
Expand Down Expand Up @@ -203,27 +200,38 @@ function find_invokes(ir)
return callees
end

collect_captures(f) = getfield.(f, fieldnames(typeof(f)))
collect_captures(f::Core.Const) = collect_captures(f.val)

function collect_methods(f, types)
mi = get_mi(f, types)
ir, rt = only(Core.Compiler.code_ircode_by_type(mi.specTypes, interp=MLIRInterpreter()))
ir, ret = only(Core.Compiler.code_ircode_by_type(mi.specTypes, interp=MLIRInterpreter()))

worklist = Core.Compiler.IRCode[ir]
methods = Dict{Core.MethodInstance, Tuple{Core.Compiler.IRCode, Type}}(
mi => (ir, rt)
mi => (ir, ret)
)
while !isempty(worklist)
code = pop!(worklist)
callees = find_invokes(code)
for callee in callees
if !haskey(methods, callee) && !is_intrinsic(callee.specTypes)
ir, rt = only(Core.Compiler.code_ircode_by_type(callee.specTypes, interp=MLIRInterpreter()))
methods[callee] = (ir, rt)
push!(worklist, ir)
if !haskey(methods, callee)
if is_intrinsic(callee.specTypes)
# TODO: if an intrinsic does nested code generation,
# invocations that are called within should be captured as well.
nothing
else
ir, ret = only(Core.Compiler.code_ircode_by_type(callee.specTypes, interp=MLIRInterpreter()))
methods[callee] = (ir, ret)
push!(worklist, ir)
end
end
end
end

return methods
ir, ret = methods[mi]
toplevel = (; mi, ir, ret)
delete!(methods, mi) # remove the toplevel method from the list of methods to generate
return toplevel, methods
end


Expand Down

0 comments on commit 203d0c2

Please sign in to comment.