Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Mjolnir for tracing #78

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
Draft

Conversation

MikeInnes
Copy link

@MikeInnes MikeInnes commented May 8, 2020

This is just a prototype for the time being (Mjolnir is not registered yet) but it shows off the basic ideas. This is the right solution to #16 (handling dispatch correctly), #8 (handling typeof etc. correctly) and #67 (handling ifelse and if/else).

julia> using SymbolicUtils

julia> @symbolic 3x^2 + 2y + 1
1 + (3 * (x ^ 2)) + (2 * y)

julia> f(x::Real) = abs2(Complex(x, 2x))
f (generic function with 1 method)

julia> @symbolic f(x)+y
(5 * (x ^ 2)) + y

Demo of #8:

julia> _qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))
_qreltype (generic function with 1 method)

julia> @symbolic zero(_qreltype(typeof(x)))
0.0

@MikeInnes
Copy link
Author

MikeInnes commented May 8, 2020

ifelse works:

julia> relu(x::Real) = ifelse(x == 0, x, zero(x))
relu (generic function with 2 methods)

julia> @symbolic relu(x)
ifelse(x == 0, x, 0.0)

I forgot to mention that there's a syntax for variable types:

julia> @symbolic relu(x::Int64)
ifelse(x == 0, x, 0)

Mjolnir is also happy tracing through array/linalg code; I'm not sure what this package's support for arrays looks like, but if there's an example I'd be happy to demo something via Mjolnir.

@codecov-io
Copy link

codecov-io commented May 8, 2020

Codecov Report

Merging #78 into master will decrease coverage by 2.46%.
The diff coverage is 0.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #78      +/-   ##
==========================================
- Coverage   89.81%   87.34%   -2.47%     
==========================================
  Files           8        9       +1     
  Lines         432      490      +58     
