Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme AD - almost work! #176

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
AMDGPUExt = "AMDGPU"
CUDAExt = "CUDA"
EnzymeExt = "Enzyme"

[compat]
AMDGPU = "0.8"
Expand All @@ -29,6 +31,7 @@ BatchedRoutines = "0.2"
CUDA = "4, 5"
ChainRulesCore = "1"
Combinatorics = "1.0"
Enzyme = "0.13.16"
MacroTools = "0.5"
OMEinsumContractionOrders = "0.9"
TupleTools = "1.2, 1.3"
Expand All @@ -49,4 +52,4 @@ TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "CUDA", "Documenter", "LinearAlgebra", "LuxorGraphPlot", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials"]
test = ["Test", "CUDA", "Documenter", "Enzyme", "LinearAlgebra", "LuxorGraphPlot", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials"]
40 changes: 40 additions & 0 deletions ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module EnzymeExt
using Enzyme.EnzymeRules, OMEinsum, Enzyme
using OMEinsum: get_size_dict!

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(einsum!)}, ::Type,
code::Const, xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)
@assert sx.val == 1 && sy.val == 0 "Only α = 1 and β = 0 is supported, got: $sx, $sy"
# Compute primal
if EnzymeRules.needs_primal(config)
primal = func.val(code.val, xs.val, ys.val, sx.val, sy.val, size_dict.val)
else
primal = nothing
end
# Save x in tape if x will be overwritten
if EnzymeRules.overwritten(config)[3]
tape = copy(xs.val)
else
tape = nothing
end
shadow = ys.dval
return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(einsum!)}, dret::Type{<:Annotation}, tape,
code::Const,
xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)
xval = EnzymeRules.overwritten(config)[3] ? tape : xs.val
for i=1:length(xs.val)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably also be xval

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should be the same, no?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's a tuple, yes becaise it's immutable. If it's an array someone might've pushed/pop'd to it in between forward and reverse pass

xs.dval[i] .+= OMEinsum.einsum_grad(OMEinsum.getixs(code.val),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason for doing the for loop here, can this just be broadcasted for all of them? Or even more ideally could the dval be an extra argument to einsum/einsum_grad?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here xs is a tuple of arrays, here einsum_grad tries to get gradient of the i-th input tensor. You are right, it is better to design an inplace version for better performance. I will consider making this change in the future.

xval, OMEinsum.getiy(code.val), size_dict.val, conj(ys.dval), i)
end
return (nothing, nothing, nothing, nothing, nothing, nothing)
end

# EnzymeRules.inactive(::typeof(get_size_dict!), args...) = nothing

end
70 changes: 70 additions & 0 deletions test/EnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using Enzyme, OMEinsum, Test
function testf1(x)
y = zeros(size(x, 1))
einsum!(ein"ii->i", (x,), y, 1, 0, Dict('i'=>3))
return sum(y)
end

function testf2(x)
y = einsum(ein"ii->i", (x,), Dict('i'=>3))
return sum(y)
end

function testf4(x)
y = ein"ii->i"(x)
return sum(y)
end

@testset "EnzymeExt" begin
x = randn(3, 3);
gx = zero(x);

autodiff(ReverseWithPrimal, testf1, Active, Duplicated(x, gx))
@test gx == [1 0 0; 0 1 0; 0 0 1]

autodiff(ReverseWithPrimal, testf2, Active, Duplicated(x, gx))
@test gx == [2 0 0; 0 2 0; 0 0 2]

autodiff(ReverseWithPrimal, testf4, Active, Duplicated(x, gx))
@test gx == [3 0 0; 0 3 0; 0 0 3]
end

@testset "EnzymeExt error" begin
x = randn(3, 3);
gx = zero(x);
function testf3(x)
y = zeros(size(x, 1))
einsum!(ein"ii->i", (x,), y, 1, 0, Dict('i'=>3))
return sum(y)
end
autodiff(ReverseWithPrimal, testf3, Active, Duplicated(x, gx))
@test gx == [1 0 0; 0 1 0; 0 0 1]
end

@testset "EnzymeExt bp check" begin
A, B, C = randn(2, 3), randn(3, 4), randn(4, 2)
cost0 = ein"(ij, jk), ki->"(A, B, C)[]
gA = zero(A); gB = zero(B); gC = zero(C);
Enzyme.autodiff(Reverse, (a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], Active, Duplicated(A, gA), Duplicated(B, gB), Duplicated(C, gC))
cost, mg = OMEinsum.cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))
@test cost[] ≈ cost0
@test all(gA .≈ mg[1])
@test all(gB .≈ mg[2])
@test all(gC .≈ mg[3])
end

@testset "EnzymeExt bp check 2" begin
A, B, C = randn(2, 3), randn(3, 4), randn(4, 2)
code = optimize_code(ein"ij, jk, ki->", uniformsize(ein"ij, jk, ki->", 2), TreeSA())
cost0 = code(A, B, C)[]
gA = zero(A); gB = zero(B); gC = zero(C);
f(code, a, b, c) = code(a, b, c)[]
Enzyme.autodiff(set_runtime_activity(Reverse), f, Active, Const(code), Duplicated(A, gA), Duplicated(B, gB), Duplicated(C, gC))
cost, mg = OMEinsum.cost_and_gradient(code, (A, B, C))
@test cost[] ≈ cost0
@test all(gA .≈ mg[1])
@test all(gB .≈ mg[2])
@test all(gC .≈ mg[3])
end

# liquid state machine
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ end
include("bp.jl")
end

@testset "EnzymeExt" begin
include("EnzymeExt.jl")
end

@testset "docstring" begin
Documenter.doctest(OMEinsum; manual=false)
end
Loading