diff --git a/Project.toml b/Project.toml index f00a396..56ecf9c 100644 --- a/Project.toml +++ b/Project.toml @@ -17,10 +17,12 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [extensions] AMDGPUExt = "AMDGPU" CUDAExt = "CUDA" +EnzymeExt = "Enzyme" [compat] AMDGPU = "0.8" @@ -29,6 +31,7 @@ BatchedRoutines = "0.2" CUDA = "4, 5" ChainRulesCore = "1" Combinatorics = "1.0" +Enzyme = "0.13.16" MacroTools = "0.5" OMEinsumContractionOrders = "0.9" TupleTools = "1.2, 1.3" @@ -49,4 +52,4 @@ TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "CUDA", "Documenter", "LinearAlgebra", "LuxorGraphPlot", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials"] +test = ["Test", "CUDA", "Documenter", "Enzyme", "LinearAlgebra", "LuxorGraphPlot", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials"] diff --git a/ext/EnzymeExt.jl b/ext/EnzymeExt.jl new file mode 100644 index 0000000..f140935 --- /dev/null +++ b/ext/EnzymeExt.jl @@ -0,0 +1,40 @@ +module EnzymeExt +using Enzyme.EnzymeRules, OMEinsum, Enzyme +using OMEinsum: get_size_dict! + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(einsum!)}, ::Type, + code::Const, xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const) + @assert sx.val == 1 && sy.val == 0 "Only α = 1 and β = 0 is supported, got: $sx, $sy" + # Compute primal + if EnzymeRules.needs_primal(config) + primal = func.val(code.val, xs.val, ys.val, sx.val, sy.val, size_dict.val) + else + primal = nothing + end + # Save x in tape if x will be overwritten + if EnzymeRules.overwritten(config)[3] + tape = copy(xs.val) + else + tape = nothing + end + shadow = ys.dval + return EnzymeRules.AugmentedReturn(primal, shadow, tape) +end + +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(einsum!)}, dret::Type{<:Annotation}, tape, + code::Const, + xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const) + xval = EnzymeRules.overwritten(config)[3] ? tape : xs.val + for i=1:length(xs.val) + xs.dval[i] .+= OMEinsum.einsum_grad(OMEinsum.getixs(code.val), + xval, OMEinsum.getiy(code.val), size_dict.val, conj(ys.dval), i) + end + return (nothing, nothing, nothing, nothing, nothing, nothing) +end + +# EnzymeRules.inactive(::typeof(get_size_dict!), args...) = nothing + +end \ No newline at end of file diff --git a/test/EnzymeExt.jl b/test/EnzymeExt.jl new file mode 100644 index 0000000..6a16c5a --- /dev/null +++ b/test/EnzymeExt.jl @@ -0,0 +1,70 @@ +using Enzyme, OMEinsum, Test +function testf1(x) + y = zeros(size(x, 1)) + einsum!(ein"ii->i", (x,), y, 1, 0, Dict('i'=>3)) + return sum(y) +end + +function testf2(x) + y = einsum(ein"ii->i", (x,), Dict('i'=>3)) + return sum(y) +end + +function testf4(x) + y = ein"ii->i"(x) + return sum(y) +end + +@testset "EnzymeExt" begin + x = randn(3, 3); + gx = zero(x); + + autodiff(ReverseWithPrimal, testf1, Active, Duplicated(x, gx)) + @test gx == [1 0 0; 0 1 0; 0 0 1] + + autodiff(ReverseWithPrimal, testf2, Active, Duplicated(x, gx)) + @test gx == [2 0 0; 0 2 0; 0 0 2] + + autodiff(ReverseWithPrimal, testf4, Active, Duplicated(x, gx)) + @test gx == [3 0 0; 0 3 0; 0 0 3] +end + +@testset "EnzymeExt error" begin + x = randn(3, 3); + gx = zero(x); + function testf3(x) + y = zeros(size(x, 1)) + einsum!(ein"ii->i", (x,), y, 1, 0, Dict('i'=>3)) + return sum(y) + end + autodiff(ReverseWithPrimal, testf3, Active, Duplicated(x, gx)) + @test gx == [1 0 0; 0 1 0; 0 0 1] +end + +@testset "EnzymeExt bp check" begin + A, B, C = randn(2, 3), randn(3, 4), randn(4, 2) + cost0 = ein"(ij, jk), ki->"(A, B, C)[] + gA = zero(A); gB = zero(B); gC = zero(C); + Enzyme.autodiff(Reverse, (a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], Active, Duplicated(A, gA), Duplicated(B, gB), Duplicated(C, gC)) + cost, mg = OMEinsum.cost_and_gradient(ein"(ij, jk), ki->", (A, B, C)) + @test cost[] ≈ cost0 + @test all(gA .≈ mg[1]) + @test all(gB .≈ mg[2]) + @test all(gC .≈ mg[3]) +end + +@testset "EnzymeExt bp check 2" begin + A, B, C = randn(2, 3), randn(3, 4), randn(4, 2) + code = optimize_code(ein"ij, jk, ki->", uniformsize(ein"ij, jk, ki->", 2), TreeSA()) + cost0 = code(A, B, C)[] + gA = zero(A); gB = zero(B); gC = zero(C); + f(code, a, b, c) = code(a, b, c)[] + Enzyme.autodiff(set_runtime_activity(Reverse), f, Active, Const(code), Duplicated(A, gA), Duplicated(B, gB), Duplicated(C, gC)) + cost, mg = OMEinsum.cost_and_gradient(code, (A, B, C)) + @test cost[] ≈ cost0 + @test all(gA .≈ mg[1]) + @test all(gB .≈ mg[2]) + @test all(gC .≈ mg[3]) +end + +# liquid state machine \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e731b50..65b9c58 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -59,6 +59,10 @@ end include("bp.jl") end +@testset "EnzymeExt" begin + include("EnzymeExt.jl") +end + @testset "docstring" begin Documenter.doctest(OMEinsum; manual=false) end