From 852c2b3366b64ece6f17b5a7edc74aaa84778bb6 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 12 Jan 2020 22:14:04 -0500 Subject: [PATCH 1/2] ReflectOn: pick which method's body to rewrite Co-authored-by: "Yingbo Ma" Co-authored-by: "Shashi Gowda" --- src/overdub.jl | 37 +++++++++++++++++++++++++++++++++++-- test/misctests.jl | 10 +++++++++- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/src/overdub.jl b/src/overdub.jl index 74dcb26..592bf4c 100644 --- a/src/overdub.jl +++ b/src/overdub.jl @@ -491,6 +491,34 @@ const OVERDUB_FALLBACK = begin code_info end +""" + ReflectOn{Tuple{F, ArgTypes...}) + +When used in place of `f` in `overdub(ctx, f, args...)`, causes the method +of the function of type `F` with method signature `ArgTypes` to be overdubbed +and called with `args`. + +It is assumed that the method body will work with `args` even though they may +not be the same type prescribed by the original method signature. Useful when +writing passes which you to extract the code for a base type and rewrite it to +work on a custom type. + +```julia +julia> Cassette.@context Foo + +julia> foo(x::Float64) = "float" + +julia> foo(x::Int) = "int" + +julia> Cassette.overdub(Foo(), Cassette.ReflectOn{Tuple{typeof(foo), Int64}}(), 1.0) +"int" + +julia> Cassette.overdub(Foo(), Cassette.ReflectOn{Tuple{typeof(foo), Float64}}(), 1) +"float" +``` +""" +struct ReflectOn{T<:Tuple} end + # `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)` function __overdub_generator__(self, context_type, args::Tuple) if nfields(args) > 0 @@ -498,8 +526,13 @@ function __overdub_generator__(self, context_type, args::Tuple) is_invoke = args[1] === typeof(Core.invoke) if !is_builtin || is_invoke try - untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,) - reflection = reflect(untagged_args) + if args[1] <: ReflectOn + argtypes = (args[1].parameters[1].parameters...,) + reflection = reflect(argtypes) + else + untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,) + reflection = reflect(untagged_args) + end if isa(reflection, Reflection) result = overdub_pass!(reflection, context_type, is_invoke) isa(result, Expr) && return result diff --git a/test/misctests.jl b/test/misctests.jl index 36482f7..b4f8382 100644 --- a/test/misctests.jl +++ b/test/misctests.jl @@ -666,9 +666,17 @@ end ############################################################################################# -print(" running OverdubOverdubCtx test...") +println(" running OverdubOverdubCtx test...") # Fixed in PR #148 Cassette.@context OverdubOverdubCtx; overdub_overdub_me() = 2 Cassette.overdub(OverdubOverdubCtx(), Cassette.overdub, OverdubOverdubCtx(), overdub_overdub_me) + +print(" running ReflectOn test...") +reflecton_test(x::Float64) = "float64" +reflecton_test(x::Int) = "int" + +Cassette.@context ReflectOnCtx +@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_test), Int}}(), 1.0)) == "int" +@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_test), Float64}}(), 1)) == "float64" From e5711d553eac4c2de344cb543a9874d82fcf26ad Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Mon, 13 Jan 2020 00:46:55 -0500 Subject: [PATCH 2/2] Make ReflectOn-overdub take the function as the third argument. This allows usage with closures. --- src/overdub.jl | 22 ++++++++++++++-------- test/misctests.jl | 16 ++++++++++++++-- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/overdub.jl b/src/overdub.jl index 592bf4c..47c6f5b 100644 --- a/src/overdub.jl +++ b/src/overdub.jl @@ -127,7 +127,8 @@ const OVERDUB_ARGUMENTS_NAME = gensym("overdub_arguments") # 4. If tagging is enabled, do the necessary IR transforms for the metadata tagging system function overdub_pass!(reflection::Reflection, context_type::DataType, - is_invoke::Bool = false) + is_invoke::Bool = false, + is_reflect_on::Bool = false) signature = reflection.signature method = reflection.method static_params = reflection.static_params @@ -176,6 +177,9 @@ function overdub_pass!(reflection::Reflection, n_actual_args = fieldcount(signature) n_method_args = Int(method.nargs) offset = 1 + if is_reflect_on + offset += 1 + end for i in 1:n_method_args if is_invoke && (i == 1 || i == 2) # With an invoke call, we have: 1 is invoke, 2 is f, 3 is Tuple{}, 4... is args. @@ -494,9 +498,9 @@ end """ ReflectOn{Tuple{F, ArgTypes...}) -When used in place of `f` in `overdub(ctx, f, args...)`, causes the method +When used in place of `f` in `overdub(ctx, f, g, args...)`, causes the method of the function of type `F` with method signature `ArgTypes` to be overdubbed -and called with `args`. +and called with `args`. `g` is used as `#self#`, the function itself. It is assumed that the method body will work with `args` even though they may not be the same type prescribed by the original method signature. Useful when @@ -510,23 +514,25 @@ julia> foo(x::Float64) = "float" julia> foo(x::Int) = "int" -julia> Cassette.overdub(Foo(), Cassette.ReflectOn{Tuple{typeof(foo), Int64}}(), 1.0) +julia> Cassette.overdub(Foo(), Cassette.ReflectOn{Tuple{typeof(foo), Int64}}(), foo, 1.0) "int" -julia> Cassette.overdub(Foo(), Cassette.ReflectOn{Tuple{typeof(foo), Float64}}(), 1) +julia> Cassette.overdub(Foo(), Cassette.ReflectOn{Tuple{typeof(foo), Float64}}(), foo, 1) "float" ``` """ -struct ReflectOn{T<:Tuple} end +struct ReflectOn{T<:Tuple} +end # `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)` function __overdub_generator__(self, context_type, args::Tuple) if nfields(args) > 0 is_builtin = args[1] <: Core.Builtin is_invoke = args[1] === typeof(Core.invoke) + is_reflect_on = args[1] <: ReflectOn if !is_builtin || is_invoke try - if args[1] <: ReflectOn + if is_reflect_on argtypes = (args[1].parameters[1].parameters...,) reflection = reflect(argtypes) else @@ -534,7 +540,7 @@ function __overdub_generator__(self, context_type, args::Tuple) reflection = reflect(untagged_args) end if isa(reflection, Reflection) - result = overdub_pass!(reflection, context_type, is_invoke) + result = overdub_pass!(reflection, context_type, is_invoke, is_reflect_on) isa(result, Expr) && return result return reflection.code_info end diff --git a/test/misctests.jl b/test/misctests.jl index b4f8382..4d2bf86 100644 --- a/test/misctests.jl +++ b/test/misctests.jl @@ -678,5 +678,17 @@ reflecton_test(x::Float64) = "float64" reflecton_test(x::Int) = "int" Cassette.@context ReflectOnCtx -@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_test), Int}}(), 1.0)) == "int" -@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_test), Float64}}(), 1)) == "float64" +@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_test), Int}}(), reflecton_test, 1.0)) == "int" +@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_test), Float64}}(), reflecton_test, 1)) == "float64" + +function reflecton_closure_test(x::Int64) + function inner(y::Int) + (x, "int") + end + function inner(y::Float64) + (x, "float64") + end + inner +end +@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_closure_test(0)), Float64}}(), reflecton_closure_test(8), 1)) == (8, "float64") +@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_closure_test(0)), Int64}}(), reflecton_closure_test(8), 1.0)) == (8, "int")