Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant frule definitions #63

Merged
merged 1 commit into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRules = "1"
ChainRulesCore = "1"
ChainRulesOverloadGeneration = "0.1"
IrrationalConstants = "0.2"
SliceMap = "0.2"
SpecialFunctions = "2"
IrrationalConstants = "0.2"
SymbolicUtils = "1"
Zygote = "0.6.55"
julia = "1.6"
35 changes: 6 additions & 29 deletions src/codegen.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,13 @@
using ChainRules
using ChainRulesCore
using SpecialFunctions
using IrrationalConstants: sqrtπ
using Symbolics: @variables
using SymbolicUtils, SymbolicUtils.Code
using SymbolicUtils: BasicSymbolic, Pow

@scalar_rule +(x::BasicSymbolic) true
@scalar_rule -(x::BasicSymbolic) -1
@scalar_rule deg2rad(x::BasicSymbolic) deg2rad(one(x))
@scalar_rule rad2deg(x::BasicSymbolic) rad2deg(one(x))
@scalar_rule asin(x::BasicSymbolic) inv(sqrt(1 - x^2))
@scalar_rule acos(x::BasicSymbolic) inv(-sqrt(1 - x^2))
@scalar_rule atan(x::BasicSymbolic) inv(-(1 + x^2))
@scalar_rule acot(x::BasicSymbolic) inv(-(1 + x^2))
@scalar_rule acsc(x::BasicSymbolic) inv(x^2 * -sqrt(1 - x^-2))
@scalar_rule asec(x::BasicSymbolic) inv(x^2 * sqrt(1 - x^-2))
@scalar_rule log(x::BasicSymbolic) inv(x)
@scalar_rule log10(x::BasicSymbolic) inv(log(10.0) * x)
@scalar_rule log1p(x::BasicSymbolic) inv(x + 1)
@scalar_rule log2(x::BasicSymbolic) inv(log(2.0) * x)
@scalar_rule sinh(x::BasicSymbolic) cosh(x)
@scalar_rule cosh(x::BasicSymbolic) sinh(x)
@scalar_rule tanh(x::BasicSymbolic) 1-Ω^2
@scalar_rule acosh(x::BasicSymbolic) inv(sqrt(x - 1) * sqrt(x + 1))
@scalar_rule acoth(x::BasicSymbolic) inv(1 - x^2)
@scalar_rule acsch(x::BasicSymbolic) inv(x^2 * -sqrt(1 + x^-2))
@scalar_rule asech(x::BasicSymbolic) inv(x * -sqrt(1 - x^2))
@scalar_rule asinh(x::BasicSymbolic) inv(sqrt(x^2 + 1))
@scalar_rule atanh(x::BasicSymbolic) inv(1 - x^2)
@scalar_rule erf(x::BasicSymbolic) exp(-x^2) * 2/sqrtπ
using SymbolicUtils: Pow

dummy = (NoTangent(), 1)
@syms t₁
@variables z
for func in (+, -, deg2rad, rad2deg,
sinh, cosh, tanh,
asin, acos, atan, asec, acsc, acot,
Expand All @@ -43,15 +20,15 @@ for func in (+, -, deg2rad, rad2deg,
t0, t1 = value(t)
TaylorScalar{T, 2}(frule((NoTangent(), t1), op, t0))
end
der = frule(dummy, func, t₁)[2]
der = frule(dummy, func, z)[2]
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
# recursion by raising
@eval @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
der_expr = $(QuoteNode(toexpr(term)))
f = $func
quote
$(Expr(:meta, :inline))
t₁ = TaylorScalar{T, N - 1}(t)
z = TaylorScalar{T, N - 1}(t)
df = $der_expr
$$raiser($f(value(t)[1]), df, t)
end
Expand Down