Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReflectOn: pick which method's body to rewrite #157

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions src/overdub.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ const OVERDUB_ARGUMENTS_NAME = gensym("overdub_arguments")
# 4. If tagging is enabled, do the necessary IR transforms for the metadata tagging system
function overdub_pass!(reflection::Reflection,
context_type::DataType,
is_invoke::Bool = false)
is_invoke::Bool = false,
is_reflect_on::Bool = false)
signature = reflection.signature
method = reflection.method
static_params = reflection.static_params
Expand Down Expand Up @@ -176,6 +177,9 @@ function overdub_pass!(reflection::Reflection,
n_actual_args = fieldcount(signature)
n_method_args = Int(method.nargs)
offset = 1
if is_reflect_on
offset += 1
end
for i in 1:n_method_args
if is_invoke && (i == 1 || i == 2)
# With an invoke call, we have: 1 is invoke, 2 is f, 3 is Tuple{}, 4... is args.
Expand Down Expand Up @@ -491,17 +495,52 @@ const OVERDUB_FALLBACK = begin
code_info
end

"""
ReflectOn{Tuple{F, ArgTypes...})

When used in place of `f` in `overdub(ctx, f, g, args...)`, causes the method
of the function of type `F` with method signature `ArgTypes` to be overdubbed
and called with `args`. `g` is used as `#self#`, the function itself.

It is assumed that the method body will work with `args` even though they may
not be the same type prescribed by the original method signature. Useful when
writing passes which you to extract the code for a base type and rewrite it to
work on a custom type.

```julia
julia> Cassette.@context Foo

julia> foo(x::Float64) = "float"

julia> foo(x::Int) = "int"

julia> Cassette.overdub(Foo(), Cassette.ReflectOn{Tuple{typeof(foo), Int64}}(), foo, 1.0)
"int"

julia> Cassette.overdub(Foo(), Cassette.ReflectOn{Tuple{typeof(foo), Float64}}(), foo, 1)
"float"
```
"""
struct ReflectOn{T<:Tuple}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for API consistency, it should be ReflectOn{f, Tuple{ArgTypes...}} instead of ReflectOn{Tuple{f, ArgTypes...}}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough.

end

# `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)`
function __overdub_generator__(self, context_type, args::Tuple)
if nfields(args) > 0
is_builtin = args[1] <: Core.Builtin
is_invoke = args[1] === typeof(Core.invoke)
is_reflect_on = args[1] <: ReflectOn
if !is_builtin || is_invoke
try
untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,)
reflection = reflect(untagged_args)
if is_reflect_on
argtypes = (args[1].parameters[1].parameters...,)
reflection = reflect(argtypes)
else
untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,)
reflection = reflect(untagged_args)
end
if isa(reflection, Reflection)
result = overdub_pass!(reflection, context_type, is_invoke)
result = overdub_pass!(reflection, context_type, is_invoke, is_reflect_on)
isa(result, Expr) && return result
return reflection.code_info
end
Expand Down
24 changes: 23 additions & 1 deletion test/misctests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ end

#############################################################################################

print(" running OverdubOverdubCtx test...")
println(" running OverdubOverdubCtx test...")

# Fixed in PR #148
Cassette.@context OverdubOverdubCtx
Expand All @@ -686,6 +686,28 @@ Cassette.overdub(OverdubOverdubCtx(), Cassette.overdub, OverdubOverdubCtx(), ove

#############################################################################################

print(" running ReflectOn test...")
reflecton_test(x::Float64) = "float64"
reflecton_test(x::Int) = "int"

Cassette.@context ReflectOnCtx
@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_test), Int}}(), reflecton_test, 1.0)) == "int"
@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_test), Float64}}(), reflecton_test, 1)) == "float64"

function reflecton_closure_test(x::Int64)
function inner(y::Int)
(x, "int")
end
function inner(y::Float64)
(x, "float64")
end
inner
end
@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_closure_test(0)), Float64}}(), reflecton_closure_test(8), 1)) == (8, "float64")
@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_closure_test(0)), Int64}}(), reflecton_closure_test(8), 1.0)) == (8, "int")

#############################################################################################

print(" running NukeCtx test...")

@Cassette.context NukeContext
Expand Down