diff --git a/src/overdub.jl b/src/overdub.jl index d336c6c..eb377c1 100644 --- a/src/overdub.jl +++ b/src/overdub.jl @@ -127,7 +127,8 @@ const OVERDUB_ARGUMENTS_NAME = gensym("overdub_arguments") # 4. If tagging is enabled, do the necessary IR transforms for the metadata tagging system function overdub_pass!(reflection::Reflection, context_type::DataType, - is_invoke::Bool = false) + is_invoke::Bool = false, + is_reflect_on::Bool = false) signature = reflection.signature method = reflection.method static_params = reflection.static_params @@ -176,6 +177,9 @@ function overdub_pass!(reflection::Reflection, n_actual_args = fieldcount(signature) n_method_args = Int(method.nargs) offset = 1 + if is_reflect_on + offset += 1 + end for i in 1:n_method_args if is_invoke && (i == 1 || i == 2) # With an invoke call, we have: 1 is invoke, 2 is f, 3 is Tuple{}, 4... is args. @@ -491,17 +495,52 @@ const OVERDUB_FALLBACK = begin code_info end +""" + ReflectOn{Tuple{F, ArgTypes...}) + +When used in place of `f` in `overdub(ctx, f, g, args...)`, causes the method +of the function of type `F` with method signature `ArgTypes` to be overdubbed +and called with `args`. `g` is used as `#self#`, the function itself. + +It is assumed that the method body will work with `args` even though they may +not be the same type prescribed by the original method signature. Useful when +writing passes which you to extract the code for a base type and rewrite it to +work on a custom type. + +```julia +julia> Cassette.@context Foo + +julia> foo(x::Float64) = "float" + +julia> foo(x::Int) = "int" + +julia> Cassette.overdub(Foo(), Cassette.ReflectOn{Tuple{typeof(foo), Int64}}(), foo, 1.0) +"int" + +julia> Cassette.overdub(Foo(), Cassette.ReflectOn{Tuple{typeof(foo), Float64}}(), foo, 1) +"float" +``` +""" +struct ReflectOn{T<:Tuple} +end + # `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)` function __overdub_generator__(self, context_type, args::Tuple) if nfields(args) > 0 is_builtin = args[1] <: Core.Builtin is_invoke = args[1] === typeof(Core.invoke) + is_reflect_on = args[1] <: ReflectOn if !is_builtin || is_invoke try - untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,) - reflection = reflect(untagged_args) + if is_reflect_on + argtypes = (args[1].parameters[1].parameters...,) + reflection = reflect(argtypes) + else + untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,) + reflection = reflect(untagged_args) + end if isa(reflection, Reflection) - result = overdub_pass!(reflection, context_type, is_invoke) + result = overdub_pass!(reflection, context_type, is_invoke, is_reflect_on) isa(result, Expr) && return result return reflection.code_info end diff --git a/test/misctests.jl b/test/misctests.jl index 6dd69ae..961aa4f 100644 --- a/test/misctests.jl +++ b/test/misctests.jl @@ -677,7 +677,7 @@ end ############################################################################################# -print(" running OverdubOverdubCtx test...") +println(" running OverdubOverdubCtx test...") # Fixed in PR #148 Cassette.@context OverdubOverdubCtx @@ -686,6 +686,28 @@ Cassette.overdub(OverdubOverdubCtx(), Cassette.overdub, OverdubOverdubCtx(), ove ############################################################################################# +print(" running ReflectOn test...") +reflecton_test(x::Float64) = "float64" +reflecton_test(x::Int) = "int" + +Cassette.@context ReflectOnCtx +@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_test), Int}}(), reflecton_test, 1.0)) == "int" +@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_test), Float64}}(), reflecton_test, 1)) == "float64" + +function reflecton_closure_test(x::Int64) + function inner(y::Int) + (x, "int") + end + function inner(y::Float64) + (x, "float64") + end + inner +end +@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_closure_test(0)), Float64}}(), reflecton_closure_test(8), 1)) == (8, "float64") +@test @inferred(Cassette.overdub(ReflectOnCtx(), Cassette.ReflectOn{Tuple{typeof(reflecton_closure_test(0)), Int64}}(), reflecton_closure_test(8), 1.0)) == (8, "int") + +############################################################################################# + print(" running NukeCtx test...") @Cassette.context NukeContext