From ce4b82d238b52de749a8421a86a38007f76b4047 Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Tue, 5 Nov 2024 16:05:16 -0500 Subject: [PATCH] Improve macro --- benchmark/Manifest.toml | 70 ++++++++++++++++++++++++++++++---------- benchmark/Project.toml | 1 + benchmark/groups/pinn.jl | 69 +++++++++++++++++++++++---------------- src/array.jl | 5 +++ src/chainrules.jl | 17 +++++----- src/primitive.jl | 11 ++++--- src/utils.jl | 3 ++ 7 files changed, 119 insertions(+), 57 deletions(-) diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index 3a90546..8ff699c 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.11.1" manifest_format = "2.0" -project_hash = "448f3cae2dea59422644934cbad9b5490d5dc6eb" +project_hash = "89d6a775281b4cbd649c4e44a66577335ae263f3" [[deps.ADTypes]] git-tree-sha1 = "eea5d80188827b35333801ef97a40c2ed653b081" @@ -57,9 +57,9 @@ version = "0.1.38" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "d80af0733c99ea80575f612813fa6aa71022d33a" +git-tree-sha1 = "50c3c56a52972d78e8be9fd135bfb91c9574c140" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.1.0" +version = "4.1.1" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -221,6 +221,36 @@ deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" version = "1.1.1+0" +[[deps.ComponentArrays]] +deps = ["ArrayInterface", "ChainRulesCore", "ForwardDiff", "Functors", "LinearAlgebra", "PackageExtensionCompat", "StaticArrayInterface", "StaticArraysCore"] +git-tree-sha1 = "bc391f0c19fa242fb6f71794b949e256cfa3772c" +uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +version = "0.15.17" + + [deps.ComponentArrays.extensions] + ComponentArraysAdaptExt = "Adapt" + ComponentArraysConstructionBaseExt = "ConstructionBase" + ComponentArraysGPUArraysExt = "GPUArrays" + ComponentArraysOptimisersExt = "Optimisers" + ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools" + ComponentArraysReverseDiffExt = "ReverseDiff" + ComponentArraysSciMLBaseExt = "SciMLBase" + ComponentArraysTrackerExt = "Tracker" + ComponentArraysTruncatedStacktracesExt = "TruncatedStacktraces" + ComponentArraysZygoteExt = "Zygote" + + [deps.ComponentArrays.weakdeps] + Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" + ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" + GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" + Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" + RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [[deps.CompositeTypes]] git-tree-sha1 = "bce26c3dab336582805503bed209faab1c279768" uuid = "b152e2b5-7a66-4b01-a709-34e65c35f657" @@ -364,9 +394,9 @@ version = "1.0.4" [[deps.Enzyme]] deps = ["CEnum", "EnzymeCore", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "Preferences", "Printf", "Random", "SparseArrays"] -git-tree-sha1 = "aba39bfce6e65ce740b29c8d9d0c8a6c5770e3c1" +git-tree-sha1 = "abcbb722aafe8ed9cc667884b3a1e1d259c5e562" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" -version = "0.13.12" +version = "0.13.13" [deps.Enzyme.extensions] EnzymeBFloat16sExt = "BFloat16s" @@ -383,9 +413,9 @@ version = "0.13.12" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.EnzymeCore]] -git-tree-sha1 = "9c3a42611e525352e9ad5e4134ddca5c692ff209" +git-tree-sha1 = "04c777af6ef65530a96ab68f0a81a4608113aa1d" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" -version = "0.8.4" +version = "0.8.5" weakdeps = ["Adapt"] [deps.EnzymeCore.extensions] @@ -393,9 +423,9 @@ weakdeps = ["Adapt"] [[deps.Enzyme_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c180391e0a09fedb2934e5c44455e13c38f859e6" +git-tree-sha1 = "62cf2140d8daa3181e9f9d7a8b5e7b9493a57f21" uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef" -version = "0.0.157+0" +version = "0.0.159+0" [[deps.ExceptionUnwrapping]] deps = ["Test"] @@ -442,9 +472,9 @@ version = "1.3.7" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +git-tree-sha1 = "a9ce73d3c827adab2d70bf168aaece8cce196898" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" +version = "0.10.37" weakdeps = ["StaticArrays"] [deps.ForwardDiff.extensions] @@ -986,6 +1016,12 @@ git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" version = "0.11.31" +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" @@ -1088,9 +1124,9 @@ version = "1.3.4" [[deps.RecursiveArrayTools]] deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "43cdc0987135597867a37fc3e8e0fc9fdef6ac66" +git-tree-sha1 = "ed2514425d030d7c9054fa0f2275ada45681788d" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "3.27.1" +version = "3.27.2" [deps.RecursiveArrayTools.extensions] RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" @@ -1152,9 +1188,9 @@ version = "0.1.0" [[deps.SciMLBase]] deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "Expronicon", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"] -git-tree-sha1 = "86e1c491cddf233d77d8aadbe289005db44e8445" +git-tree-sha1 = "7a54136472ca0cb0f66ef22aa3f0ff198f379fa7" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.57.2" +version = "2.58.0" [deps.SciMLBase.extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" @@ -1374,9 +1410,9 @@ version = "3.7.2" [[deps.Symbolics]] deps = ["ADTypes", "ArrayInterface", "Bijections", "CommonWorldInvalidations", "ConstructionBase", "DataStructures", "DiffRules", "Distributions", "DocStringExtensions", "DomainSets", "DynamicPolynomials", "IfElse", "LaTeXStrings", "Latexify", "Libdl", "LinearAlgebra", "LogExpFunctions", "MacroTools", "Markdown", "NaNMath", "PrecompileTools", "Primes", "RecipesBase", "Reexport", "RuntimeGeneratedFunctions", "SciMLBase", "Setfield", "SparseArrays", "SpecialFunctions", "StaticArraysCore", "SymbolicIndexingInterface", "SymbolicLimits", "SymbolicUtils", "TermInterface"] -git-tree-sha1 = "ef7532b95fbd529e1252cabb36bba64803020840" +git-tree-sha1 = "41852067b437d16a3ad4e01705ffc6e22925c42c" uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7" -version = "6.16.0" +version = "6.17.0" [deps.Symbolics.extensions] SymbolicsForwardDiffExt = "ForwardDiff" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 30011e7..9acf62c 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -1,5 +1,6 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" diff --git a/benchmark/groups/pinn.jl b/benchmark/groups/pinn.jl index d68dded..9a3e90f 100644 --- a/benchmark/groups/pinn.jl +++ b/benchmark/groups/pinn.jl @@ -1,41 +1,56 @@ -using Lux, Zygote +using Lux, Zygote, Enzyme, ComponentArrays -const input = 2 -const hidden = 16 - -model = Chain(Dense(input => hidden, Lux.relu), - Dense(hidden => hidden, Lux.relu), - Dense(hidden => 1), - first) - -ps, st = Lux.setup(rng, model) - -trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x, ps, st)[1] +function trial(model, x, ps, st) + u, st = Lux.apply(model, x, ps, st) + x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * u +end -x = rand(Float32, input) -trial(model, x) -function loss_by_finitediff(model, x) - ε = cbrt(eps(Float32)) - ε₁ = [ε, 0] - ε₂ = [0, ε] - error = (trial(model, x + ε₁) + trial(model, x - ε₁) + trial(model, x + ε₂) + - trial(model, x - ε₂) - 4 * trial(model, x)) / - ε^2 + sin(π * x[1]) * sin(π * x[2]) +function loss_by_finitediff(model, x, ps, st) + T = eltype(x) + ε = cbrt(eps(T)) + ε₁ = [ε, zero(T)] + ε₂ = [zero(T), ε] + f(x) = trial(model, x, ps, st) + error = (f(x + ε₁) + f(x - ε₁) + f(x + ε₂) + f(x - ε₂) - 4 * f(x)) / ε^2 + + sin(π * x[1]) * sin(π * x[2]) abs2(error) end -function loss_by_taylordiff(model, x) - f(x) = trial(model, x) +function loss_by_taylordiff(model, x, ps, st) + f(x) = trial(model, x, ps, st) error = derivative(f, x, Float32[1, 0], Val(2)) + derivative(f, x, Float32[0, 1], Val(2)) + sin(π * x[1]) * sin(π * x[2]) abs2(error) end +function loss_by_forwarddiff(model, x, ps, st) + f(x) = trial(model, x, ps, st) + error = derivative(f, x, Float32[1, 0], Val(2)) + + derivative(f, x, Float32[0, 1], Val(2)) + + sin(π * x[1]) * sin(π * x[2]) + abs2(error) +end + +const input = 2 +const hidden = 16 +model = Chain(Dense(input => hidden, exp), + Dense(hidden => hidden, exp), + Dense(hidden => 1), + first) +x = rand(Float32, input) +dx = deepcopy(x) +ps, st = Lux.setup(rng, model) +ps = ps |> ComponentArray +dps = deepcopy(ps) +dx .= 0; +dps .= 0; -pinn_t = BenchmarkGroup("primal" => (@benchmarkable loss_by_taylordiff($model, $x)), +pinn_t = BenchmarkGroup( + "primal" => (@benchmarkable loss_by_taylordiff($model, $x, $ps, $st)), "gradient" => (@benchmarkable gradient(loss_by_taylordiff, $model, - $x))) -pinn_f = BenchmarkGroup("primal" => (@benchmarkable loss_by_finitediff($model, $x)), + $x, $ps, $st))) +pinn_f = BenchmarkGroup( + "primal" => (@benchmarkable loss_by_finitediff($model, $x, $ps, $st)), "gradient" => (@benchmarkable gradient($loss_by_finitediff, $model, - $x))) + $x, $ps, $st))) pinn = BenchmarkGroup(["vector", "physical"], "taylordiff" => pinn_t, "finitediff" => pinn_f) diff --git a/src/array.jl b/src/array.jl index 1ace577..504ec1f 100644 --- a/src/array.jl +++ b/src/array.jl @@ -96,3 +96,8 @@ function find_taylor(a::Array{<:Tuple{TaylorScalar{T, P}, Any}, N}, rest) where TaylorArray{P}(zeros(T, size(a))) end find_taylor(::Any, rest) = find_taylor(rest) + +# function Base.copyto!(dest::TaylorArray, bc::Broadcast.Broadcasted{<:TaylorArrayStyle, Axes}) where Axes +# println("copyto!($(typeof(dest)), $(typeof(bc)))") +# error("Not implemented") +# end diff --git a/src/chainrules.jl b/src/chainrules.jl index 142ae1c..def2ee3 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -17,20 +17,21 @@ function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T} end function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T} - value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(0, v̄) + z = zero(T) + partials_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(z, v̄) # for structural tangent, convert to tuple - function value_pullback(v̄::Tangent{P, NTuple{N, T}}) where {P} - NoTangent(), TaylorScalar(zero(T), backing(v̄)) + function partials_pullback(v̄::Tangent{P, NTuple{N, T}}) where {P} + NoTangent(), TaylorScalar(z, backing(v̄)) end - function value_pullback(v̄) - NoTangent(), TaylorScalar(zero(T), map(x -> convert(T, x), Tuple(v̄))) + function partials_pullback(::ZeroTangent) + NoTangent(), TaylorScalar(z, ntuple(j -> zero(T), Val(N))) end - return partials(t), value_pullback + return partials(t), partials_pullback end function rrule(::typeof(partials), t::TaylorArray{T, N, A, P}) where {N, T, A, P} - value_pullback(v̄::NTuple{P, A}) = NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄) - return partials(t), value_pullback + partials_pullback(v̄::NTuple{P, A}) = NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄) + return partials(t), partials_pullback end function rrule(::typeof(extract_derivative), t::TaylorScalar{T, P}, diff --git a/src/primitive.jl b/src/primitive.jl index aeb18ed..8ade8f2 100644 --- a/src/primitive.jl +++ b/src/primitive.jl @@ -48,7 +48,6 @@ end f = flatten(t) v[0] = exp(f[0]) for i in 1:P - v[i] = zero(T) for j in 0:(i - 1) v[i] += (i - j) * v[j] * f[i - j] end @@ -62,8 +61,6 @@ for func in (:sin, :cos) f = flatten(t) s[0], c[0] = sincos(f[0]) for i in 1:P - s[i] = zero(T) - c[i] = zero(T) for j in 0:(i - 1) s[i] += (i - j) * c[j] * f[i - j] c[i] -= (i - j) * s[j] * f[i - j] @@ -107,7 +104,6 @@ end @immutable function *(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P} va, vb = flatten(a), flatten(b) for i in 0:P - v[i] = zero(T) for j in 0:i v[i] += va[j] * vb[i - j] end @@ -140,7 +136,6 @@ for R in (Integer, Real) f = flatten(t) v[0] = f[0]^n for i in 1:P - v[i] = zero(T) for j in 0:(i - 1) v[i] += (n * (i - j) - j) * v[j] * f[i - j] end @@ -164,3 +159,9 @@ end @inline raise(f0, d::TaylorScalar, t) = integrate(differentiate(t) * d, f0) @inline raise(f0, d::Number, t) = d * t @inline raiseinv(f0, d, t) = integrate(differentiate(t) / d, f0) + +# Array primitives + +# Pass-through linear operators + +# *(a::AbstractMatrix{T}, b::TaylorArray{T}) where {T} = TaylorArray(a * value(b), map(p -> a * p, partials(b))) diff --git a/src/utils.jl b/src/utils.jl index 3e55d7c..20100c4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -86,9 +86,12 @@ function process(d, expr) end # Modify indices magic_names = (:v, :s, :c) + known_names = Set() expr = postwalk(expr) do x @match x begin a_[idx_] => a in magic_names ? Symbol(a, idx) : :($a[begin + $idx]) + (a_ += b_) => a in known_names ? :($a += $b) : (push!(known_names, a); :($a = $b)) + (a_ -= b_) => a in known_names ? :($a -= $b) : (push!(known_names, a); :($a = -$b)) TaylorScalar(v_) => :(TaylorScalar(tuple($([Symbol(v, idx) for idx in 0:d[:P]]...)))) _ => x end