-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Mjolnir for tracing #78
base: master
Are you sure you want to change the base?
Conversation
julia> relu(x::Real) = ifelse(x == 0, x, zero(x))
relu (generic function with 2 methods)
julia> @symbolic relu(x)
ifelse(x == 0, x, 0.0) I forgot to mention that there's a syntax for variable types: julia> @symbolic relu(x::Int64)
ifelse(x == 0, x, 0) Mjolnir is also happy tracing through array/linalg code; I'm not sure what this package's support for arrays looks like, but if there's an example I'd be happy to demo something via Mjolnir. |
Codecov Report
@@ Coverage Diff @@
## master #78 +/- ##
==========================================
- Coverage 89.81% 87.34% -2.47%
==========================================
Files 8 9 +1
Lines 432 490 +58
==========================================
+ Hits 388 428 +40
- Misses 44 62 +18
Continue to review full report at Codecov.
|
Absolutely fantanstic!! I'm excited. It might be good to add a Symutils -> IR conversion given a vector of arguments. |
Amazing, Mike! This is a really promising direction. |
Glad you like it!
Yeah, that seems useful – especially for ultimately turning a Term into something you can evaluate. |
this is something we need to start adding methods for... I guess right now an array of symbols is all that works. I think we can use something like Mjolnir's shaped arrays. It could be stored as |
Wow, fantastic. Thanks, Mike! |
julia> using SymbolicUtils
julia> @syms a b c
(a, b, c)
julia> ir=IR(a+b+c, [a,b]; mod=Main)
1: (%1, %2)
%3 = (+)(%1, %2, Main.c)
return %3
julia> h=func(ir)
##260 (generic function with 1 method)
julia> h(2,3)
5 + c Nice! Does this sort of thing have world age issues? Like can ModelingToolkit use this function instead of going through GeneralizedGenerated? |
made the macro use values from the current env (if symbols are passed, uses the this lets you call stuff with anything: julia> f(y, x) = y .+ x
julia> @symbolic f([1,2], a)
broadcasted(+, [1, 2], a)
julia> @symbolic f([b,c], a)
broadcasted(+, [b, c], a)
julia> @syms a b::Float64
(a, b)
julia> f(x, y) = x > 1 ? y+2 : y
f (generic function with 2 methods)
julia> @symbolic f(2, a)
a + 2
julia> f(x::Float64, y) = float(y)
f (generic function with 2 methods)
julia> @symbolic f(b, a)
float(a)
|
julia> function make_and_call(expr, args, vals)
func(IR(expr, args))(vals...)
end
julia> make_and_call(a+b, [a,b], [1,2])
ERROR: MethodError: no method matching ##261(::Int64, ::Int64)
The applicable method may be too new: running in world age 27244, while current world is 27246.
Closest candidates are:
##261(::Any, ::Any) at /home/shashi/.julia/dev/IRTools/src/eval.jl:18 (method too new to be called from this world context.)
Stacktrace:
[1] make_and_call(::SymbolicUtils.Term{Number}, ::Array{SymbolicUtils.Sym,1}, ::Array{Int64,1}) at ./REPL[91]:2
[2] top-level scope at REPL[92]:1
[3] run_backend(::REPL.REPLBackend) at /home/shashi/.julia/packages/Revise/MgvIv/src/Revise.jl:1023
[4] top-level scope at none:0 lol I guess conversion back to IR is only useful to rewrite functions with n=1 basic block right now. |
using SymbolicUtils
using ModelingToolkit
using ModelingToolkit: expand_derivatives, to_mtk
using SymbolicUtils: to_symbolic
function D(f, T; simplify=true)
@syms t()::T
@derivatives DD'~to_mtk(t())
expr = @symbolic f(t())
deriv_expr = to_symbolic(expand_derivatives(DD(to_mtk(expr)), simplify))
@show deriv_expr
IR(deriv_expr, [t()])
end
julia> f(x::Float64) = sin(cos(x)) - cos(sin(x))
f (generic function with 1 method)
julia> D(f, Float64, simplify=false)
deriv_expr = (one(cos(sin(t()))) * sin(sin(t())) * cos(t())) + (-1 * one(sin(cos(t()))) * cos(cos(t())) * sin(t()))
1: (%1)
%2 = (cos)(%1)
%3 = (sin)(%2)
%4 = (one)(%3)
%5 = (cos)(%1)
%6 = (cos)(%5)
%7 = (sin)(%1)
%8 = (-)(%7)
%9 = (*)(%8, 1)
%10 = (*)(%6, %9)
%11 = (*)(%4, %10)
%12 = (sin)(%1)
%13 = (cos)(%12)
%14 = (one)(%13)
%15 = (-)(%14)
%16 = (sin)(%1)
%17 = (sin)(%16)
%18 = (-)(%17)
%19 = (cos)(%1)
%20 = (*)(%19, 1)
%21 = (*)(%18, %20)
%22 = (*)(%15, %21)
%23 = (+)(%11, %22)
return %23
julia> D(f, Float64, simplify=true)
deriv_expr = (sin(sin(t())) * cos(t())) + (-1 * cos(cos(t())) * sin(t()))
1: (%1)
%2 = (sin)(%1)
%3 = (sin)(%2)
%4 = (cos)(%1)
%5 = (*)(%3, %4)
%6 = (cos)(%1)
%7 = (cos)(%6)
%8 = (sin)(%1)
%9 = (*)(-1, %7, %8)
%10 = (+)(%5, %9)
return %10 tracing AD with simplification! 2nd derivative goes from 78 lines -> 20 lines after simplify |
I think the right way to solve this is to use a generated function which does a trace based on input types; and it should also add backedges from all traced functions to solve 265-y issues. |
@MikeInnes I just added a fuzzer for Mjolnir if you're interested in having a look (only constructs simple functions from SymbolicUtils exprs): julia> include("test/fuzzlib.jl")
fuzz_test (generic function with 1 method)
julia> for i=1:500; fuzz_test(0, num_spec, mjolnir=true); end
err = Mjolnir.TraceError(ErrorException("No IR for Tuple{Core.IntrinsicFunction,Int64,Int64}"), Any[(const(#63),), (const(//), const(-91), const(33)), (const(Rational), const(-91), const(33)), (const(Rational{Int64}), const(-91), const(33)), (const(divgcd), const(-91), const(33)), (const(div), const(-91), const(1)), (const(checked_sdiv_int), const(-91), const(1))])
function ()
-91//33
end
err = Mjolnir.TraceError(ErrorException("No IR for Tuple{Core.IntrinsicFunction,Int64,Int64}"), Any[(const(#71), Real), (const(//), const(39), const(1)), (const(Rational), const(39), const(1)), (const(Rational{Int64}), const(39), const(1)), (const(divgcd), const(39), const(1)), (const(div), const(39), const(1)), (const(checked_sdiv_int), const(39), const(1))])
function (b,)
(39//1 - 29//36*im) - b
end
... It might sometimes hit #30 but I will fix that. |
This is just a prototype for the time being (Mjolnir is not registered yet) but it shows off the basic ideas. This is the right solution to #16 (handling dispatch correctly), #8 (handling
typeof
etc. correctly) and #67 (handlingifelse
andif
/else
).Demo of #8: