Skip to content

Commit

Permalink
Moved to DI.jl for AD interface
Browse files Browse the repository at this point in the history
  • Loading branch information
cfarm6 committed Jul 18, 2024
1 parent d959a70 commit c08b2ed
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 147 deletions.
18 changes: 5 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,17 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
HyperelasticsEnzymeExt = "Enzyme"
HyperelasticsFastDifferentiationExt = "FastDifferentiation"
HyperelasticsFiniteDiffExt = "FiniteDiff"
HyperelasticsFiniteDifferences = "FiniteDifferences"
HyperelasticsForwardDiffExt = "ForwardDiff"
HyperelasticsDifferentiationInterfaceExt = "DifferentiationInterface"
HyperelasticsOptimizationExt = "Optimization"
HyperelasticsZygoteExt = "Zygote"


[compat]
ADTypes = "1"
Expand All @@ -70,6 +61,7 @@ julia = "1.9"
[extras]
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Expand All @@ -84,4 +76,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "ComponentArrays", "ForwardDiff", "InteractiveUtils", "FiniteDiff", "Zygote", "Enzyme", "Optimization", "OptimizationOptimJL"]
test = ["Test", "ComponentArrays", "ForwardDiff", "InteractiveUtils", "FiniteDiff", "Zygote", "Enzyme", "Optimization", "OptimizationOptimJL", "DifferentiationInterface"]
1 change: 1 addition & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
Expand Down
18 changes: 8 additions & 10 deletions ext/HyperelasticsDifferentiationInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@ module HyperelasticsDifferentiationInterfaceExt

using DifferentiationInterface
using Hyperelastics
using ADTypes.AbstractADType
using ADTypes

function Hyperelastics.∂ψ(
ψ::Hyperelastics.AbstractHyperelasticModel{R},
λ⃗::Vector{S},
λ⃗::Vector{T},
p,
ad_type <: AbstractADType;
ad_type::ADTypes.AbstractADType;
kwargs...,
) where {R,T,S}

) where {R,T}
W(λ⃗) = StrainEnergyDensity(ψ, λ⃗, p)
∂W∂λ = gradient(W, ad_type, λ⃗; kwargs...)
return ∂W∂λ
end
return gradient(
(λ⃗::Vector{S}) -> ψ(λ⃗, p; kwargs...),
λ⃗,
ad_type,
)

end
19 changes: 0 additions & 19 deletions ext/HyperelasticsEnzymeExt.jl

This file was deleted.

21 changes: 0 additions & 21 deletions ext/HyperelasticsFiniteDiffExt.jl

This file was deleted.

18 changes: 0 additions & 18 deletions ext/HyperelasticsFiniteDifferences.jl

This file was deleted.

19 changes: 0 additions & 19 deletions ext/HyperelasticsForwardDiffExt.jl

This file was deleted.

45 changes: 0 additions & 45 deletions ext/HyperelasticsZygoteExt.jl

This file was deleted.

3 changes: 2 additions & 1 deletion test/model_fitting.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@testset "Model Fitting" begin
using Optimization, OptimizationOptimJL, ComponentArrays, ForwardDiff
using Optimization, OptimizationOptimJL, ComponentArrays, DifferentiationInterface
using ForwardDiff
# Determine if the model is exported by hyperelastics.
usemodel(model) = Base.isexported(Hyperelastics, Symbol(model))

Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ using LinearAlgebra

using Hyperelastics
using Test
using ForwardDiff, FiniteDiff, Zygote, Enzyme
using DifferentiationInterface
import ForwardDiff, FiniteDiff, Zygote, Enzyme
using InteractiveUtils
using ADTypes

Expand Down

0 comments on commit c08b2ed

Please sign in to comment.