Skip to content

Commit

Permalink
add AD (Enzyme) support via MeshIntegralsEnzymeExt (#152)
Browse files Browse the repository at this point in the history
* add Enzyme as a potential differentiation method for the jacobian

* refactor check for enzyme support

* add FP to _default_diff_method

* add `using Enzyme` to benchmarks.jl

* update CoordRefSystems.jl compat

Co-authored-by: Joshua Lampert <[email protected]>

* add Enzyme to Benchmark Project.toml

* fix Meshes compat in Benchmark Project.toml

Co-authored-by: Joshua Lampert <[email protected]>

* use import Enzyme, not using Enzyme

Co-authored-by: Joshua Lampert <[email protected]>

* fix typo in Benchmarks Project.toml

* remove Meshes version check in combinations.jl

* Apply format suggestion

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/Project.toml

Co-authored-by: Joshua Lampert <[email protected]>

* Bump compat of Enzyme to v0.13.19

* test supports_autoenzyme to combinations; test both backends for wrong dims

* Restore recently-updated FiniteDifference constructors

* Add docstrings, formatting

* Formatting

* Add test for two-arg jacobian

* Use rest of MeshIntegrals namespace

* Disambiguate use of jacobian

* fix test

* use `import Enzyme`

Co-authored-by: Joshua Lampert <[email protected]>

* use `import Enzyme`

Co-authored-by: Joshua Lampert <[email protected]>

* remove unneeded MeshIntegrals.jl

Co-authored-by: Joshua Lampert <[email protected]>

---------

Co-authored-by: Joshua Lampert <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Michael Ingold <[email protected]>
  • Loading branch information
4 people authored Dec 14, 2024
1 parent c7c0a47 commit 969ee0a
Show file tree
Hide file tree
Showing 15 changed files with 175 additions and 68 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml

# development related
.vscode
dev
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,20 @@ Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
MeshIntegralsEnzymeExt = "Enzyme"

[compat]
CliffordNumbers = "0.1.9"
CoordRefSystems = "0.12, 0.13, 0.14, 0.15, 0.16"
CoordRefSystems = "0.15, 0.16"
Enzyme = "0.13.19"
FastGaussQuadrature = "1"
HCubature = "1.5"
LinearAlgebra = "1"
Meshes = "0.50, 0.51, 0.52"
Meshes = "0.51.20, 0.52"
QuadGK = "2.1.1"
Unitful = "1.19"
julia = "1.9"
4 changes: 3 additions & 1 deletion benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
BenchmarkTools = "1.5"
Enzyme = "0.13.19"
LinearAlgebra = "1"
Meshes = "0.50, 0.51, 0.52"
Meshes = "0.51.20, 0.52"
Unitful = "1.19"
julia = "1.9"
1 change: 1 addition & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using LinearAlgebra
using Meshes
using MeshIntegrals
using Unitful
import Enzyme

const SUITE = BenchmarkGroup()

Expand Down
19 changes: 19 additions & 0 deletions ext/MeshIntegralsEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module MeshIntegralsEnzymeExt

using MeshIntegrals: MeshIntegrals, AutoEnzyme
using Meshes: Meshes
using Enzyme: Enzyme

function MeshIntegrals.jacobian(
geometry::Meshes.Geometry,
ts::Union{AbstractVector{T}, Tuple{T, Vararg{T}}},
::AutoEnzyme
) where {T <: AbstractFloat}
Dim = Meshes.paramdim(geometry)
if Dim != length(ts)
throw(ArgumentError("ts must have same number of dimensions as geometry."))
end
return Meshes.to.(Enzyme.jacobian(Enzyme.Forward, geometry, ts...))
end

end
2 changes: 1 addition & 1 deletion src/MeshIntegrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import QuadGK
import Unitful

include("differentiation.jl")
export DifferentiationMethod, FiniteDifference, jacobian
export DifferentiationMethod, FiniteDifference, AutoEnzyme, jacobian

include("utils.jl")

Expand Down
16 changes: 11 additions & 5 deletions src/differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ A category of types used to specify the desired method for calculating derivativ
Derivatives are used to form Jacobian matrices when calculating the differential
element size throughout the integration region.
See also [`FiniteDifference`](@ref).
See also [`FiniteDifference`](@ref), [`AutoEnzyme`](@ref).
"""
abstract type DifferentiationMethod end

Expand All @@ -27,8 +27,14 @@ end
FiniteDifference{T}() where {T <: AbstractFloat} = FiniteDifference{T}(T(1e-6))
FiniteDifference() = FiniteDifference{Float64}()

"""
AutoEnzyme()
Use to specify use of the Enzyme.jl for calculating derivatives.
"""
struct AutoEnzyme <: DifferentiationMethod end

# Future Support:
# struct AutoEnzyme <: DifferentiationMethod end
# struct AutoZygote <: DifferentiationMethod end

################################################################################
Expand All @@ -52,7 +58,7 @@ function jacobian(
geometry::G,
ts::Union{AbstractVector{T}, Tuple{T, Vararg{T}}}
) where {G <: Geometry, T <: AbstractFloat}
return jacobian(geometry, ts, _default_diff_method(G))
return jacobian(geometry, ts, _default_diff_method(G, T))
end

function jacobian(
Expand All @@ -68,7 +74,7 @@ function jacobian(
# Get the partial derivative along the n'th axis via finite difference
# approximation, where ts is the current parametric position
function ∂ₙr(ts, n, ε)
# Build left/right parametric coordinates with non-allocating iterators
# Build left/right parametric coordinates with non-allocating iterators
left = Iterators.map(((i, t),) -> i == n ? t - ε : t, enumerate(ts))
right = Iterators.map(((i, t),) -> i == n ? t + ε : t, enumerate(ts))
# Select orientation of finite-diff
Expand Down Expand Up @@ -107,7 +113,7 @@ possible and finite difference approximations otherwise.
function differential(
geometry::G,
ts::Union{AbstractVector{T}, Tuple{T, Vararg{T}}},
diff_method::DifferentiationMethod = _default_diff_method(G)
diff_method::DifferentiationMethod = _default_diff_method(G, T)
) where {G <: Geometry, T <: AbstractFloat}
J = Iterators.map(_KVector, jacobian(geometry, ts, diff_method))
return LinearAlgebra.norm(foldl(, J))
Expand Down
16 changes: 11 additions & 5 deletions src/integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
################################################################################

"""
integral(f, geometry[, rule]; diff_method=_default_method(geometry), FP=Float64)
integral(f, geometry[, rule]; diff_method=_default_diff_method(geometry, FP), FP=Float64)
Numerically integrate a given function `f(::Point)` over the domain defined by
a `geometry` using a particular numerical integration `rule` with floating point
Expand All @@ -16,7 +16,7 @@ precision of type `FP`.
`GaussKronrod()` in 1D and `HAdaptiveCubature()` else)
# Keyword Arguments
- `diff_method::DifferentiationMethod = _default_method(geometry)`: the method to
- `diff_method::DifferentiationMethod = _default_diff_method(geometry, FP)`: the method to
use for calculating Jacobians that are used to calculate differential elements
- `FP = Float64`: the floating point precision desired.
"""
Expand All @@ -42,8 +42,10 @@ function _integral(
geometry,
rule::GaussKronrod;
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(geometry)
diff_method::DM = _default_diff_method(geometry, FP)
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(geometry, diff_method)

# Implementation depends on number of parametric dimensions over which to integrate
N = Meshes.paramdim(geometry)
if N == 1
Expand All @@ -70,8 +72,10 @@ function _integral(
geometry,
rule::GaussLegendre;
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(geometry)
diff_method::DM = _default_diff_method(geometry, FP)
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(geometry, diff_method)

N = Meshes.paramdim(geometry)

# Get Gauss-Legendre nodes and weights of type FP for a region [-1,1]ᴺ
Expand Down Expand Up @@ -99,8 +103,10 @@ function _integral(
geometry,
rule::HAdaptiveCubature;
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(geometry)
diff_method::DM = _default_diff_method(geometry, FP)
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(geometry, diff_method)

N = Meshes.paramdim(geometry)

integrand(ts) = f(geometry(ts...)) * differential(geometry, ts, diff_method)
Expand Down
8 changes: 6 additions & 2 deletions src/specializations/BezierCurve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@ function integral(
curve::Meshes.BezierCurve,
rule::IntegrationRule;
alg::Meshes.BezierEvalMethod = Meshes.Horner(),
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(curve, FP),
kwargs...
)
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(curve, diff_method)

# Generate a _ParametricGeometry whose parametric function auto-applies the alg kwarg
param_curve = _ParametricGeometry(_parametric(curve, alg), Meshes.paramdim(curve))

# Integrate the _ParametricGeometry using the standard methods
return _integral(f, param_curve, rule; kwargs...)
return _integral(f, param_curve, rule; diff_method = diff_method, FP = FP, kwargs...)
end

################################################################################
Expand Down
12 changes: 8 additions & 4 deletions src/specializations/CylinderSurface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,22 @@ function integral(
f,
cyl::Meshes.CylinderSurface,
rule::I;
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(cyl, FP),
kwargs...
) where {I <: IntegrationRule}
) where {I <: IntegrationRule, DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(cyl, diff_method)

# The generic method only parametrizes the sides
sides = _integral(f, cyl, rule; kwargs...)
sides = _integral(f, cyl, rule; diff_method = diff_method, FP = FP, kwargs...)

# Integrate the Disk at the top
disk_top = Meshes.Disk(cyl.top, cyl.radius)
top = _integral(f, disk_top, rule; kwargs...)
top = _integral(f, disk_top, rule; diff_method = diff_method, FP = FP, kwargs...)

# Integrate the Disk at the bottom
disk_bottom = Meshes.Disk(cyl.bot, cyl.radius)
bottom = _integral(f, disk_bottom, rule; kwargs...)
bottom = _integral(f, disk_bottom, rule; diff_method = diff_method, FP = FP, kwargs...)

return sides + top + bottom
end
47 changes: 41 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,50 @@ end
# DifferentiationMethod
################################################################################

# Return the default DifferentiationMethod instance for a particular geometry type
"""
supports_autoenzyme(geometry)
Return whether a geometry (or geometry type) has a parametric function that can be
differentiated with Enzyme. See GitHub Issue #154 for more information.
"""
supports_autoenzyme(::Type{<:Meshes.Geometry}) = true
supports_autoenzyme(::Type{<:Meshes.BezierCurve}) = false
supports_autoenzyme(::Type{<:Meshes.CylinderSurface}) = false
supports_autoenzyme(::Type{<:Meshes.Cylinder}) = false
supports_autoenzyme(::Type{<:Meshes.ParametrizedCurve}) = false
supports_autoenzyme(::G) where {G <: Geometry} = supports_autoenzyme(G)

"""
_check_diff_method_support(::Geometry, ::DifferentiationMethod) -> nothing
Throw an error if incompatible geometry-diff_method combination detected.
"""
_check_diff_method_support(::Geometry, ::DifferentiationMethod) = nothing
function _check_diff_method_support(geometry::Geometry, ::AutoEnzyme)
if !supports_autoenzyme(geometry)
throw(ArgumentError("AutoEnzyme not supported for this geometry."))
end
end

"""
_default_diff_method(geometry, FP)
Return an instance of the default DifferentiationMethod for a particular geometry
(or geometry type) and floating point type.
"""
function _default_diff_method(
g::Type{G}
) where {G <: Geometry}
return FiniteDifference()
g::Type{G}, FP::Type{T}
) where {G <: Geometry, T <: AbstractFloat}
if supports_autoenzyme(g) && FP <: Union{Float32, Float64}
AutoEnzyme()
else
FiniteDifference()
end
end

# Return the default DifferentiationMethod instance for a particular geometry instance
_default_diff_method(g::G) where {G <: Geometry} = _default_diff_method(G)
function _default_diff_method(::G, ::Type{T}) where {G <: Geometry, T <: AbstractFloat}
_default_diff_method(G, T)
end

################################################################################
# Numerical Tools
Expand Down
6 changes: 4 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CoordRefSystems = "b46f11dc-f210-4604-bfba-323c1ec968cb"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
Expand All @@ -12,10 +13,11 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
Aqua = "0.7, 0.8"
CoordRefSystems = "0.12, 0.13, 0.14, 0.15, 0.16"
CoordRefSystems = "0.15, 0.16"
Enzyme = "0.13.19"
ExplicitImports = "1.6.0"
LinearAlgebra = "1"
Meshes = "0.50, 0.51, 0.52"
Meshes = "0.51.20, 0.52"
SpecialFunctions = "2"
TestItemRunner = "1"
TestItems = "1"
Expand Down
Loading

0 comments on commit 969ee0a

Please sign in to comment.