Skip to content

Commit

Permalink
Handle llvmcall in overdub_pass
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Oct 12, 2019
1 parent b318b15 commit 00f7a5a
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 1 deletion.
91 changes: 91 additions & 0 deletions src/overdub.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,97 @@ function overdub_pass!(reflection::Reflection,
append!(overdubbed_code, code_info.code)
append!(overdubbed_codelocs, code_info.codelocs)

#=== mark all `llvmcall`s as nooverdub, optionally mark all `Intrinsics`/`Builtins` nooverdub ===#

function unravel_intrinsics(x)
stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
if Base.Meta.isexpr(stmt, :call)
applycall = is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
f = applycall ? stmt.args[2] : stmt.args[1]
f = ir_element(f, overdubbed_code)
if f isa Expr && Base.Meta.isexpr(f, :call) &&
is_ir_element(f.args[1], GlobalRef(Base, :getproperty), overdubbed_code)

# resolve getproperty here
# this is formed by Core.Intrinsics.llvmcall
# %1 = Base.getproperty(Core, :Intrinsics)
# %2 = GlobalRef(%1, :llvmcall)
mod = ir_element(f.args[2], overdubbed_code)
if mod isa GlobalRef
mod = resolve_early(mod) # returns nothing if fails
end
if !(mod isa Module)
# might be nothing or a Slot
return nothing
end
fname = ir_element(f.args[3], overdubbed_code)
if fname isa QuoteNode
fname = fname.value
end
f = GlobalRef(mod, fname)
end
if f isa GlobalRef
f = resolve_early(f)
end
return f
end
return nothing
end

# TODO: add user-facing flag to do this for all intrinsics
if !iskwfunc
insert_statements!(overdubbed_code, overdubbed_codelocs,
(x, i) -> begin
intrinsic = unravel_intrinsics(x)
if intrinsic === nothing
return nothing
end
if intrinsic === Core.Intrinsics.llvmcall
if istaggingenabled
count = 0
for arg in stmt.args
if isa(arg, SSAValue) || isa(arg, SlotNumber)
count += 1
end
end
return count + 1
else
return 1
end
end
end,
(x, i) -> begin
stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
applycall = is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
intrinsic = unravel_intrinsics(x)
items = Any[]
args = nothing
if istaggingenabled
args = Any[]
for arg in stmt.args
if isa(arg, SSAValue) || isa(arg, SlotNumber)
push!(args, SSAValue(i + length(items)))
push!(items, Expr(:call, Expr(:nooverdub, GlobalRef(Cassette, :untag)), arg, overdub_ctx_slot))
else
push!(result.args, arg)
end
end
end
idx = 1
if applycall
idx = 2
end
# using stmt.args[idx] instead of `intrinsic` leads to a bug
stmt.args[idx] = Expr(:nooverdub, intrinsic)
if args !== nothing
idx += 1
stmt.args[idx:end] = args
end
push!(items, x)
return items
end)
end

#=== perform tagged module transformation if tagging is enabled ===#

if istaggingenabled && !iskwfunc
Expand Down
32 changes: 32 additions & 0 deletions src/pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,35 @@ function is_ir_element(x, y, code::Vector)
end
return result
end

"""
ir_element(x, code::Vector)
Follows the series of `SSAValue` that define `x`.
See also: [`is_ir_element`](@ref)
"""
function ir_element(x, code::Vector)
while isa(x, Core.SSAValue)
x = code[x.id]
end
return x
end

"""
resolve_early(ref::GlobalRef)
Resolves a `Core.Compiler.GlobalRef` during compilation, may
return `nothing` if the binding is not resolved or defined yet.
Only use this when you are certain that the result of the lookup
will not change.
"""
function resolve_early(ref::GlobalRef)
mod = ref.mod
name = ref.name
if Base.isbindingresolved(mod, name) && Base.isdefined(mod, name)
return getfield(mod, name)
else
return nothing
end
end
31 changes: 31 additions & 0 deletions test/misctaggingtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,34 @@ result = overdub(ctx, matrixliteral, tag(1, ctx, "hi"))
@test metameta(result, ctx) == fill(Cassette.Meta("hi", Cassette.NoMetaMeta()), 2, 2)

println("done (took ", time() - before_time, " seconds)")

#############################################################################################

print(" running TaggedLLVMCallCtx test...")
before_time = time()
Cassette.@context TaggedLLVMCallCtx
Cassette.metadatatype(::Type{<:ArrayIndexCtx}, ::Type{Float64}) = Float64

function Cassette.overdub(ctx::TaggedLLVMCallCtx, f, args...)
if Cassette.canrecurse(ctx, f, args...)
Cassette.recurse(ctx, f, args...)
else
Cassette.fallback(ctx, f, args...)
end
end

function llvm_sin(x::Float64)
Core.Intrinsics.llvmcall(
(
"""declare double @llvm.sin.f64(double)""",
"""%2 = call double @llvm.sin.f64(double %0)
ret double %2"""
),
Float64, Tuple{Float64}, x
)
end

ctx = enabletagging(TaggedLLVMCallCtx(), llvm_sin)
Cassette.@overdub ctx llvm_sin(tag(4.0, ctx, 1.0))

println("done (took ", time() - before_time, " seconds)")
38 changes: 37 additions & 1 deletion test/misctests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -633,8 +633,12 @@ callback()

println("done (took ", time() - before_time, " seconds)")

#############################################################################################
# Test overdubbing of a call overload invoke

print(" running CtxCallOverload test...")
before_time = time()

using LinearAlgebra

struct Dense{F,S,T}
Expand Down Expand Up @@ -664,11 +668,43 @@ let d = Dense(3,3)
Cassette.overdub(CtxCallOverload(), d, data)
end

println("done (took ", time() - before_time, " seconds)")

#############################################################################################

print(" running OverdubOverdubCtx test...")
println(" running OverdubOverdubCtx test...")

# Fixed in PR #148
Cassette.@context OverdubOverdubCtx;
overdub_overdub_me() = 2
Cassette.overdub(OverdubOverdubCtx(), Cassette.overdub, OverdubOverdubCtx(), overdub_overdub_me)

#############################################################################################

print(" running LLVMCallCtx test...")
before_time = time()
Cassette.@context LLVMCallCtx

# This overdub does nothing, intentionally not marked `@inline`
function Cassette.overdub(ctx::LLVMCallCtx, f, args...)
if Cassette.canrecurse(ctx, f, args...)
Cassette.recurse(ctx, f, args...)
else
Cassette.fallback(ctx, f, args...)
end
end

function llvm_sin(x::Float64)
Core.Intrinsics.llvmcall(
(
"""declare double @llvm.sin.f64(double)""",
"""%2 = call double @llvm.sin.f64(double %0)
ret double %2"""
),
Float64, Tuple{Float64}, x
)
end

Cassette.@overdub LLVMCallCtx() llvm_sin(4.0)

println("done (took ", time() - before_time, " seconds)")

0 comments on commit 00f7a5a

Please sign in to comment.