Skip to content

Commit

Permalink
Merge pull request #163 from KristofferC/kc/fix_apply_iterate
Browse files Browse the repository at this point in the history
fix bug in Core._apply_iterate overdubbing, args[1] also needs to be overdubbed
  • Loading branch information
vchuravy authored Jan 27, 2020
2 parents 18b5480 + 6d75f47 commit 2e0827a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ macro context(_Ctx)

@inline Cassette.overdub(ctx::$Ctx, ::typeof(Core._apply), f, args...) = Core._apply(overdub, (ctx, f), args...)
if VERSION >= v"1.4.0-DEV.304"
@inline Cassette.overdub(ctx::$Ctx, ::typeof(Core._apply_iterate), f, args...) = Core._apply_iterate((args...)->overdub(ctx, f, args...), args...)
@inline function Cassette.overdub(ctx::$Ctx, ::typeof(Core._apply_iterate), f, args...)
new_args = ((_args...) -> overdub(ctx, args[1], _args...), Base.tail(args)...)
Core._apply_iterate((args...)->overdub(ctx, f, args...), new_args...)
end
end

# TODO: There are certain non-`Core.Builtin` functions which the compiler often
Expand Down
5 changes: 4 additions & 1 deletion src/overdub.jl
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,10 @@ function recurse end

recurse(ctx::Context, ::typeof(Core._apply), f, args...) = Core._apply(recurse, (ctx, f), args...)
if VERSION >= v"1.4.0-DEV.304"
recurse(ctx::Context, ::typeof(Core._apply_iterate), f, args...) = Core._apply_iterate((args...)->recurse(ctx, f, args...), args...)
function recurse(ctx::Context, ::typeof(Core._apply_iterate), f, args...)
new_args = ((_args...) -> overdub(ctx, args[1], _args...), Base.tail(args)...)
Core._apply_iterate((args...)->recurse(ctx, f, args...), new_args...)
end
end

function overdub_definition(line, file)
Expand Down
14 changes: 14 additions & 0 deletions test/misctests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,3 +706,17 @@ if VERSION >= v"1.4.0-DEV.304"
else
@test_broken Cassette.overdub(NukeContext(), launch, Silo()) === ()
end

if VERSION >= v"1.4.0-DEV.304"
Cassette.@context ApplyIterateCtx;

const instructions = []
function Cassette.prehook(ctx::ApplyIterateCtx,
op::Any,
a::T1, b::T2) where {T1, T2}
push!(instructions, (op, T1, T2))
end

Cassette.overdub(ApplyIterateCtx(), ()->pi*2.0)
@test instructions[end] === (Core.Intrinsics.mul_float, Float64, Float64)
end

0 comments on commit 2e0827a

Please sign in to comment.