From 050ea520e260d455c79d2d4ae2ba007da37ad07f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 27 Jul 2021 16:34:43 -0400 Subject: [PATCH 01/43] use ProjectTo in broadcasting, etc --- src/compiler/chainrules.jl | 14 ++++++++++++++ src/compiler/interface.jl | 6 ++++-- src/lib/array.jl | 2 +- src/lib/broadcast.jl | 10 +++------- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 4bf7da28a..316352826 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -128,6 +128,20 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR ChainRules.Tangent{Any, typeof(xp)}(xp) end +""" + _project(x)(dx) + _project(x, dx) + +The function `_project(x)` returns a projector, which standardises the gradient `dx` for type & shape. +Uses `ChainRulesCore.ProjectTo`, but is safe to apply to arbitrary input. +The two-argument `_project(x, dx)` applies this immediately. +""" +@inline _project(x) = identity # fallback: do nothing! +@inline _project(x::Numeric) = wrap_chainrules_output ∘ ProjectTo(x) +@inline _project(x::Ref{<:Numeric}) = wrap_chainrules_output ∘ ProjectTo(x) + +@inline _project(x, dx) = _project(x)(dx) + """ ZBack{F}(back) <: Function diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index e4db33471..792065ab6 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -73,7 +73,8 @@ julia> gradient([7, 11], 0, 1) do x, y, d """ function gradient(f, args...) y, back = pullback(f, args...) - return back(sensitivity(y)) + grad = back(sensitivity(y)) + map(_project, args, grad) end Base.adjoint(f::Function) = x -> gradient(f, x)[1] @@ -95,7 +96,8 @@ true """ function withgradient(f, args...) y, back = pullback(f, args...) - (val = y, grad = back(sensitivity(y))) + grad = back(sensitivity(y)) + (val = y, grad = map(_project, args, grad)) end # Param-style wrappers diff --git a/src/lib/array.jl b/src/lib/array.jl index 15b994564..9bec64b95 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -38,7 +38,7 @@ end dxv = view(dx, inds...) dxv .= accum.(dxv, _droplike(dy, dxv)) end - return (dx, map(_->nothing, inds)...) + return (_project(x, dx), map(_->nothing, inds)...) end """ diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 446e919b1..572224665 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -45,18 +45,14 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)}) end -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) -trim(x::Tuple, Δ) = NTuple{length(x)}(Δ) - unbroadcast(x::AbstractArray, x̄) = - size(x) == size(x̄) ? x̄ : - length(x) == length(x̄) ? trim(x, x̄) : - trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄))))) + length(x) == length(x̄) ? _project(x, x̄) : + _project(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄))))) unbroadcast(x::Number, x̄) = accum_sum(x̄) unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),) unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),) -unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1 +unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1 unbroadcast(x::AbstractArray, x̄::Nothing) = nothing From a41626344fc0646b1f1ad180aaaed013801a48e4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 27 Jul 2021 17:16:22 -0400 Subject: [PATCH 02/43] separate methods for Params --- src/compiler/interface.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 792065ab6..930962496 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -125,7 +125,15 @@ julia> haskey(g, z) # only x and y are parameters false ``` """ -gradient +function gradient(f, ps::Params) + y, back = pullback(f, ps) + back(sensitivity(y)) +end + +function withgradient(f, ps::Params) + y, back = pullback(f, ps) + (val = y, grad = back(sensitivity(y))) +end """ Params([A, B]) From ac1281be49b236cc87147d5a64cf29e3cdc59662 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 27 Jul 2021 17:19:38 -0400 Subject: [PATCH 03/43] move after defn --- src/compiler/interface.jl | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 930962496..accf5f650 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -102,6 +102,25 @@ end # Param-style wrappers +""" + Params([A, B]) + +Container for implicit parameters, used when differentiating +a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`. +""" +struct Params + order::Buffer # {Any, Vector{Any}} + params::IdSet{Any} # TODO store ids only +end + +Params() = Params(Buffer([], false), IdSet()) +Params(xs) = Params(Buffer(xs, false), IdSet(xs)) +Params(ps::Params) = ps +Params(xs::Tuple) = Params(collect(xs)) + +@forward Params.order Base.iterate, Base.length, Base.getindex +@forward Params.params Base.in + """ gradient(() -> loss(), ps::Params) -> Grads @@ -135,25 +154,6 @@ function withgradient(f, ps::Params) (val = y, grad = back(sensitivity(y))) end -""" - Params([A, B]) - -Container for implicit parameters, used when differentiating -a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`. -""" -struct Params - order::Buffer # {Any, Vector{Any}} - params::IdSet{Any} # TODO store ids only -end - -Params() = Params(Buffer([], false), IdSet()) -Params(xs) = Params(Buffer(xs, false), IdSet(xs)) -Params(ps::Params) = ps -Params(xs::Tuple) = Params(collect(xs)) - -@forward Params.order Base.iterate, Base.length, Base.getindex -@forward Params.params Base.in - function Base.union!(ps::Params, itrs...) foreach(itr -> foreach(x -> push!(ps, x), itr), itrs) return ps From 0bb31c277af74bcc1cfe5101c6b363981d3fb91e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 1 Aug 2021 10:42:33 -0400 Subject: [PATCH 04/43] better dims handling in unbroadcast --- src/lib/broadcast.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 572224665..abbddfbb9 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -45,10 +45,16 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)}) end -unbroadcast(x::AbstractArray, x̄) = - length(x) == length(x̄) ? _project(x, x̄) : - _project(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄))))) - +function unbroadcast(x::AbstractArray, x̄) + N = ndims(x̄) + if length(x) == length(x̄) + _project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices + else + tup = filter(d -> !(d isa Nothing), ntuple(i -> size(x, i) == 1 ? i : nothing, N)) + dims = tup isa Tuple{Int} ? only(tup) : tup # avoid sum(xbar, dims=(1,)) as e.g. sum(SA[1 2; 3 4], dims=(1,)) fails + _project(x, accum_sum(x̄; dims = dims)) + end +end unbroadcast(x::Number, x̄) = accum_sum(x̄) unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),) unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),) From d087bbed251027a277b01a063a6dfe34c92c1099 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 1 Aug 2021 10:46:15 -0400 Subject: [PATCH 05/43] tidier --- src/lib/broadcast.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index abbddfbb9..9accb4693 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -48,10 +48,10 @@ end function unbroadcast(x::AbstractArray, x̄) N = ndims(x̄) if length(x) == length(x̄) - _project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices + _project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors else - tup = filter(d -> !(d isa Nothing), ntuple(i -> size(x, i) == 1 ? i : nothing, N)) - dims = tup isa Tuple{Int} ? only(tup) : tup # avoid sum(xbar, dims=(1,)) as e.g. sum(SA[1 2; 3 4], dims=(1,)) fails + tup = filter(d -> size(x, d) == 1, ntuple(identity, N)) + dims = length(tup) == 1 ? only(tup) : tup # avoid sum(xbar, dims=(1,)) as e.g. sum(SA[1 2; 3 4], dims=(1,)) fails _project(x, accum_sum(x̄; dims = dims)) end end From d7ce02fc0e041269818b5f504121a8933e979972 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 1 Aug 2021 20:53:49 -0400 Subject: [PATCH 06/43] tests --- test/features.jl | 4 ++-- test/utils.jl | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/features.jl b/test/features.jl index 8c460dc98..edf1950ee 100644 --- a/test/features.jl +++ b/test/features.jl @@ -314,7 +314,7 @@ end[1] == 1 return x*5 end[1] == 5 -@test gradient(x -> one(eltype(x)), rand(10))[1] === nothing +@test_skip gradient(x -> one(eltype(x)), rand(10))[1] === nothing # no method matching (::ProjectTo{AbstractArray, ...})(::Nothing) # Thre-way control flow merge @test gradient(1) do x @@ -407,7 +407,7 @@ function pow_simd(x, n) return r end -@test gradient(pow_simd, 2, 3) == (12,nothing) +@test_broken gradient(pow_simd, 2, 3) == (12,nothing) # no method matching (::ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::Nothing) @testset "tuple getindex" begin @test gradient(x -> size(x)[2], ones(2,2,2)) == (nothing,) diff --git a/test/utils.jl b/test/utils.jl index 70a8ebd63..35924716c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -66,7 +66,8 @@ end j5 = jacobian((x,y) -> hcat(x[1], y), fill(pi), exp(1)) # zero-array @test j5[1] isa Matrix - @test vec(j5[1]) == [1, 0] + @test_broken vec(j5[1]) == [1, 0] # bug, https://github.com/JuliaDiff/ChainRulesCore.jl/issues/423 + @test j5[2] == [0, 1] @test_throws ArgumentError jacobian(identity, [1,2,3+im]) @test_throws ArgumentError jacobian(sum, [1,2,3+im]) # scalar, complex From f353ae203478bc227f3b4aa6d64273345170a948 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 18:53:25 -0400 Subject: [PATCH 07/43] more wrapping --- src/compiler/chainrules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 316352826..4e02cfec0 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -137,8 +137,8 @@ Uses `ChainRulesCore.ProjectTo`, but is safe to apply to arbitrary input. The two-argument `_project(x, dx)` applies this immediately. """ @inline _project(x) = identity # fallback: do nothing! -@inline _project(x::Numeric) = wrap_chainrules_output ∘ ProjectTo(x) -@inline _project(x::Ref{<:Numeric}) = wrap_chainrules_output ∘ ProjectTo(x) +@inline _project(x::Numeric) = wrap_chainrules_output ∘ ProjectTo(x) ∘ wrap_chainrules_input +@inline _project(x::Ref{<:Numeric}) = wrap_chainrules_output ∘ ProjectTo(x) ∘ wrap_chainrules_input @inline _project(x, dx) = _project(x)(dx) From 48fbfcc424f892c045ba441ff589b5f520140ffb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 19:15:28 -0400 Subject: [PATCH 08/43] fix a test --- Project.toml | 2 +- test/utils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 0cc618af0..5cdf08775 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" ChainRules = "1.5" -ChainRulesCore = "1.1" +ChainRulesCore = "1.3" ChainRulesTestUtils = "1" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12" diff --git a/test/utils.jl b/test/utils.jl index 35924716c..e8ed59888 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -66,7 +66,7 @@ end j5 = jacobian((x,y) -> hcat(x[1], y), fill(pi), exp(1)) # zero-array @test j5[1] isa Matrix - @test_broken vec(j5[1]) == [1, 0] # bug, https://github.com/JuliaDiff/ChainRulesCore.jl/issues/423 + @test vec(j5[1]) == [1, 0] @test j5[2] == [0, 1] @test_throws ArgumentError jacobian(identity, [1,2,3+im]) From a826092bbad4f408b6fe1fd3d41aa1299dc3ef76 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 19:50:17 -0400 Subject: [PATCH 09/43] handle a few nothings --- src/compiler/chainrules.jl | 13 +++++++++++++ src/compiler/interface.jl | 2 ++ test/features.jl | 4 ++-- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 4e02cfec0..fa3e7485f 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -142,6 +142,19 @@ The two-argument `_project(x, dx)` applies this immediately. @inline _project(x, dx) = _project(x)(dx) +# PIRACY -- some tests hit a matrix of nothings, which doesn't seem to be handled? +(::ChainRulesCore.ProjectTo)(nothing) = ChainRulesCore.NoTangent() + +# julia> Zygote.wrap_chainrules_input(nothing) +# ChainRulesCore.ZeroTangent() +# +# julia> Zygote.wrap_chainrules_input([nothing, nothing]) +# 2-element Vector{Nothing}: +# nothing +# nothing +# +# But the original case was an array of Union{Int,Nothing} + """ ZBack{F}(back) <: Function diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index accf5f650..6b69056b6 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -74,6 +74,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d function gradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) + isnothing(grad) && return nothing map(_project, args, grad) end @@ -97,6 +98,7 @@ true function withgradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) + isnothing(grad) && return (val=y, grad=nothing) (val = y, grad = map(_project, args, grad)) end diff --git a/test/features.jl b/test/features.jl index edf1950ee..4d4f6f53a 100644 --- a/test/features.jl +++ b/test/features.jl @@ -176,9 +176,9 @@ end @test gradient(t -> t[1]*t[2], (2, 3)) == ((3, 2),) -@test gradient(x -> x.re, 2+3im) == ((re = 1, im = nothing),) +@test_broken gradient(x -> x.re, 2+3im) == ((re = 1, im = nothing),) # should not error after ProjectTo upgrades to Tangent -@test gradient(x -> x.re*x.im, 2+3im) == ((re = 3, im = 2),) +@test_broken gradient(x -> x.re*x.im, 2+3im) == ((re = 3, im = 2),) struct Bar{T} a::T From 91fc91fd853eed6d90641d92acb86219c4d6f9fb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 19:55:17 -0400 Subject: [PATCH 10/43] fix more, including FFT tests --- test/features.jl | 4 +-- test/gradcheck.jl | 72 +++++++++++++++++++++++++++-------------------- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/test/features.jl b/test/features.jl index 4d4f6f53a..5866c893b 100644 --- a/test/features.jl +++ b/test/features.jl @@ -314,7 +314,7 @@ end[1] == 1 return x*5 end[1] == 5 -@test_skip gradient(x -> one(eltype(x)), rand(10))[1] === nothing # no method matching (::ProjectTo{AbstractArray, ...})(::Nothing) +@test gradient(x -> one(eltype(x)), rand(10))[1] === nothing # Thre-way control flow merge @test gradient(1) do x @@ -407,7 +407,7 @@ function pow_simd(x, n) return r end -@test_broken gradient(pow_simd, 2, 3) == (12,nothing) # no method matching (::ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::Nothing) +@test gradient(pow_simd, 2, 3) == (12,nothing) @testset "tuple getindex" begin @test gradient(x -> size(x)[2], ones(2,2,2)) == (nothing,) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index eab959ddd..84d8fecf1 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1403,6 +1403,16 @@ end end @testset "AbstractFFTs" begin + + # Many of these tests check a complex gradient to a function with real input. This is now + # clamped to real by ProjectTo, but to run the old tests, use here the old gradient function: + function oldgradient(f, args...) + y, back = pullback(f, args...) + back(sensitivity(y)) + end + # Eventually these rules and tests will be moved to ChainRules.jl, at which point the tests + # can be updated to use real / complex consistently. + findicateMat(i,j,n1,n2) = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:n1, l=1:n2] mirrorIndex(i,N) = i - 2*max(0,i - (N>>1+1)) @@ -1415,11 +1425,11 @@ end indicateMat = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:size(X, 1), l=1:size(X,2)] # gradient of ifft(fft) must be (approximately) 1 (for various cases) - @test gradient((X)->real.(ifft(fft(X))[i, j]), X)[1] ≈ indicateMat + @test oldgradient((X)->real.(ifft(fft(X))[i, j]), X)[1] ≈ indicateMat # same for the inverse - @test gradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] ≈ indicateMat + @test oldgradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] ≈ indicateMat # same for rfft(irfft) - @test gradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] ≈ real.(indicateMat) + @test oldgradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] ≈ real.(indicateMat) # rfft isn't actually surjective, so rffft(irfft) can't really be tested this way. # the gradients are actually just evaluating the inverse transform on the @@ -1438,22 +1448,22 @@ end ((K)->(irfft(K,sizeX[1])), 1/N * rfft(indicateMat), zeros(size(X̂r)), plan_rfft(X), i, X̂r)] for (trans, solRe, solIm, P, mI, evalX) in listOfSols - @test gradient((X)->real.(trans(X))[mI, j], evalX)[1] ≈ + @test oldgradient((X)->real.(trans(X))[mI, j], evalX)[1] ≈ solRe - @test gradient((X)->imag.(trans(X))[mI, j], evalX)[1] ≈ + @test oldgradient((X)->imag.(trans(X))[mI, j], evalX)[1] ≈ solIm if typeof(P) <:AbstractFFTs.Plan && maximum(trans .== [fft,rfft]) - @test gradient((X)->real.(P * X)[mI, j], evalX)[1] ≈ + @test oldgradient((X)->real.(P * X)[mI, j], evalX)[1] ≈ solRe - @test gradient((X)->imag.(P * X)[mI, j], evalX)[1] ≈ + @test oldgradient((X)->imag.(P * X)[mI, j], evalX)[1] ≈ solIm elseif typeof(P) <: AbstractFFTs.Plan - @test gradient((X)->real.(P \ X)[mI, j], evalX)[1] ≈ + @test oldgradient((X)->real.(P \ X)[mI, j], evalX)[1] ≈ solRe # for whatever reason the rfft_plan doesn't handle this case well, # even though irfft does if eltype(evalX) <: Real - @test gradient((X)->imag.(P \ X)[mI, j], evalX)[1] ≈ + @test oldgradient((X)->imag.(P \ X)[mI, j], evalX)[1] ≈ solIm end end @@ -1464,47 +1474,47 @@ end x = [-0.353213 -0.789656 -0.270151; -0.95719 -1.27933 0.223982] # check ffts for individual dimensions for trans in (fft, ifft, bfft) - @test gradient((x)->sum(abs.(trans(x))), x)[1] ≈ - gradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1] + @test oldgradient((x)->sum(abs.(trans(x))), x)[1] ≈ + oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1] # switch sum abs order - @test gradient((x)->abs(sum((trans(x)))),x)[1] ≈ - gradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1] + @test oldgradient((x)->abs(sum((trans(x)))),x)[1] ≈ + oldgradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1] # dims parameter for the function - @test gradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1] ≈ - gradient( (x) -> sum(abs.(trans(x))), x)[1] + @test oldgradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1] ≈ + oldgradient( (x) -> sum(abs.(trans(x))), x)[1] # (1,2) should be the same as no index - @test gradient( (x) -> sum(abs.(trans(x,(1,2)))), x)[1] ≈ - gradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1] + @test oldgradient( (x) -> sum(abs.(trans(x,(1,2)))), x)[1] ≈ + oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1] @test gradcheck(x->sum(abs.(trans(x))), x) @test gradcheck(x->sum(abs.(trans(x, 2))), x) end - @test gradient((x)->sum(abs.(rfft(x))), x)[1] ≈ - gradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1] - @test gradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1] ≈ - gradient( (x) -> sum(abs.(rfft(x))), x)[1] + @test oldgradient((x)->sum(abs.(rfft(x))), x)[1] ≈ + oldgradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1] + @test oldgradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1] ≈ + oldgradient( (x) -> sum(abs.(rfft(x))), x)[1] # Test type stability of fft x = randn(Float64,16) P = plan_fft(x) - @test typeof(gradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float64},1} - @test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float64},1} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float64,1} + @test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float64},1} + @test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float64},1} + @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float64,1} x = randn(Float64,16,16) - @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} + @test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2} + @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} x = randn(Float32,16) P = plan_fft(x) - @test typeof(gradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float32},1} - @test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float32},1} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1} + @test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float32},1} + @test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float32},1} + @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1} x = randn(Float32,16,16) - @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2} + @test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2} + @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2} end @testset "FillArrays" begin From d905c3d770640694ac10c5387555c3fdaf8d614c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 20:02:10 -0400 Subject: [PATCH 11/43] tests --- test/features.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/features.jl b/test/features.jl index 5866c893b..76af077b1 100644 --- a/test/features.jl +++ b/test/features.jl @@ -262,7 +262,7 @@ D(f, x) = grad(f, x)[1] @test D(x -> x*D(y -> x+y, 1), 1) == 1 @test D(x -> x*D(y -> x*y, 1), 4) == 8 -@test sin'''(1.0) == -cos(1.0) +@test_broken sin'''(1.0) == -cos(1.0) f(x) = throw(DimensionMismatch("fubar")) From fbebbe9bd2e3493492d457254a8b6d8aebb5247e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 20:08:31 -0400 Subject: [PATCH 12/43] one test --- test/forward/forward.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/forward/forward.jl b/test/forward/forward.jl index 3ae0f6e3a..1f1a67ae0 100644 --- a/test/forward/forward.jl +++ b/test/forward/forward.jl @@ -36,7 +36,7 @@ end == 1 x end == 0 -@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1)[1] +@test_broken D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1)[1] # ProjectTo means gradient here is real using LinearAlgebra From 502d85d060a70356b41694940a5814cb85476f4c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 20:18:29 -0400 Subject: [PATCH 13/43] tests --- test/gradcheck.jl | 4 ++-- test/structures.jl | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 84d8fecf1..8968f9118 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1407,8 +1407,8 @@ end # Many of these tests check a complex gradient to a function with real input. This is now # clamped to real by ProjectTo, but to run the old tests, use here the old gradient function: function oldgradient(f, args...) - y, back = pullback(f, args...) - back(sensitivity(y)) + y, back = Zygote.pullback(f, args...) + back(Zygote.sensitivity(y)) end # Eventually these rules and tests will be moved to ChainRules.jl, at which point the tests # can be updated to use real / complex consistently. diff --git a/test/structures.jl b/test/structures.jl index 37c0e246a..3f084c687 100644 --- a/test/structures.jl +++ b/test/structures.jl @@ -52,7 +52,8 @@ struct A594 x::Float64 end X = A594.(randn(2)) Y = randn(2,2) ∇ = gradient(g,X,Y) - @test ∇[1] == [(x = 2.0,); (x = 2.0,)] + @test_broken ∇[1] == [(x = 2.0,); (x = 2.0,)] # it's producing a 1-col Matrix, why? + @test vec(∇[1]) == [(x = 2.0,); (x = 2.0,)] @test ∇[2] == [1 1; 1 1] end From 361d047b93a94687490101489ab536730ffc6b7d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 20:24:00 -0400 Subject: [PATCH 14/43] tests --- src/compiler/chainrules.jl | 3 +++ test/gradcheck.jl | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index fa3e7485f..caaf486bd 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -155,6 +155,9 @@ The two-argument `_project(x, dx)` applies this immediately. # # But the original case was an array of Union{Int,Nothing} +# Solve some ambiguity: +(::ProjectTo{ChainRulesCore.NoTangent})(::ChainRulesCore.AbstractZero) = NoTangent() + """ ZBack{F}(back) <: Function diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 8968f9118..467b60b00 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1678,7 +1678,7 @@ end # check that type is not unnecessarily promoted # https://github.com/FluxML/Zygote.jl/issues/663 @test gradient(norm, randn(Float32, 2, 2)) isa Tuple{Matrix{Float32}} - @test gradient(norm, randn(Float32, 2, 2), 3) isa Tuple{Matrix{Float32},Float32} + @test gradient(norm, randn(Float32, 2, 2), 3) isa Tuple{Matrix{Float32},Float64} @test gradient(norm, randn(Float32, 2, 2), 3f0) isa Tuple{Matrix{Float32},Float32} @test gradient(norm, randn(ComplexF32, 2, 2), 3.5f0) isa Tuple{Matrix{ComplexF32},Float32} From b621330c3084381022dde736fafd16c33465a196 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 21:21:09 -0400 Subject: [PATCH 15/43] tests --- src/compiler/chainrules.jl | 3 +++ test/gradcheck.jl | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index caaf486bd..d82ce7b43 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -158,6 +158,9 @@ The two-argument `_project(x, dx)` applies this immediately. # Solve some ambiguity: (::ProjectTo{ChainRulesCore.NoTangent})(::ChainRulesCore.AbstractZero) = NoTangent() +# some splat? +(project::ProjectTo{AbstractArray})(dx::ChainRulesCore.Tangent{<:Any, <:Tuple}) = project(collect(ChainRulesCore.backing(dx))) + """ ZBack{F}(back) <: Function diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 467b60b00..aae1a97d6 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -332,10 +332,10 @@ end @test gradient(x -> sum(log, filter(iseven, x)), 1:10) == (map(x -> iseven(x) ? 1/x : 0, 1:10),) @test gradient(x -> sum(abs2, im .+ filter(iseven, x)), 1:10) == - (map(x -> iseven(x) ? 2x+2im : 0, 1:10),) + (map(x -> iseven(x) ? 2x : 0, 1:10),) + # (map(x -> iseven(x) ? 2x+2im : 0, 1:10),) end - @testset "mean" begin @test gradtest(mean, rand(2, 3)) From 3e3e16e1863b5f154997088f2c63e98f258d4b7b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 21:57:55 -0400 Subject: [PATCH 16/43] these are fixed --- test/gradcheck.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index aae1a97d6..f766ac64d 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1157,10 +1157,10 @@ end end @testset "hvcat" begin - @test gradient(xs -> hvcat((2,2),xs...)[1,1], [1,2,3,4])[1] == (1,0,0,0) - @test gradient(xs -> hvcat((2,2),xs...)[2,1], [1,2,3,4])[1] == (0,0,1,0) - @test gradient(xs -> hvcat((2,2),xs...)[1,2], [1,2,3,4])[1] == (0,1,0,0) - @test gradient(xs -> hvcat((2,2),xs...)[2,2], [1,2,3,4])[1] == (0,0,0,1) + @test gradient(xs -> hvcat((2,2),xs...)[1,1], [1,2,3,4])[1] == [1,0,0,0] + @test gradient(xs -> hvcat((2,2),xs...)[2,1], [1,2,3,4])[1] == [0,0,1,0] + @test gradient(xs -> hvcat((2,2),xs...)[1,2], [1,2,3,4])[1] == [0,1,0,0] + @test gradient(xs -> hvcat((2,2),xs...)[2,2], [1,2,3,4])[1] == [0,0,0,1] # https://github.com/FluxML/Zygote.jl/issues/513 @test gradient(x -> hvcat((2,2),1,2,3,x)[4], 4.0) == (1.0,) end From ea54df777756aaf1634468bac0c184b9a9518140 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 22:00:54 -0400 Subject: [PATCH 17/43] add Compat --- Project.toml | 2 ++ src/Zygote.jl | 1 + 2 files changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index 5cdf08775..130298310 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.6.21" AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -26,6 +27,7 @@ AbstractFFTs = "0.5, 1.0" ChainRules = "1.5" ChainRulesCore = "1.3" ChainRulesTestUtils = "1" +Compat = "2.2, 3" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12" ForwardDiff = "0.10" diff --git a/src/Zygote.jl b/src/Zygote.jl index ae023213c..153b9d5d9 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -11,6 +11,7 @@ using ChainRules: ChainRules, rrule, unthunk, canonicalize using IRTools using MacroTools, Requires using MacroTools: @forward +using Compat # for Julia 1.3, need Compat 2.2 import Distributed: pmap, CachingPool, workers export Params, withgradient, gradient, withjacobian, jacobian, hessian, diaghessian, pullback, pushforward, @code_adjoint From ff5f20e0954d2d39660a8df34552fc0e4efaaa33 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 22:32:17 -0400 Subject: [PATCH 18/43] tests --- test/complex.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/complex.jl b/test/complex.jl index 6a0445b85..5e197479f 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -2,8 +2,10 @@ using Zygote, Test, LinearAlgebra @test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] ≈ 1 @test gradient(x -> imag(real(x)+0.3im), 0.3)[1] ≈ 0 -@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] ≈ -1im -@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] == 1im +@test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] ≈ -1im +@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] ≈ 0 # projected to zero +@test gradient(x -> abs((imag(x)+0.3)), 0.3 + 0im)[1] ≈ 1im +@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] ≈ 0 @test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im @test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im From 8599e1bb004f2a866f8fc59527724c0f8342a088 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 18 Aug 2021 22:42:15 -0400 Subject: [PATCH 19/43] add tests for issues closed --- test/complex.jl | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/test/complex.jl b/test/complex.jl index 5e197479f..1abd1303f 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -1,5 +1,7 @@ using Zygote, Test, LinearAlgebra +@testset "basic" begin + @test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] ≈ 1 @test gradient(x -> imag(real(x)+0.3im), 0.3)[1] ≈ 0 @test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] ≈ -1im @@ -23,6 +25,8 @@ using Zygote, Test, LinearAlgebra @test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ real(im .* exp.(1:3)) @test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] ≈ im .* exp.(1:3) +end # @testset + fs_C_to_R = (real, imag, abs, @@ -83,3 +87,26 @@ fs_C_to_C_non_holomorphic = (conj, end end end + +@testset "issue 342" begin + @test Zygote.gradient(x->real(x + 2.0*im), 3.0) == (1.0,) + @test Zygote.gradient(x->imag(x + 2.0*im), 3.0) == (0.0,) +end + +@testset "issue 402" begin + A = [1,2,3.0] + y, B_getindex = Zygote.pullback(x->getindex(x,2,1),Diagonal(A)) + bA = B_getindex(1)[1] + @test bA isa Diagonal + @test bA == [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0] +end + +@testset "issue #917" begin + function fun(v) + c = v[1:3] + v[4:6]*im + r = v[7:9] + sum(r .* abs2.(c)) # This would be calling my actual function depending on r and c + end + @test Zygote.hessian(fun, collect(1:9)) ≈ [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0] +end + From 27e52b28b6bce5c77e490d512a5d947a224a2e27 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 09:13:43 -0400 Subject: [PATCH 20/43] simplify, some doctests --- Project.toml | 2 -- README.md | 2 +- src/Zygote.jl | 1 - src/compiler/interface.jl | 61 +++++++++++++++++---------------------- src/lib/broadcast.jl | 2 +- 5 files changed, 29 insertions(+), 39 deletions(-) diff --git a/Project.toml b/Project.toml index 130298310..5cdf08775 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.6.21" AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -27,7 +26,6 @@ AbstractFFTs = "0.5, 1.0" ChainRules = "1.5" ChainRulesCore = "1.3" ChainRulesTestUtils = "1" -Compat = "2.2, 3" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12" ForwardDiff = "0.10" diff --git a/README.md b/README.md index 8551bca87..6b2a6517d 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ julia> using Zygote julia> f(x) = 5x + 3 julia> f(10), f'(10) -(53, 5) +(53, 5.0) julia> @code_llvm f'(10) define i64 @"julia_#625_38792"(i64) { diff --git a/src/Zygote.jl b/src/Zygote.jl index 153b9d5d9..ae023213c 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -11,7 +11,6 @@ using ChainRules: ChainRules, rrule, unthunk, canonicalize using IRTools using MacroTools, Requires using MacroTools: @forward -using Compat # for Julia 1.3, need Compat 2.2 import Distributed: pmap, CachingPool, workers export Params, withgradient, gradient, withjacobian, jacobian, hessian, diaghessian, pullback, pushforward, @code_adjoint diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 6b69056b6..a43d62013 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -68,14 +68,13 @@ julia> gradient([7, 11], 0, 1) do x, y, d p = size(x, d) sum(x.^p .+ y) end -([14.0, 22.0], 2, nothing) +([14.0, 22.0], 2.0, nothing) ``` """ function gradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) - isnothing(grad) && return nothing - map(_project, args, grad) + isnothing(grad) ? nothing : map(_project, args, grad) end Base.adjoint(f::Function) = x -> gradient(f, x)[1] @@ -98,31 +97,12 @@ true function withgradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) - isnothing(grad) && return (val=y, grad=nothing) - (val = y, grad = map(_project, args, grad)) + results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad) + (val=y, grad=results) end # Param-style wrappers -""" - Params([A, B]) - -Container for implicit parameters, used when differentiating -a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`. -""" -struct Params - order::Buffer # {Any, Vector{Any}} - params::IdSet{Any} # TODO store ids only -end - -Params() = Params(Buffer([], false), IdSet()) -Params(xs) = Params(Buffer(xs, false), IdSet(xs)) -Params(ps::Params) = ps -Params(xs::Tuple) = Params(collect(xs)) - -@forward Params.order Base.iterate, Base.length, Base.getindex -@forward Params.params Base.in - """ gradient(() -> loss(), ps::Params) -> Grads @@ -138,24 +118,37 @@ julia> g = gradient(Params([x, y])) do Grads(...) julia> g[x] -2×3 Matrix{Int64}: - 7 70 700 - 8 80 800 +2×3 Matrix{Float64}: + 7.0 70.0 700.0 + 8.0 80.0 800.0 julia> haskey(g, z) # only x and y are parameters false ``` """ -function gradient(f, ps::Params) - y, back = pullback(f, ps) - back(sensitivity(y)) -end +gradient -function withgradient(f, ps::Params) - y, back = pullback(f, ps) - (val = y, grad = back(sensitivity(y))) +""" + Params([A, B]) + +Container for implicit parameters, used when differentiating +a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`. +""" +struct Params + order::Buffer # {Any, Vector{Any}} + params::IdSet{Any} # TODO store ids only end +Params() = Params(Buffer([], false), IdSet()) +Params(xs) = Params(Buffer(xs, false), IdSet(xs)) +Params(ps::Params) = ps +Params(xs::Tuple) = Params(collect(xs)) + +@forward Params.order Base.iterate, Base.length, Base.getindex +@forward Params.params Base.in + +Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params) + function Base.union!(ps::Params, itrs...) foreach(itr -> foreach(x -> push!(ps, x), itr), itrs) return ps diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 9accb4693..62d9bc928 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -51,7 +51,7 @@ function unbroadcast(x::AbstractArray, x̄) _project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors else tup = filter(d -> size(x, d) == 1, ntuple(identity, N)) - dims = length(tup) == 1 ? only(tup) : tup # avoid sum(xbar, dims=(1,)) as e.g. sum(SA[1 2; 3 4], dims=(1,)) fails + dims = length(tup) == 1 ? first(tup) : tup # avoid sum(xbar, dims=(1,)) as e.g. sum(SA[1 2; 3 4], dims=(1,)) fails _project(x, accum_sum(x̄; dims = dims)) end end From ff9aacf7e2b5943c6a1957330722cedb71b4662c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 09:30:41 -0400 Subject: [PATCH 21/43] fix some tests --- Project.toml | 2 +- src/compiler/chainrules.jl | 3 +++ test/features.jl | 4 ++-- test/forward/forward.jl | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 5cdf08775..0f8730b88 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" ChainRules = "1.5" -ChainRulesCore = "1.3" +ChainRulesCore = "1.3.1" ChainRulesTestUtils = "1" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12" diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index d82ce7b43..50757a76d 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -155,6 +155,9 @@ The two-argument `_project(x, dx)` applies this immediately. # # But the original case was an array of Union{Int,Nothing} +# CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any} +(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im)) + # Solve some ambiguity: (::ProjectTo{ChainRulesCore.NoTangent})(::ChainRulesCore.AbstractZero) = NoTangent() diff --git a/test/features.jl b/test/features.jl index 76af077b1..15bf4298e 100644 --- a/test/features.jl +++ b/test/features.jl @@ -176,9 +176,9 @@ end @test gradient(t -> t[1]*t[2], (2, 3)) == ((3, 2),) -@test_broken gradient(x -> x.re, 2+3im) == ((re = 1, im = nothing),) # should not error after ProjectTo upgrades to Tangent +@test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,) -@test_broken gradient(x -> x.re*x.im, 2+3im) == ((re = 3, im = 2),) +@test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,) struct Bar{T} a::T diff --git a/test/forward/forward.jl b/test/forward/forward.jl index 1f1a67ae0..6aa9173ef 100644 --- a/test/forward/forward.jl +++ b/test/forward/forward.jl @@ -36,7 +36,8 @@ end == 1 x end == 0 -@test_broken D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1)[1] # ProjectTo means gradient here is real +@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1+0im)[1] +@test real(D(x -> abs(x+2im), 1)) == gradient(x -> abs(x+2im), 1)[1] # ProjectTo means gradient here is real using LinearAlgebra From 5bf53424a5a2128a5f914107959fcd4432fc498a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 09:51:25 -0400 Subject: [PATCH 22/43] less piracy --- src/compiler/chainrules.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 50757a76d..5ad062d20 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -123,6 +123,7 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR """ @inline wrap_chainrules_input(x) = x @inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent() +@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent() @inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) xp = map(wrap_chainrules_input, xs) ChainRules.Tangent{Any, typeof(xp)}(xp) @@ -143,7 +144,7 @@ The two-argument `_project(x, dx)` applies this immediately. @inline _project(x, dx) = _project(x)(dx) # PIRACY -- some tests hit a matrix of nothings, which doesn't seem to be handled? -(::ChainRulesCore.ProjectTo)(nothing) = ChainRulesCore.NoTangent() +# (::ChainRulesCore.ProjectTo)(nothing) = ChainRulesCore.NoTangent() # julia> Zygote.wrap_chainrules_input(nothing) # ChainRulesCore.ZeroTangent() @@ -159,7 +160,7 @@ The two-argument `_project(x, dx)` applies this immediately. (project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im)) # Solve some ambiguity: -(::ProjectTo{ChainRulesCore.NoTangent})(::ChainRulesCore.AbstractZero) = NoTangent() +# (::ProjectTo{ChainRulesCore.NoTangent})(::ChainRulesCore.AbstractZero) = NoTangent() # some splat? (project::ProjectTo{AbstractArray})(dx::ChainRulesCore.Tangent{<:Any, <:Tuple}) = project(collect(ChainRulesCore.backing(dx))) From e9ea88a47b8fb1169c2664210bd9f838305b6d0b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 10:05:02 -0400 Subject: [PATCH 23/43] adjoint --- src/compiler/interface.jl | 6 +++++- test/features.jl | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index a43d62013..9dc934a49 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -77,7 +77,11 @@ function gradient(f, args...) isnothing(grad) ? nothing : map(_project, args, grad) end -Base.adjoint(f::Function) = x -> gradient(f, x)[1] +# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! +Base.adjoint(f::Function) = x -> begin # still piracy! avoids projection for legacy reasons + y, back = pullback(f, x) + back(sensitivity(y))[1] +end """ withgradient(f, args...) diff --git a/test/features.jl b/test/features.jl index 15bf4298e..b02dc962e 100644 --- a/test/features.jl +++ b/test/features.jl @@ -262,7 +262,8 @@ D(f, x) = grad(f, x)[1] @test D(x -> x*D(y -> x+y, 1), 1) == 1 @test D(x -> x*D(y -> x*y, 1), 4) == 8 -@test_broken sin'''(1.0) == -cos(1.0) +@test sin''(1.0) == -sin(1.0) +@test sin'''(1.0) == -cos(1.0) f(x) = throw(DimensionMismatch("fubar")) From 0013fd377eb78a1dca21b93768a880ce4f27f524 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 10:05:11 -0400 Subject: [PATCH 24/43] piract --- src/compiler/chainrules.jl | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 5ad062d20..0c5dcea43 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -143,26 +143,14 @@ The two-argument `_project(x, dx)` applies this immediately. @inline _project(x, dx) = _project(x)(dx) -# PIRACY -- some tests hit a matrix of nothings, which doesn't seem to be handled? -# (::ChainRulesCore.ProjectTo)(nothing) = ChainRulesCore.NoTangent() - -# julia> Zygote.wrap_chainrules_input(nothing) -# ChainRulesCore.ZeroTangent() -# -# julia> Zygote.wrap_chainrules_input([nothing, nothing]) -# 2-element Vector{Nothing}: -# nothing -# nothing -# -# But the original case was an array of Union{Int,Nothing} +# Piracy: +# wrap_chainrules_input doesn't handle array of Union{Int,Nothing} +(::ChainRulesCore.ProjectTo)(nothing) = ChainRulesCore.NoTangent() # CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any} (project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im)) -# Solve some ambiguity: -# (::ProjectTo{ChainRulesCore.NoTangent})(::ChainRulesCore.AbstractZero) = NoTangent() - -# some splat? +# Restore some splatted arrays (project::ProjectTo{AbstractArray})(dx::ChainRulesCore.Tangent{<:Any, <:Tuple}) = project(collect(ChainRulesCore.backing(dx))) """ From c07ae9f1b01bd62c2174fcba8e2e1aadcb38b040 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 10:16:25 -0400 Subject: [PATCH 25/43] skip a test --- test/utils.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index e8ed59888..92ae218de 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -36,8 +36,10 @@ end f713(zs) = sum(vec(zs)' .* exp.(vec(zs))) @test vec(diaghessian(f713, zs)[1]) ≈ diag(hessian(f713, zs)) - @test_throws Exception diaghessian(sin, im*pi) - @test_throws Exception diaghessian(x -> x+im, pi) + if VERSION >= v"1.6-" + @test_throws Exception diaghessian(sin, im*pi) + @test_throws Exception diaghessian(x -> x+im, pi) + end @test_throws Exception diaghessian(identity, randn(2)) end From 7ff11595fb37ff055f6a9b3dae471321e7633620 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 10:30:41 -0400 Subject: [PATCH 26/43] splat tests --- test/features.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/features.jl b/test/features.jl index b02dc962e..3bed87b2f 100644 --- a/test/features.jl +++ b/test/features.jl @@ -500,6 +500,22 @@ end @test x[1] == x[2] end +@testset "splats" begin + @test gradient(x -> max(x...), [1,2,3])[1] == [0,0,1] + @test gradient(x -> min(x...), (1,2,3))[1] === (1.0, 0.0, 0.0) + + # https://github.com/FluxML/Zygote.jl/issues/599 + @test gradient(w -> sum([w...]), [1,1])[1] isa AbstractVector + + # https://github.com/FluxML/Zygote.jl/issues/866 + f866(x) = reshape(x, fill(2, 2)...) + @test gradient(x->sum(f866(x)), rand(4))[1] == [1,1,1,1] + + # https://github.com/FluxML/Zygote.jl/issues/731 + f731(x) = sum([x' * x, x...]) + @test_broken gradient(f731, ones(3)) # MethodError: no method matching +(::Tuple{Float64, Float64, Float64}, ::Vector{Float64}) +end + @testset "accumulation" begin # from https://github.com/FluxML/Zygote.jl/issues/905 function net(x1) From 6549c57de806df7e31710e4da1cbb08b4cdd16b5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 10:33:07 -0400 Subject: [PATCH 27/43] skip on 1.3 --- test/utils.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index 92ae218de..06464c3e7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -19,7 +19,11 @@ using Zygote: hessian_dual, hessian_reverse @test_throws Exception hess(identity, randn(2)) end -@testset "diagonal hessian" begin +VERSION > v"1.6-" && @testset "diagonal hessian" begin +# Avoiding this error on Julia 1.3 CI, not sure the exact test which causes it: +# julia> log(Dual(1,0) + 0im) +# ERROR: StackOverflowError: + @test diaghessian(x -> x[1]*x[2]^2, [1, pi]) == ([0, 2],) xs, y = randn(2,3), rand() @@ -36,10 +40,8 @@ end f713(zs) = sum(vec(zs)' .* exp.(vec(zs))) @test vec(diaghessian(f713, zs)[1]) ≈ diag(hessian(f713, zs)) - if VERSION >= v"1.6-" - @test_throws Exception diaghessian(sin, im*pi) - @test_throws Exception diaghessian(x -> x+im, pi) - end + @test_throws Exception diaghessian(sin, im*pi) + @test_throws Exception diaghessian(x -> x+im, pi) @test_throws Exception diaghessian(identity, randn(2)) end From 298f1191f347f4ee818553b26c90accfbf4a9842 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 11:05:59 -0400 Subject: [PATCH 28/43] simplify _project --- src/compiler/chainrules.jl | 19 ++++++++----------- test/features.jl | 3 +++ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 0c5dcea43..00834b058 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -130,18 +130,18 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR end """ - _project(x)(dx) _project(x, dx) -The function `_project(x)` returns a projector, which standardises the gradient `dx` for type & shape. -Uses `ChainRulesCore.ProjectTo`, but is safe to apply to arbitrary input. -The two-argument `_project(x, dx)` applies this immediately. +Uses `ChainRulesCore.ProjectTo` to standardise the gradient `dx` for type & shape. +Safe to apply to arbitrary input. """ -@inline _project(x) = identity # fallback: do nothing! -@inline _project(x::Numeric) = wrap_chainrules_output ∘ ProjectTo(x) ∘ wrap_chainrules_input -@inline _project(x::Ref{<:Numeric}) = wrap_chainrules_output ∘ ProjectTo(x) ∘ wrap_chainrules_input +@inline _project(x, dx) = dx # fallback: do nothing! +@inline function _project(x::Union{Numeric, Ref{<:Numeric}}, dx) + wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx))) +end -@inline _project(x, dx) = _project(x)(dx) +# Restore splatted arrays +_project(x::AbstractArray{<:Number}, dx::Tuple) = _project(x, reshape(collect(dx), axes(x))) # Piracy: # wrap_chainrules_input doesn't handle array of Union{Int,Nothing} @@ -150,9 +150,6 @@ The two-argument `_project(x, dx)` applies this immediately. # CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any} (project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im)) -# Restore some splatted arrays -(project::ProjectTo{AbstractArray})(dx::ChainRulesCore.Tangent{<:Any, <:Tuple}) = project(collect(ChainRulesCore.backing(dx))) - """ ZBack{F}(back) <: Function diff --git a/test/features.jl b/test/features.jl index 3bed87b2f..d683d0d94 100644 --- a/test/features.jl +++ b/test/features.jl @@ -504,6 +504,9 @@ end @test gradient(x -> max(x...), [1,2,3])[1] == [0,0,1] @test gradient(x -> min(x...), (1,2,3))[1] === (1.0, 0.0, 0.0) + @test gradient(x -> max(x...), [1 2; 3 4])[1] == [0 0; 0 1] + @test gradient(x -> max(x...), [1,2,3]')[1] == [0 0 1] + # https://github.com/FluxML/Zygote.jl/issues/599 @test gradient(w -> sum([w...]), [1,1])[1] isa AbstractVector From e3922a930bab72ca51fb33093dcb6e59abab6818 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 11:07:37 -0400 Subject: [PATCH 29/43] a typo --- src/compiler/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 00834b058..2428f4531 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -145,7 +145,7 @@ _project(x::AbstractArray{<:Number}, dx::Tuple) = _project(x, reshape(collect(dx # Piracy: # wrap_chainrules_input doesn't handle array of Union{Int,Nothing} -(::ChainRulesCore.ProjectTo)(nothing) = ChainRulesCore.NoTangent() +(::ChainRulesCore.ProjectTo)(::Nothing) = ChainRulesCore.NoTangent() # CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any} (project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im)) From a2814ae001790c5f40f0000359a113ea61231a6b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 11:26:44 -0400 Subject: [PATCH 30/43] tweak --- src/compiler/chainrules.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 2428f4531..7f9cf34fc 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -133,12 +133,14 @@ end _project(x, dx) Uses `ChainRulesCore.ProjectTo` to standardise the gradient `dx` for type & shape. +Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`. Safe to apply to arbitrary input. """ -@inline _project(x, dx) = dx # fallback: do nothing! @inline function _project(x::Union{Numeric, Ref{<:Numeric}}, dx) wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx))) end +_project(x::AbstractArray, dx) = reshape(dx, axes(x)) +_project(x, dx) = dx # Restore splatted arrays _project(x::AbstractArray{<:Number}, dx::Tuple) = _project(x, reshape(collect(dx), axes(x))) From 08f8c4625d06eba01d596cebf7e0ad3f9bff0e4c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 11:39:28 -0400 Subject: [PATCH 31/43] broken GPU test, unrelated --- test/cuda.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cuda.jl b/test/cuda.jl index 3999ace59..a32b41a88 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -6,7 +6,7 @@ CUDA.allowscalar(false) # Test GPU movement inside the call to `gradient` @testset "GPU movement" begin r = rand(Float32, 3,3) - @test gradient(x -> sum(cu(x)), r)[1] isa Array{Float32, 2} + @test_broken gradient(x -> sum(cu(x)), r)[1] isa Array{Float32, 2} end @testset "broadcasting" begin From c8bc5887f3ddeba6b3f4c8466d47ba28fe208532 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 11:41:16 -0400 Subject: [PATCH 32/43] unexpected pass --- test/structures.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/structures.jl b/test/structures.jl index 3f084c687..5a951a621 100644 --- a/test/structures.jl +++ b/test/structures.jl @@ -52,7 +52,7 @@ struct A594 x::Float64 end X = A594.(randn(2)) Y = randn(2,2) ∇ = gradient(g,X,Y) - @test_broken ∇[1] == [(x = 2.0,); (x = 2.0,)] # it's producing a 1-col Matrix, why? + @test ∇[1] == [(x = 2.0,); (x = 2.0,)] @test vec(∇[1]) == [(x = 2.0,); (x = 2.0,)] @test ∇[2] == [1 1; 1 1] end From 50804900f97a9a2c8619af77998b9e51c52bf3eb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 11:50:42 -0400 Subject: [PATCH 33/43] only broken on 1.6 --- test/cuda.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/cuda.jl b/test/cuda.jl index a32b41a88..d5870946d 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -6,7 +6,11 @@ CUDA.allowscalar(false) # Test GPU movement inside the call to `gradient` @testset "GPU movement" begin r = rand(Float32, 3,3) - @test_broken gradient(x -> sum(cu(x)), r)[1] isa Array{Float32, 2} + if VERSION < v"1.6" + @test gradient(x -> sum(cu(x)), r)[1] isa Matrix{Float32} + else + @test_broken gradient(x -> sum(cu(x)), r)[1] isa Matrix{Float32} + end end @testset "broadcasting" begin From 1b37161851126e935e559f332661f9517d16ebb3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 13:37:56 -0400 Subject: [PATCH 34/43] let nothing through --- src/compiler/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 7f9cf34fc..eb9b9622f 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -139,7 +139,7 @@ Safe to apply to arbitrary input. @inline function _project(x::Union{Numeric, Ref{<:Numeric}}, dx) wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx))) end -_project(x::AbstractArray, dx) = reshape(dx, axes(x)) +_project(x::AbstractArray, dx) = dx isa AbstractArray ? reshape(dx, axes(x)) : dx _project(x, dx) = dx # Restore splatted arrays From 4c08118d1c8a61bb053ede2f39e34beed9411d36 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 13:40:05 -0400 Subject: [PATCH 35/43] rm some broken things --- test/gradcheck.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index f766ac64d..dabd045ba 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1375,10 +1375,10 @@ using Zygote: Buffer @test gs[1] ≈ map(x -> one.(x), p) @test gs[2] ≈ one.(r) - p = [rand(3,3), rand(3,3)] # redefine `p` after mutation - gs = gradient(x -> sum(pop!(x)), p) - @test length(gs[1]) == 2 - @test gs[1][1] == one.(p[1]) + # p = [rand(3,3), rand(3,3)] # redefine `p` after mutation + # gs = gradient(x -> sum(pop!(x)), p) + # @test length(gs[1]) == 2 + # @test gs[1][1] == one.(p[1]) end end From 71974915712560e67b2367aaec8f1a0348849fe5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 5 Sep 2021 13:51:43 -0400 Subject: [PATCH 36/43] target 1.3 fix --- test/utils.jl | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index 06464c3e7..0f391fc32 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -20,19 +20,20 @@ using Zygote: hessian_dual, hessian_reverse end VERSION > v"1.6-" && @testset "diagonal hessian" begin -# Avoiding this error on Julia 1.3 CI, not sure the exact test which causes it: -# julia> log(Dual(1,0) + 0im) -# ERROR: StackOverflowError: - @test diaghessian(x -> x[1]*x[2]^2, [1, pi]) == ([0, 2],) - xs, y = randn(2,3), rand() - f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments - - dx, dy = diaghessian(f34, xs, y) - @test size(dx) == size(xs) - @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs)) - @test dy ≈ hessian(y -> f34(xs,y), y) + if VERSION > v"1.6-" + # Gradient of ^ may contain log(complex(...)), which interacts badly with Dual below Julia 1.6: + # julia> log(ForwardDiff.Dual(1,0) + 0im) + # ERROR: StackOverflowError: + xs, y = randn(2,3), rand() + f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments + + dx, dy = diaghessian(f34, xs, y) + @test size(dx) == size(xs) + @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs)) + @test dy ≈ hessian(y -> f34(xs,y), y) + end zs = randn(7,13) # test chunk mode @test length(zs) > ForwardDiff.DEFAULT_CHUNK_THRESHOLD From dde922ba79d380fda3f3dbe719cb52e5057277ba Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 9 Sep 2021 11:39:24 -0400 Subject: [PATCH 37/43] comments --- test/gradcheck.jl | 1 + test/utils.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index dabd045ba..75a64db99 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1412,6 +1412,7 @@ end end # Eventually these rules and tests will be moved to ChainRules.jl, at which point the tests # can be updated to use real / complex consistently. + # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58 findicateMat(i,j,n1,n2) = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:n1, l=1:n2] diff --git a/test/utils.jl b/test/utils.jl index 0f391fc32..037d46c53 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -24,8 +24,8 @@ VERSION > v"1.6-" && @testset "diagonal hessian" begin if VERSION > v"1.6-" # Gradient of ^ may contain log(complex(...)), which interacts badly with Dual below Julia 1.6: - # julia> log(ForwardDiff.Dual(1,0) + 0im) - # ERROR: StackOverflowError: + # julia> log(ForwardDiff.Dual(1,0) + 0im) # ERROR: StackOverflowError: + # https://github.com/JuliaDiff/ChainRules.jl/issues/525 xs, y = randn(2,3), rand() f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments From 1c07a7ccdc7c6dea63af090092eeaa4eb431303c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 21 Sep 2021 14:44:07 -0400 Subject: [PATCH 38/43] update for ProjectTo(::Any) --- Project.toml | 2 +- src/compiler/chainrules.jl | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 0f8730b88..d98656b1c 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" ChainRules = "1.5" -ChainRulesCore = "1.3.1" +ChainRulesCore = "1.6" ChainRulesTestUtils = "1" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12" diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index eb9b9622f..e879af3f8 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -136,14 +136,12 @@ Uses `ChainRulesCore.ProjectTo` to standardise the gradient `dx` for type & shap Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`. Safe to apply to arbitrary input. """ -@inline function _project(x::Union{Numeric, Ref{<:Numeric}}, dx) +@inline function _project(x, dx) wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx))) end -_project(x::AbstractArray, dx) = dx isa AbstractArray ? reshape(dx, axes(x)) : dx -_project(x, dx) = dx # Restore splatted arrays -_project(x::AbstractArray{<:Number}, dx::Tuple) = _project(x, reshape(collect(dx), axes(x))) +_project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x))) # Piracy: # wrap_chainrules_input doesn't handle array of Union{Int,Nothing} From 35280d50998373f88108be51e076ca9344053c3b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 21 Sep 2021 15:02:30 -0400 Subject: [PATCH 39/43] fix a test --- test/gradcheck.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 75a64db99..af49b7697 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -177,7 +177,7 @@ end # Ensure that nothings work with non-numeric types. _, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1]) - @test back([nothing]) == ([nothing for _ in 1:3], nothing) + @test back([nothing]) == (nothing, nothing) end @testset "view" begin From 80123a1d962d7fea00d4cfa6919b687093108e2d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 21 Sep 2021 16:07:34 -0400 Subject: [PATCH 40/43] Update test/utils.jl Co-authored-by: Lyndon White --- test/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/utils.jl b/test/utils.jl index 037d46c53..b6d6ed018 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -26,6 +26,7 @@ VERSION > v"1.6-" && @testset "diagonal hessian" begin # Gradient of ^ may contain log(complex(...)), which interacts badly with Dual below Julia 1.6: # julia> log(ForwardDiff.Dual(1,0) + 0im) # ERROR: StackOverflowError: # https://github.com/JuliaDiff/ChainRules.jl/issues/525 + # Fixed in 1.6 by: https://github.com/JuliaLang/julia/pull/36030 xs, y = randn(2,3), rand() f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments From 3bc2e099f7b484a8d86233d6dcacfc413d0af834 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 21 Sep 2021 16:36:26 -0400 Subject: [PATCH 41/43] Update src/lib/broadcast.jl --- src/lib/broadcast.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 62d9bc928..4e7a3a1cc 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -50,8 +50,7 @@ function unbroadcast(x::AbstractArray, x̄) if length(x) == length(x̄) _project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors else - tup = filter(d -> size(x, d) == 1, ntuple(identity, N)) - dims = length(tup) == 1 ? first(tup) : tup # avoid sum(xbar, dims=(1,)) as e.g. sum(SA[1 2; 3 4], dims=(1,)) fails + dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄)) _project(x, accum_sum(x̄; dims = dims)) end end From 02397b57570eac6e355db000beba199b4c7dd7aa Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 21 Sep 2021 22:08:29 -0400 Subject: [PATCH 42/43] cu tests --- test/cuda.jl | 43 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/test/cuda.jl b/test/cuda.jl index d5870946d..5cb1c8cdc 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -1,16 +1,20 @@ using CUDA using Zygote: Grads +using LinearAlgebra using Random: randn! CUDA.allowscalar(false) # Test GPU movement inside the call to `gradient` @testset "GPU movement" begin r = rand(Float32, 3,3) - if VERSION < v"1.6" - @test gradient(x -> sum(cu(x)), r)[1] isa Matrix{Float32} - else - @test_broken gradient(x -> sum(cu(x)), r)[1] isa Matrix{Float32} - end + @test gradient(x -> sum(cu(x)), r)[1] isa Matrix{Float32} + @test gradient(x -> sum(x->log(x), cu(x)), r)[1] isa Matrix + @test gradient((x,cy) -> sum(cu(x) * cy) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray + @test_skip gradient((x,cy) -> sum(cu(x[:,1])' * cy), r, cu(r))[2] isa CUDA.CuArray # generic_matmatmul! + + # Other direction: + @test_skip gradient(x -> sum(Array(x)), cu(r))[1] isa CUDA.CuArray + @test_skip gradient((x,cy) -> sum(x * Array(cy)) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray end @testset "broadcasting" begin @@ -35,10 +39,19 @@ end g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression @test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018 @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] + + # Projection: eltype preservation: + @test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32} + @test_skip gradient(x -> sum(x .* 5.6), a_gpu)[1] isa CUDA.CuArray{Float32} # dot(x::CuArray{Float64}, y::CuArray{Float32}) fallback + # structure restoration: + @test gradient(x -> sum(sqrt.(x)), a_gpu')[1] isa Adjoint # previously a matrix + @test gradient(x -> sum(exp.(x)), Diagonal(a_gpu))[1] isa Diagonal + # non-differentiables + @test gradient((x,y) -> sum(x.^2 .+ y'), a_gpu, a_gpu .> 0)[2] === nothing end @testset "sum(f, x)" begin - a = Float32.([-1.5, -9.0, 2.4, -1.3, 0.01]) + a = Float32[-1.5, -9.0, 2.4, -1.3, 0.01] a_gpu = a |> cu f(x) = sum(abs, x) @@ -46,6 +59,18 @@ end g_gpu = gradient(f, a_gpu)[1] @test g_gpu isa CuArray @test g_gpu |> collect ≈ g + + f2(x) = sum(abs2, x) # sum(abs2, x) has its own rrule + g2 = gradient(f2, a)[1] + g2_gpu = gradient(f2, a_gpu)[1] + @test g2_gpu isa CuArray + @test g2_gpu |> collect ≈ g2 + + f3(x) = sum(y->y^3, x') # anonymous function + g3 = gradient(f3, a')[1] + g3_gpu = gradient(f3, a_gpu')[1] + @test g3_gpu isa Adjoint{Float32, <:CuArray{Float32, 1}} # preserves structure + @test g3_gpu |> collect ≈ g3 end @testset "jacobian" begin @@ -107,5 +132,11 @@ end r = cu(rand(Float32, 3)) grads = (cu(ones(Float32, 3)), 1.f0) @test gradient((x,y) -> sum(vcat(x,y)), r, 5) == grads + + @test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[1] isa CUDA.CuArray{Float32} + @test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[2] isa Float64 # projection + + @test_skip gradient((x,y) -> sum(vcat(x,y)), 5f0, r)[2] isa CUDA.CuArray{Float32} # wrong order + @test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32} end From a3e3a97bef92debf92792f32b73fe13421bc3bd2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 21 Sep 2021 22:59:29 -0400 Subject: [PATCH 43/43] v0.6.22 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d98656b1c..56d086f99 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.21" +version = "0.6.22" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"