==========================================
+ Hits          388      428      +40     
- Misses         44       62      +18     
Impacted Files Coverage Δ
src/SymbolicUtils.jl 100.00% <ø> (ø)
src/trace.jl 0.00% <0.00%> (ø)
src/rulesets.jl 50.00% <0.00%> (ø)
src/matchers.jl 90.09% <0.00%> (+0.30%) ⬆️
src/rule_dsl.jl 98.87% <0.00%> (+1.54%) ⬆️
src/simplify.jl 95.08% <0.00%> (+2.90%) ⬆️
src/types.jl 89.58% <0.00%> (+2.91%) ⬆️
src/methods.jl 96.29% <0.00%> (+7.40%) ⬆️
src/util.jl 76.19% <0.00%> (+9.52%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5a8f2f9...fd1e2af. Read the comment docs.

@shashi
Copy link
Member

shashi commented May 8, 2020

Absolutely fantanstic!! I'm excited.

It might be good to add a Symutils -> IR conversion given a vector of arguments.

@MasonProtter
Copy link
Member

Amazing, Mike! This is a really promising direction.

@MikeInnes
Copy link
Author

Glad you like it!

It might be good to add a Symutils -> IR conversion given a vector of arguments.

Yeah, that seems useful – especially for ultimately turning a Term into something you can evaluate.

@shashi
Copy link
Member

shashi commented May 8, 2020

Mjolnir is also happy tracing through array/linalg code; I'm not sure what this package's support for arrays looks like, but if there's an example I'd be happy to demo something via Mjolnir.

this is something we need to start adding methods for... I guess right now an array of symbols is all that works. I think we can use something like Mjolnir's shaped arrays. It could be stored as T in Symbolic{T}.

@YingboMa
Copy link
Member

YingboMa commented May 8, 2020

Wow, fantastic. Thanks, Mike!

@shashi shashi marked this pull request as draft May 10, 2020 16:05
@shashi
Copy link
Member

shashi commented May 10, 2020

julia> using SymbolicUtils

julia> @syms a b c
(a, b, c)

julia> ir=IR(a+b+c, [a,b]; mod=Main)
1: (%1, %2)
  %3 = (+)(%1, %2, Main.c)
  return %3

julia> h=func(ir)
##260 (generic function with 1 method)

julia> h(2,3)
5 + c

Nice! Does this sort of thing have world age issues? Like can ModelingToolkit use this function instead of going through GeneralizedGenerated?

@shashi
Copy link
Member

shashi commented May 10, 2020

made the macro use values from the current env (if symbols are passed, uses the symtype to trace, everything else is a Const).

this lets you call stuff with anything:

julia> f(y, x) = y .+ x

julia> @symbolic f([1,2], a)
broadcasted(+, [1, 2], a)

julia> @symbolic f([b,c], a)
broadcasted(+, [b, c], a)

julia> @syms a b::Float64
(a, b)

julia> f(x, y) =  x > 1 ? y+2 : y
f (generic function with 2 methods)

julia> @symbolic f(2, a)
a + 2

julia> f(x::Float64, y) =  float(y)
f (generic function with 2 methods)

julia> @symbolic f(b, a)
float(a)

@shashi
Copy link
Member

shashi commented May 10, 2020

world age

julia> function make_and_call(expr, args, vals)
           func(IR(expr, args))(vals...)
       end

julia> make_and_call(a+b, [a,b], [1,2])
ERROR: MethodError: no method matching ##261(::Int64, ::Int64)
The applicable method may be too new: running in world age 27244, while current world is 27246.
Closest candidates are:
  ##261(::Any, ::Any) at /home/shashi/.julia/dev/IRTools/src/eval.jl:18 (method too new to be called from this world context.)
Stacktrace:
 [1] make_and_call(::SymbolicUtils.Term{Number}, ::Array{SymbolicUtils.Sym,1}, ::Array{Int64,1}) at ./REPL[91]:2
 [2] top-level scope at REPL[92]:1
 [3] run_backend(::REPL.REPLBackend) at /home/shashi/.julia/packages/Revise/MgvIv/src/Revise.jl:1023
 [4] top-level scope at none:0

lol I guess conversion back to IR is only useful to rewrite functions with n=1 basic block right now.

@shashi
Copy link
Member

shashi commented May 10, 2020

using SymbolicUtils
using ModelingToolkit
using ModelingToolkit: expand_derivatives, to_mtk
using SymbolicUtils: to_symbolic
function D(f, T; simplify=true)
    @syms t()::T
    @derivatives DD'~to_mtk(t())
    expr = @symbolic f(t())
    deriv_expr = to_symbolic(expand_derivatives(DD(to_mtk(expr)), simplify))
    @show deriv_expr
    IR(deriv_expr, [t()])
end
julia> f(x::Float64) = sin(cos(x)) - cos(sin(x))
f (generic function with 1 method)
julia> D(f, Float64, simplify=false)
deriv_expr = (one(cos(sin(t()))) * sin(sin(t())) * cos(t())) + (-1 * one(sin(cos(t()))) * cos(cos(t())) * sin(t()))
1: (%1)
  %2 = (cos)(%1)
  %3 = (sin)(%2)
  %4 = (one)(%3)
  %5 = (cos)(%1)
  %6 = (cos)(%5)
  %7 = (sin)(%1)
  %8 = (-)(%7)
  %9 = (*)(%8, 1)
  %10 = (*)(%6, %9)
  %11 = (*)(%4, %10)
  %12 = (sin)(%1)
  %13 = (cos)(%12)
  %14 = (one)(%13)
  %15 = (-)(%14)
  %16 = (sin)(%1)
  %17 = (sin)(%16)
  %18 = (-)(%17)
  %19 = (cos)(%1)
  %20 = (*)(%19, 1)
  %21 = (*)(%18, %20)
  %22 = (*)(%15, %21)
  %23 = (+)(%11, %22)
  return %23
julia> D(f, Float64, simplify=true)
deriv_expr = (sin(sin(t())) * cos(t())) + (-1 * cos(cos(t())) * sin(t()))
1: (%1)
  %2 = (sin)(%1)
  %3 = (sin)(%2)
  %4 = (cos)(%1)
  %5 = (*)(%3, %4)
  %6 = (cos)(%1)
  %7 = (cos)(%6)
  %8 = (sin)(%1)
  %9 = (*)(-1, %7, %8)
  %10 = (+)(%5, %9)
  return %10

tracing AD with simplification! 2nd derivative goes from 78 lines -> 20 lines after simplify

@MikeInnes
Copy link
Author

Nice! Does this sort of thing have world age issues? Like can ModelingToolkit use this function instead of going through GeneralizedGenerated?

func just calls eval on the expression, so it has all the usual issues that does.

I think the right way to solve this is to use a generated function which does a trace based on input types; and it should also add backedges from all traced functions to solve 265-y issues.

@shashi
Copy link
Member

shashi commented May 11, 2020

@MikeInnes I just added a fuzzer for Mjolnir if you're interested in having a look (only constructs simple functions from SymbolicUtils exprs):

julia> include("test/fuzzlib.jl")
fuzz_test (generic function with 1 method)

julia> for i=1:500; fuzz_test(0, num_spec, mjolnir=true); end
err = Mjolnir.TraceError(ErrorException("No IR for Tuple{Core.IntrinsicFunction,Int64,Int64}"), Any[(const(#63),), (const(//), const(-91), const(33)), (const(Rational), const(-91), const(33)), (const(Rational{Int64}), const(-91), const(33)), (const(divgcd), const(-91), const(33)), (const(div), const(-91), const(1)), (const(checked_sdiv_int), const(-91), const(1))])
function ()
    -91//33
end

err = Mjolnir.TraceError(ErrorException("No IR for Tuple{Core.IntrinsicFunction,Int64,Int64}"), Any[(const(#71), Real), (const(//), const(39), const(1)), (const(Rational), const(39), const(1)), (const(Rational{Int64}), const(39), const(1)), (const(divgcd), const(39), const(1)), (const(div), const(39), const(1)), (const(checked_sdiv_int), const(39), const(1))])
function (b,)
    (39//1 - 29//36*im) - b
end
...

It might sometimes hit #30 but I will fix that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants