diff --git a/src/specialfunctions.jl b/src/specialfunctions.jl index c0a5e11..1499e9e 100644 --- a/src/specialfunctions.jl +++ b/src/specialfunctions.jl @@ -1,4 +1,4 @@ -import SpecialFunctions: airy, airyai, airyaiprime, airyaiprimex, airyaix, airybi, airybiprime, airybiprimex, airybix, airyprime, airyx, besselh, besselhx, besseli, besselix, besselj, besselj0, besselj1, besseljx, besselk, besselkx, bessely, bessely0, bessely1, besselyx, beta, cosint, dawson, digamma, erf, erfc, erfcinv, erfcx, erfi, erfinv, eta, gamma, hankelh1, hankelh1x, hankelh2, hankelh2x, invdigamma, lbeta, lfact, lfactorial, lgamma, logabsgamma, loggamma, polygamma, sinint, trigamma, zeta +import SpecialFunctions: airy, airyai, airyaiprime, airyaiprimex, airyaix, airybi, airybiprime, airybiprimex, airybix, airyprime, airyx, besselh, besselhx, besseli, besselix, besselj, besselj0, besselj1, besseljx, besselk, besselkx, bessely, bessely0, bessely1, besselyx, beta, cosint, dawson, digamma, erf, erfc, erfcinv, erfcx, erfi, erfinv, eta, gamma, hankelh1, hankelh1x, hankelh2, hankelh2x, invdigamma, lbeta, lfact, lfactorial, polygamma, sinint, trigamma, zeta # , lgamma, logabsgamma, loggamma # `airy(k,x)` is deprecated, use `airyai(x)`, `airyaiprime(x)`, `airybi(x)` or `airybiprime(x)` instead. @primitive airyai(x),dy (dy.*(airyaiprime.(x))) @@ -47,7 +47,17 @@ import SpecialFunctions: airy, airyai, airyaiprime, airyaiprimex, airyaix, airyb # lfactorial # logabsgamma # `lgamma` is deprecated, use `(logabsgamma(x))[1]` instead. I use `loggamma` which throws a DomainError if gamma(x) is negative. -# @primitive lgamma(x),dy,y (dy.*(digamma.(x))) +# `lgamma` deprecated, using loggamma. +# TODO: remove this once everybody uses loggamma and SpecialFunctions 0.8+ +if !isdefined(SpecialFunctions, :loggamma) && isdefined(SpecialFunctions, :lgamma) + import SpecialFunctions: lgamma + loggamma(x) = lgamma(x) +end +if isdefined(SpecialFunctions, :loggamma) && !isdefined(SpecialFunctions, :lgamma) + import SpecialFunctions: loggamma, logabsgamma + lgamma(x) = loggamma(x) +end +@primitive lgamma(x),dy,y (dy.*(digamma.(x))) @primitive loggamma(x),dy,y (dy.*(digamma.(x))) @primitive polygamma(x1,x2),dy,y nothing unbroadcast(x2,dy.*polygamma(x1+1,x2)) # sinint diff --git a/test/specialfunctions.jl b/test/specialfunctions.jl index 8045c06..d3bc9bc 100644 --- a/test/specialfunctions.jl +++ b/test/specialfunctions.jl @@ -1,9 +1,10 @@ include("header.jl") using SpecialFunctions +using AutoGrad: lgamma, loggamma # TODO: delete after everyone has SpecialFunctions 0.8 @testset "specialfunctions" begin - o = (:delta=>0.0001,:rtol=>0.01,:atol=>0.01) + o = (:delta=>0.0001,:rtol=>0.02,:atol=>0.01) ϵ = 0.1 val_0_2(x)=rand() * (2-2ϵ) + ϵ val_gt_0(x)=abs(x) + ϵ