Skip to content

Commit

Permalink
Improve macro
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Nov 5, 2024
1 parent aeeceef commit ce4b82d
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 57 deletions.
70 changes: 53 additions & 17 deletions benchmark/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.11.1"
manifest_format = "2.0"
project_hash = "448f3cae2dea59422644934cbad9b5490d5dc6eb"
project_hash = "89d6a775281b4cbd649c4e44a66577335ae263f3"

[[deps.ADTypes]]
git-tree-sha1 = "eea5d80188827b35333801ef97a40c2ed653b081"
Expand Down Expand Up @@ -57,9 +57,9 @@ version = "0.1.38"

[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
git-tree-sha1 = "d80af0733c99ea80575f612813fa6aa71022d33a"
git-tree-sha1 = "50c3c56a52972d78e8be9fd135bfb91c9574c140"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "4.1.0"
version = "4.1.1"
weakdeps = ["StaticArrays"]

[deps.Adapt.extensions]
Expand Down Expand Up @@ -221,6 +221,36 @@ deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.1+0"

[[deps.ComponentArrays]]
deps = ["ArrayInterface", "ChainRulesCore", "ForwardDiff", "Functors", "LinearAlgebra", "PackageExtensionCompat", "StaticArrayInterface", "StaticArraysCore"]
git-tree-sha1 = "bc391f0c19fa242fb6f71794b949e256cfa3772c"
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
version = "0.15.17"

[deps.ComponentArrays.extensions]
ComponentArraysAdaptExt = "Adapt"
ComponentArraysConstructionBaseExt = "ConstructionBase"
ComponentArraysGPUArraysExt = "GPUArrays"
ComponentArraysOptimisersExt = "Optimisers"
ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools"
ComponentArraysReverseDiffExt = "ReverseDiff"
ComponentArraysSciMLBaseExt = "SciMLBase"
ComponentArraysTrackerExt = "Tracker"
ComponentArraysTruncatedStacktracesExt = "TruncatedStacktraces"
ComponentArraysZygoteExt = "Zygote"

[deps.ComponentArrays.weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[[deps.CompositeTypes]]
git-tree-sha1 = "bce26c3dab336582805503bed209faab1c279768"
uuid = "b152e2b5-7a66-4b01-a709-34e65c35f657"
Expand Down Expand Up @@ -364,9 +394,9 @@ version = "1.0.4"

[[deps.Enzyme]]
deps = ["CEnum", "EnzymeCore", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "Preferences", "Printf", "Random", "SparseArrays"]
git-tree-sha1 = "aba39bfce6e65ce740b29c8d9d0c8a6c5770e3c1"
git-tree-sha1 = "abcbb722aafe8ed9cc667884b3a1e1d259c5e562"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
version = "0.13.12"
version = "0.13.13"

[deps.Enzyme.extensions]
EnzymeBFloat16sExt = "BFloat16s"
Expand All @@ -383,19 +413,19 @@ version = "0.13.12"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[[deps.EnzymeCore]]
git-tree-sha1 = "9c3a42611e525352e9ad5e4134ddca5c692ff209"
git-tree-sha1 = "04c777af6ef65530a96ab68f0a81a4608113aa1d"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
version = "0.8.4"
version = "0.8.5"
weakdeps = ["Adapt"]

[deps.EnzymeCore.extensions]
AdaptExt = "Adapt"

[[deps.Enzyme_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "c180391e0a09fedb2934e5c44455e13c38f859e6"
git-tree-sha1 = "62cf2140d8daa3181e9f9d7a8b5e7b9493a57f21"
uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef"
version = "0.0.157+0"
version = "0.0.159+0"

[[deps.ExceptionUnwrapping]]
deps = ["Test"]
Expand Down Expand Up @@ -442,9 +472,9 @@ version = "1.3.7"

[[deps.ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"]
git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad"
git-tree-sha1 = "a9ce73d3c827adab2d70bf168aaece8cce196898"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.36"
version = "0.10.37"
weakdeps = ["StaticArrays"]

[deps.ForwardDiff.extensions]
Expand Down Expand Up @@ -986,6 +1016,12 @@ git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.31"

[[deps.PackageExtensionCompat]]
git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518"
uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930"
version = "1.0.2"
weakdeps = ["Requires", "TOML"]

[[deps.Parsers]]
deps = ["Dates", "PrecompileTools", "UUIDs"]
git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821"
Expand Down Expand Up @@ -1088,9 +1124,9 @@ version = "1.3.4"

[[deps.RecursiveArrayTools]]
deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"]
git-tree-sha1 = "43cdc0987135597867a37fc3e8e0fc9fdef6ac66"
git-tree-sha1 = "ed2514425d030d7c9054fa0f2275ada45681788d"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
version = "3.27.1"
version = "3.27.2"

[deps.RecursiveArrayTools.extensions]
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
Expand Down Expand Up @@ -1152,9 +1188,9 @@ version = "0.1.0"

[[deps.SciMLBase]]
deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "Expronicon", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"]
git-tree-sha1 = "86e1c491cddf233d77d8aadbe289005db44e8445"
git-tree-sha1 = "7a54136472ca0cb0f66ef22aa3f0ff198f379fa7"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
version = "2.57.2"
version = "2.58.0"

[deps.SciMLBase.extensions]
SciMLBaseChainRulesCoreExt = "ChainRulesCore"
Expand Down Expand Up @@ -1374,9 +1410,9 @@ version = "3.7.2"

[[deps.Symbolics]]
deps = ["ADTypes", "ArrayInterface", "Bijections", "CommonWorldInvalidations", "ConstructionBase", "DataStructures", "DiffRules", "Distributions", "DocStringExtensions", "DomainSets", "DynamicPolynomials", "IfElse", "LaTeXStrings", "Latexify", "Libdl", "LinearAlgebra", "LogExpFunctions", "MacroTools", "Markdown", "NaNMath", "PrecompileTools", "Primes", "RecipesBase", "Reexport", "RuntimeGeneratedFunctions", "SciMLBase", "Setfield", "SparseArrays", "SpecialFunctions", "StaticArraysCore", "SymbolicIndexingInterface", "SymbolicLimits", "SymbolicUtils", "TermInterface"]
git-tree-sha1 = "ef7532b95fbd529e1252cabb36bba64803020840"
git-tree-sha1 = "41852067b437d16a3ad4e01705ffc6e22925c42c"
uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7"
version = "6.16.0"
version = "6.17.0"

[deps.Symbolics.extensions]
SymbolicsForwardDiffExt = "ForwardDiff"
Expand Down
1 change: 1 addition & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
Expand Down
69 changes: 42 additions & 27 deletions benchmark/groups/pinn.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,56 @@
using Lux, Zygote
using Lux, Zygote, Enzyme, ComponentArrays

const input = 2
const hidden = 16

model = Chain(Dense(input => hidden, Lux.relu),
Dense(hidden => hidden, Lux.relu),
Dense(hidden => 1),
first)

ps, st = Lux.setup(rng, model)

trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x, ps, st)[1]
function trial(model, x, ps, st)
u, st = Lux.apply(model, x, ps, st)
x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * u
end

x = rand(Float32, input)
trial(model, x)
function loss_by_finitediff(model, x)
ε = cbrt(eps(Float32))
ε₁ = [ε, 0]
ε₂ = [0, ε]
error = (trial(model, x + ε₁) + trial(model, x - ε₁) + trial(model, x + ε₂) +
trial(model, x - ε₂) - 4 * trial(model, x)) /
ε^2 + sin* x[1]) * sin* x[2])
function loss_by_finitediff(model, x, ps, st)
T = eltype(x)
ε = cbrt(eps(T))
ε₁ = [ε, zero(T)]
ε₂ = [zero(T), ε]
f(x) = trial(model, x, ps, st)
error = (f(x + ε₁) + f(x - ε₁) + f(x + ε₂) + f(x - ε₂) - 4 * f(x)) / ε^2 +
sin* x[1]) * sin* x[2])
abs2(error)
end
function loss_by_taylordiff(model, x)
f(x) = trial(model, x)
function loss_by_taylordiff(model, x, ps, st)
f(x) = trial(model, x, ps, st)
error = derivative(f, x, Float32[1, 0], Val(2)) +
derivative(f, x, Float32[0, 1], Val(2)) +
sin* x[1]) * sin* x[2])
abs2(error)
end
function loss_by_forwarddiff(model, x, ps, st)
f(x) = trial(model, x, ps, st)
error = derivative(f, x, Float32[1, 0], Val(2)) +
derivative(f, x, Float32[0, 1], Val(2)) +
sin* x[1]) * sin* x[2])
abs2(error)
end

