From 00f7a5ac7c47fcd6d68913a3b0d5891e024b4c56 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sat, 13 Jul 2019 19:59:39 -0400 Subject: [PATCH] Handle llvmcall in overdub_pass --- src/overdub.jl | 91 ++++++++++++++++++++++++++++++++++++++++ src/pass.jl | 32 ++++++++++++++ test/misctaggingtests.jl | 31 ++++++++++++++ test/misctests.jl | 38 ++++++++++++++++- 4 files changed, 191 insertions(+), 1 deletion(-) diff --git a/src/overdub.jl b/src/overdub.jl index 36da74d..7779db5 100644 --- a/src/overdub.jl +++ b/src/overdub.jl @@ -223,6 +223,97 @@ function overdub_pass!(reflection::Reflection, append!(overdubbed_code, code_info.code) append!(overdubbed_codelocs, code_info.codelocs) + #=== mark all `llvmcall`s as nooverdub, optionally mark all `Intrinsics`/`Builtins` nooverdub ===# + + function unravel_intrinsics(x) + stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x + if Base.Meta.isexpr(stmt, :call) + applycall = is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code) + f = applycall ? stmt.args[2] : stmt.args[1] + f = ir_element(f, overdubbed_code) + if f isa Expr && Base.Meta.isexpr(f, :call) && + is_ir_element(f.args[1], GlobalRef(Base, :getproperty), overdubbed_code) + + # resolve getproperty here + # this is formed by Core.Intrinsics.llvmcall + # %1 = Base.getproperty(Core, :Intrinsics) + # %2 = GlobalRef(%1, :llvmcall) + mod = ir_element(f.args[2], overdubbed_code) + if mod isa GlobalRef + mod = resolve_early(mod) # returns nothing if fails + end + if !(mod isa Module) + # might be nothing or a Slot + return nothing + end + fname = ir_element(f.args[3], overdubbed_code) + if fname isa QuoteNode + fname = fname.value + end + f = GlobalRef(mod, fname) + end + if f isa GlobalRef + f = resolve_early(f) + end + return f + end + return nothing + end + + # TODO: add user-facing flag to do this for all intrinsics + if !iskwfunc + insert_statements!(overdubbed_code, overdubbed_codelocs, + (x, i) -> begin + intrinsic = unravel_intrinsics(x) + if intrinsic === nothing + return nothing + end + if intrinsic === Core.Intrinsics.llvmcall + if istaggingenabled + count = 0 + for arg in stmt.args + if isa(arg, SSAValue) || isa(arg, SlotNumber) + count += 1 + end + end + return count + 1 + else + return 1 + end + end + end, + (x, i) -> begin + stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x + applycall = is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code) + intrinsic = unravel_intrinsics(x) + items = Any[] + args = nothing + if istaggingenabled + args = Any[] + for arg in stmt.args + if isa(arg, SSAValue) || isa(arg, SlotNumber) + push!(args, SSAValue(i + length(items))) + push!(items, Expr(:call, Expr(:nooverdub, GlobalRef(Cassette, :untag)), arg, overdub_ctx_slot)) + else + push!(result.args, arg) + end + end + end + idx = 1 + if applycall + idx = 2 + end + # using stmt.args[idx] instead of `intrinsic` leads to a bug + stmt.args[idx] = Expr(:nooverdub, intrinsic) + if args !== nothing + idx += 1 + stmt.args[idx:end] = args + end + push!(items, x) + return items + end) + end + #=== perform tagged module transformation if tagging is enabled ===# if istaggingenabled && !iskwfunc diff --git a/src/pass.jl b/src/pass.jl index b510700..029902e 100644 --- a/src/pass.jl +++ b/src/pass.jl @@ -197,3 +197,35 @@ function is_ir_element(x, y, code::Vector) end return result end + +""" + ir_element(x, code::Vector) + +Follows the series of `SSAValue` that define `x`. + +See also: [`is_ir_element`](@ref) +""" +function ir_element(x, code::Vector) + while isa(x, Core.SSAValue) + x = code[x.id] + end + return x +end + +""" + resolve_early(ref::GlobalRef) + +Resolves a `Core.Compiler.GlobalRef` during compilation, may +return `nothing` if the binding is not resolved or defined yet. +Only use this when you are certain that the result of the lookup +will not change. +""" +function resolve_early(ref::GlobalRef) + mod = ref.mod + name = ref.name + if Base.isbindingresolved(mod, name) && Base.isdefined(mod, name) + return getfield(mod, name) + else + return nothing + end +end diff --git a/test/misctaggingtests.jl b/test/misctaggingtests.jl index 9f43861..65f5f18 100644 --- a/test/misctaggingtests.jl +++ b/test/misctaggingtests.jl @@ -496,3 +496,34 @@ result = overdub(ctx, matrixliteral, tag(1, ctx, "hi")) @test metameta(result, ctx) == fill(Cassette.Meta("hi", Cassette.NoMetaMeta()), 2, 2) println("done (took ", time() - before_time, " seconds)") + +############################################################################################# + +print(" running TaggedLLVMCallCtx test...") +before_time = time() +Cassette.@context TaggedLLVMCallCtx +Cassette.metadatatype(::Type{<:ArrayIndexCtx}, ::Type{Float64}) = Float64 + +function Cassette.overdub(ctx::TaggedLLVMCallCtx, f, args...) + if Cassette.canrecurse(ctx, f, args...) + Cassette.recurse(ctx, f, args...) + else + Cassette.fallback(ctx, f, args...) + end +end + +function llvm_sin(x::Float64) + Core.Intrinsics.llvmcall( + ( + """declare double @llvm.sin.f64(double)""", + """%2 = call double @llvm.sin.f64(double %0) + ret double %2""" + ), + Float64, Tuple{Float64}, x + ) +end + +ctx = enabletagging(TaggedLLVMCallCtx(), llvm_sin) +Cassette.@overdub ctx llvm_sin(tag(4.0, ctx, 1.0)) + +println("done (took ", time() - before_time, " seconds)") \ No newline at end of file diff --git a/test/misctests.jl b/test/misctests.jl index 36482f7..1021105 100644 --- a/test/misctests.jl +++ b/test/misctests.jl @@ -633,8 +633,12 @@ callback() println("done (took ", time() - before_time, " seconds)") +############################################################################################# # Test overdubbing of a call overload invoke +print(" running CtxCallOverload test...") +before_time = time() + using LinearAlgebra struct Dense{F,S,T} @@ -664,11 +668,43 @@ let d = Dense(3,3) Cassette.overdub(CtxCallOverload(), d, data) end +println("done (took ", time() - before_time, " seconds)") + ############################################################################################# -print(" running OverdubOverdubCtx test...") +println(" running OverdubOverdubCtx test...") # Fixed in PR #148 Cassette.@context OverdubOverdubCtx; overdub_overdub_me() = 2 Cassette.overdub(OverdubOverdubCtx(), Cassette.overdub, OverdubOverdubCtx(), overdub_overdub_me) + +############################################################################################# + +print(" running LLVMCallCtx test...") +before_time = time() +Cassette.@context LLVMCallCtx + +# This overdub does nothing, intentionally not marked `@inline` +function Cassette.overdub(ctx::LLVMCallCtx, f, args...) + if Cassette.canrecurse(ctx, f, args...) + Cassette.recurse(ctx, f, args...) + else + Cassette.fallback(ctx, f, args...) + end +end + +function llvm_sin(x::Float64) + Core.Intrinsics.llvmcall( + ( + """declare double @llvm.sin.f64(double)""", + """%2 = call double @llvm.sin.f64(double %0) + ret double %2""" + ), + Float64, Tuple{Float64}, x + ) +end + +Cassette.@overdub LLVMCallCtx() llvm_sin(4.0) + +println("done (took ", time() - before_time, " seconds)")