Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llvmcall overdubbing #139

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions src/overdub.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,95 @@ function overdub_pass!(reflection::Reflection,
append!(overdubbed_code, code_info.code)
append!(overdubbed_codelocs, code_info.codelocs)

#=== mark all `llvmcall`s as nooverdub, optionally mark all `Intrinsics`/`Builtins` nooverdub ===#

function unravel_intrinsics(stmt)
if Base.Meta.isexpr(stmt, :call)
applycall = is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
f = applycall ? stmt.args[2] : stmt.args[1]
f = ir_element(f, overdubbed_code)
if f isa Expr && Base.Meta.isexpr(f, :call) &&
is_ir_element(f.args[1], GlobalRef(Base, :getproperty), overdubbed_code)

# resolve getproperty here
# this is formed by Core.Intrinsics.llvmcall
# %1 = Base.getproperty(Core, :Intrinsics)
# %2 = GlobalRef(%1, :llvmcall)
mod = ir_element(f.args[2], overdubbed_code)
if mod isa GlobalRef
mod = resolve_early(mod) # returns nothing if fails
end
if !(mod isa Module)
# might be nothing or a Slot
return nothing
end
fname = ir_element(f.args[3], overdubbed_code)
if fname isa QuoteNode
fname = fname.value
end
f = GlobalRef(mod, fname)
end
if f isa GlobalRef
f = resolve_early(f)
end
return f
end
return nothing
end

# TODO: add user-facing flag to do this for all intrinsics
if !iskwfunc
insert_statements!(overdubbed_code, overdubbed_codelocs,
(x, i) -> begin
stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
intrinsic = unravel_intrinsics(stmt)
intrinsic === nothing && return nothing

applycall = is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
if intrinsic === Core.Intrinsics.llvmcall
if istaggingenabled
count = 0
offset = applycall ? 3 : 2
while offset <= length(stmt.args)
arg = stmt.args[offset]
if isa(arg, SSAValue) || isa(arg, SlotNumber)
count += 1
end
offset += 1
end
return count + 1
else
return 1
end
end
end,
(x, i) -> begin
stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
intrinsic = unravel_intrinsics(stmt)
applycall = is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
items = Any[]

idx = applycall ? 2 : 1
# using stmt.args[idx] instead of `intrinsic` leads to a bug
stmt.args[idx] = Expr(:nooverdub, intrinsic)
idx += 1

if istaggingenabled
while idx <= length(stmt.args)
arg = stmt.args[idx]
if isa(arg, SSAValue) || isa(arg, SlotNumber)
stmt.args[idx] = SSAValue(i + length(items))
push!(items, Expr(:call, Expr(:nooverdub, GlobalRef(Cassette, :untag)), arg, overdub_ctx_slot))
end
idx += 1
end
end

push!(items, x)
return items
end)
end

#=== perform tagged module transformation if tagging is enabled ===#

if istaggingenabled && !iskwfunc
Expand Down
32 changes: 32 additions & 0 deletions src/pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,35 @@ function is_ir_element(x, y, code::Vector)
end
return result
end

"""
ir_element(x, code::Vector)

Follows the series of `SSAValue` that define `x`.

See also: [`is_ir_element`](@ref)
"""
function ir_element(x, code::Vector)
while isa(x, Core.SSAValue)
x = code[x.id]
end
return x
end
Copy link
Contributor

@oxinabox oxinabox Oct 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super cute


"""
resolve_early(ref::GlobalRef)

Resolves a `Core.Compiler.GlobalRef` during compilation, may
return `nothing` if the binding is not resolved or defined yet.
Only use this when you are certain that the result of the lookup
will not change.
"""
function resolve_early(ref::GlobalRef)
mod = ref.mod
name = ref.name
if Base.isbindingresolved(mod, name) && Base.isdefined(mod, name)
return getfield(mod, name)
else
return nothing
end
end
31 changes: 31 additions & 0 deletions test/misctaggingtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,34 @@ result = overdub(ctx, matrixliteral, tag(1, ctx, "hi"))
@test metameta(result, ctx) == fill(Cassette.Meta("hi", Cassette.NoMetaMeta()), 2, 2)

println("done (took ", time() - before_time, " seconds)")

#############################################################################################

print(" running TaggedLLVMCallCtx test...")
before_time = time()
Cassette.@context TaggedLLVMCallCtx
Cassette.metadatatype(::Type{<:TaggedLLVMCallCtx}, ::Type{Float64}) = Float64

function Cassette.overdub(ctx::TaggedLLVMCallCtx, f, args...)
if Cassette.canrecurse(ctx, f, args...)
Cassette.recurse(ctx, f, args...)
else
Cassette.fallback(ctx, f, args...)
end
end

function llvm_sin(x::Float64)
Core.Intrinsics.llvmcall(
(
"""declare double @llvm.sin.f64(double)""",
"""%2 = call double @llvm.sin.f64(double %0)
ret double %2"""
),
Float64, Tuple{Float64}, x
)
end

ctx = enabletagging(TaggedLLVMCallCtx(), llvm_sin)
Cassette.@overdub ctx llvm_sin(tag(4.0, ctx, 1.0))

println("done (took ", time() - before_time, " seconds)")
38 changes: 37 additions & 1 deletion test/misctests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -633,8 +633,12 @@ callback()

println("done (took ", time() - before_time, " seconds)")

#############################################################################################
# Test overdubbing of a call overload invoke

print(" running CtxCallOverload test...")
before_time = time()

using LinearAlgebra

struct Dense{F,S,T}
Expand Down Expand Up @@ -664,11 +668,43 @@ let d = Dense(3,3)
Cassette.overdub(CtxCallOverload(), d, data)
end

println("done (took ", time() - before_time, " seconds)")

#############################################################################################

print(" running OverdubOverdubCtx test...")
println(" running OverdubOverdubCtx test...")

# Fixed in PR #148
Cassette.@context OverdubOverdubCtx;
overdub_overdub_me() = 2
Cassette.overdub(OverdubOverdubCtx(), Cassette.overdub, OverdubOverdubCtx(), overdub_overdub_me)

#############################################################################################

print(" running LLVMCallCtx test...")
before_time = time()
Cassette.@context LLVMCallCtx

# This overdub does nothing, intentionally not marked `@inline`
function Cassette.overdub(ctx::LLVMCallCtx, f, args...)
if Cassette.canrecurse(ctx, f, args...)
Cassette.recurse(ctx, f, args...)
else
Cassette.fallback(ctx, f, args...)
end
end

function llvm_sin(x::Float64)
Core.Intrinsics.llvmcall(
(
"""declare double @llvm.sin.f64(double)""",
"""%2 = call double @llvm.sin.f64(double %0)
ret double %2"""
),
Float64, Tuple{Float64}, x
)
end

Cassette.@overdub LLVMCallCtx() llvm_sin(4.0)

println("done (took ", time() - before_time, " seconds)")