Skip to content

Commit

Permalink
Merge pull request #1112 from SciML/dw/totallength
Browse files Browse the repository at this point in the history
Import `totallength` in ForwardDiff extension
  • Loading branch information
ChrisRackauckas authored Feb 10, 2025
2 parents b5faf40 + 907508a commit 1a870e8
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqBase"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
authors = ["Chris Rackauckas <[email protected]>"]
version = "6.162.1"
version = "6.162.2"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
15 changes: 4 additions & 11 deletions ext/DiffEqBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,19 +501,19 @@ unitfulvalue(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V
unitfulvalue(x::ForwardDiff.Dual) = unitfulvalue(ForwardDiff.unitfulvalue(x))

sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x))
function totallength(x::ForwardDiff.Dual)
totallength(ForwardDiff.value(x)) + sum(totallength, ForwardDiff.partials(x))
function DiffEqBase.totallength(x::ForwardDiff.Dual)
return DiffEqBase.totallength(ForwardDiff.value(x)) + sum(DiffEqBase.totallength, ForwardDiff.partials(x))
end

@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual, ::Any) = sqrt(sse(u))
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{Tag, T}},
t::Any) where {Tag, T}
sqrt(__sum(sse, u; init = sse(zero(T))) / totallength(u))
sqrt(DiffEqBase.__sum(sse, u; init = sse(zero(T))) / DiffEqBase.totallength(u))
end
@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual, ::ForwardDiff.Dual) = sqrt(sse(u))
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{Tag, T}},
::ForwardDiff.Dual) where {Tag, T}
sqrt(__sum(sse, u; init = sse(zero(T))) / totallength(u))
sqrt(DiffEqBase.__sum(sse, u; init = sse(zero(T))) / DiffEqBase.totallength(u))
end

if !hasmethod(nextfloat, Tuple{ForwardDiff.Dual})
Expand All @@ -528,13 +528,6 @@ end

# bisection(f, tup::Tuple{T,T}, t_forward::Bool) where {T<:ForwardDiff.Dual} = find_zero(f, tup, Roots.AlefeldPotraShi())

# Static Arrays don't support the `init` keyword argument for `sum`
@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...)
@inline function __sum(
f::F, a::DiffEqBase.StaticArraysCore.StaticArray...; init, kwargs...) where {F}
return mapreduce(f, +, a...; init, kwargs...)
end

# Differentiation of internal solver

function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...)
Expand Down
1 change: 0 additions & 1 deletion ext/DiffEqBaseReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using DiffEqBase
import DiffEqBase: value
import ReverseDiff
import DiffEqBase.ArrayInterface
import DiffEqBase.ForwardDiff

function DiffEqBase.anyeltypedual(::Type{T},
::Type{Val{counter}} = Val{0}) where {counter} where {
Expand Down
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ unitfulvalue(x) = x
isdistribution(u0) = false
sse(x::Number) = abs2(x)

# Static Arrays don't support the `init` keyword argument for `sum`
@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...)
@inline function __sum(
f::F, a::StaticArraysCore.StaticArray...; init, kwargs...) where {F}
return mapreduce(f, +, a...; init, kwargs...)
end

totallength(x::Number) = 1
totallength(x::AbstractArray) = __sum(totallength, x; init = 0)

Expand Down
10 changes: 10 additions & 0 deletions test/forwarddiff_dual_detection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,13 @@ end
ReverseDiff.TrackedReal{<:ForwardDiff.Dual}
@test DiffEqBase.promote_u0(NaN, [NaN], 0.0) isa Float64
@test DiffEqBase.promote_u0([1.0], [NaN], 0.0) isa Vector{Float64}

# totallength
val = rand(10)
par = rand(10)
u = Dual.(val, par)
@test DiffEqBase.totallength(val[1]) == 1
@test DiffEqBase.totallength(val) == length(val)
@test DiffEqBase.totallength(par) == length(par)
@test DiffEqBase.totallength(u[1]) == DiffEqBase.totallength(val[1]) + DiffEqBase.totallength(par[1])
@test DiffEqBase.totallength(u) == sum(DiffEqBase.totallength, u)

0 comments on commit 1a870e8

Please sign in to comment.