const input = 2
const hidden = 16
model = Chain(Dense(input => hidden, exp),
Dense(hidden => hidden, exp),
Dense(hidden => 1),
first)
x = rand(Float32, input)
dx = deepcopy(x)
ps, st = Lux.setup(rng, model)
ps = ps |> ComponentArray
dps = deepcopy(ps)
dx .= 0;
dps .= 0;

pinn_t = BenchmarkGroup("primal" => (@benchmarkable loss_by_taylordiff($model, $x)),
pinn_t = BenchmarkGroup(
"primal" => (@benchmarkable loss_by_taylordiff($model, $x, $ps, $st)),
"gradient" => (@benchmarkable gradient(loss_by_taylordiff, $model,
$x)))
pinn_f = BenchmarkGroup("primal" => (@benchmarkable loss_by_finitediff($model, $x)),
$x, $ps, $st)))
pinn_f = BenchmarkGroup(
"primal" => (@benchmarkable loss_by_finitediff($model, $x, $ps, $st)),
"gradient" => (@benchmarkable gradient($loss_by_finitediff, $model,
$x)))
$x, $ps, $st)))
pinn = BenchmarkGroup(["vector", "physical"], "taylordiff" => pinn_t,
"finitediff" => pinn_f)
5 changes: 5 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,8 @@ function find_taylor(a::Array{<:Tuple{TaylorScalar{T, P}, Any}, N}, rest) where
TaylorArray{P}(zeros(T, size(a)))
end
find_taylor(::Any, rest) = find_taylor(rest)

# function Base.copyto!(dest::TaylorArray, bc::Broadcast.Broadcasted{<:TaylorArrayStyle, Axes}) where Axes
# println("copyto!($(typeof(dest)), $(typeof(bc)))")
# error("Not implemented")
# end
17 changes: 9 additions & 8 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,21 @@ function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
end

function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T}
value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(0, v̄)
z = zero(T)
partials_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(z, v̄)
# for structural tangent, convert to tuple
function value_pullback(v̄::Tangent{P, NTuple{N, T}}) where {P}
NoTangent(), TaylorScalar(zero(T), backing(v̄))
function partials_pullback(v̄::Tangent{P, NTuple{N, T}}) where {P}
NoTangent(), TaylorScalar(z, backing(v̄))
end
function value_pullback(v̄)
NoTangent(), TaylorScalar(zero(T), map(x -> convert(T, x), Tuple(v̄)))
function partials_pullback(::ZeroTangent)
NoTangent(), TaylorScalar(z, ntuple(j -> zero(T), Val(N)))
end
return partials(t), value_pullback
return partials(t), partials_pullback
end

function rrule(::typeof(partials), t::TaylorArray{T, N, A, P}) where {N, T, A, P}
value_pullback(v̄::NTuple{P, A}) = NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄)
return partials(t), value_pullback
partials_pullback(v̄::NTuple{P, A}) = NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄)
return partials(t), partials_pullback
end

function rrule(::typeof(extract_derivative), t::TaylorScalar{T, P},
Expand Down
11 changes: 6 additions & 5 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ end
f = flatten(t)
v[0] = exp(f[0])
for i in 1:P
v[i] = zero(T)
for j in 0:(i - 1)
v[i] += (i - j) * v[j] * f[i - j]
end
Expand All @@ -62,8 +61,6 @@ for func in (:sin, :cos)
f = flatten(t)
s[0], c[0] = sincos(f[0])
for i in 1:P
s[i] = zero(T)
c[i] = zero(T)
for j in 0:(i - 1)
s[i] += (i - j) * c[j] * f[i - j]
c[i] -= (i - j) * s[j] * f[i - j]
Expand Down Expand Up @@ -107,7 +104,6 @@ end
@immutable function *(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
va, vb = flatten(a), flatten(b)
for i in 0:P
v[i] = zero(T)
for j in 0:i
v[i] += va[j] * vb[i - j]
end
Expand Down Expand Up @@ -140,7 +136,6 @@ for R in (Integer, Real)
f = flatten(t)
v[0] = f[0]^n
for i in 1:P
v[i] = zero(T)
for j in 0:(i - 1)
v[i] += (n * (i - j) - j) * v[j] * f[i - j]
end
Expand All @@ -164,3 +159,9 @@ end
@inline raise(f0, d::TaylorScalar, t) = integrate(differentiate(t) * d, f0)
@inline raise(f0, d::Number, t) = d * t
@inline raiseinv(f0, d, t) = integrate(differentiate(t) / d, f0)

# Array primitives

# Pass-through linear operators

# *(a::AbstractMatrix{T}, b::TaylorArray{T}) where {T} = TaylorArray(a * value(b), map(p -> a * p, partials(b)))
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,12 @@ function process(d, expr)
end
# Modify indices
magic_names = (:v, :s, :c)
known_names = Set()
expr = postwalk(expr) do x
@match x begin
a_[idx_] => a in magic_names ? Symbol(a, idx) : :($a[begin + $idx])
(a_ += b_) => a in known_names ? :($a += $b) : (push!(known_names, a); :($a = $b))
(a_ -= b_) => a in known_names ? :($a -= $b) : (push!(known_names, a); :($a = -$b))
TaylorScalar(v_) => :(TaylorScalar(tuple($([Symbol(v, idx) for idx in 0:d[:P]]...))))
_ => x
end
Expand Down

0 comments on commit ce4b82d

Please sign in to comment.