Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow non-constant likelihoods #105

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
19 changes: 15 additions & 4 deletions src/expectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,16 @@ function expected_loglikelihood(
mc::MonteCarloExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
# take `n_samples` reparameterised samples
f_μ = mean.(q_f)
fs = f_μ .+ std.(q_f) .* randn(eltype(f_μ), length(q_f), mc.n_samples)
lls = loglikelihood.(lik.(fs), y)
r = randn(typeof(mean(first(q_f))), length(q_f), mc.n_samples)
lls = _mc_exp_loglikelihood_kernel.(_maybe_ref(lik), q_f, y, r)
return sum(lls) / mc.n_samples
end

function _mc_exp_loglikelihood_kernel(lik, q_f, y, r)
f = mean(q_f) + std(q_f) * r
return loglikelihood(lik(f), y)
end

# Compute the expected_loglikelihood over a collection of observations and marginal distributions
function expected_loglikelihood(
gh::GaussHermiteExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
Expand All @@ -92,14 +96,21 @@ function expected_loglikelihood(
# type stable. Compared to other type stable implementations, e.g.
# using a custom two-argument pairwise sum, this is faster to
# differentiate using Zygote.
A = loglikelihood.(lik.(sqrt2 .* std.(q_f) .* gh.xs' .+ mean.(q_f)), y) .* gh.ws'
A = _gh_exp_loglikelihood_kernel.(_maybe_ref(lik), q_f, y, gh.xs', gh.ws')
return invsqrtπ * sum(A)
end

function _gh_exp_loglikelihood_kernel(lik, q_f, y, x, w)
return loglikelihood(lik(sqrt2 * std(q_f) * x + mean(q_f)), y) * w
end

function expected_loglikelihood(
::AnalyticExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
return error(
"No analytic solution exists for $(typeof(lik)). Use `DefaultExpectationMethod`, `GaussHermiteExpectation` or `MonteCarloExpectation` instead.",
)
end

_maybe_ref(lik) = Ref(lik)
_maybe_ref(liks::AbstractArray) = liks
14 changes: 12 additions & 2 deletions src/likelihoods/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,18 @@ function expected_loglikelihood(
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
f_μ = mean.(q_f)
return sum(-f_μ - y .* exp.((var.(q_f) / 2) .- f_μ))
return sum(_exp_exp_loglikelihood_kernel.(q_f, y))
end

function expected_loglikelihood(
::AnalyticExpectation,
::AbstractVector{<:ExponentialLikelihood{ExpLink}},
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
return sum(_exp_exp_loglikelihood_kernel.(q_f, y))
end

_exp_exp_loglikelihood_kernel(q_f, y) = -mean(q_f) - y * exp((var(q_f) / 2) - mean(q_f))

default_expectation_method(::ExponentialLikelihood{ExpLink}) = AnalyticExpectation()
20 changes: 15 additions & 5 deletions src/likelihoods/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,21 @@ function expected_loglikelihood(
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
f_μ = mean.(q_f)
return sum(
(lik.α - 1) * log.(y) .- y .* exp.((var.(q_f) / 2) .- f_μ) .- lik.α * f_μ .-
loggamma(lik.α),
)
return sum(_gamma_exp_loglikelihood_kernel.(lik.α, q_f, y))
end

function expected_loglikelihood(
::AnalyticExpectation,
liks::AbstractVector{<:GammaLikelihood{ExpLink}},
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
return sum(_gamma_exp_loglikelihood_kernel.(getfield.(liks, :α), q_f, y))
end

function _gamma_exp_loglikelihood_kernel(α, q_f, y)
return (α - 1) * log(y) - y * exp((var(q_f) / 2) - mean(q_f)) - α * mean(q_f) -
loggamma(α)
simsurace marked this conversation as resolved.
Show resolved Hide resolved
end

default_expectation_method(::GammaLikelihood{ExpLink}) = AnalyticExpectation()
17 changes: 14 additions & 3 deletions src/likelihoods/gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,20 @@ function expected_loglikelihood(
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
return sum(
-0.5 * (log(2π) .+ log.(lik.σ²) .+ ((y .- mean.(q_f)) .^ 2 .+ var.(q_f)) / lik.σ²)
)
return sum(_gaussian_exp_loglikelihood_kernel.(lik.σ², q_f, y))
end

function expected_loglikelihood(
::AnalyticExpectation,
liks::AbstractVector{<:GaussianLikelihood},
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
return sum(_gaussian_exp_loglikelihood_kernel.(only.(getfield.(liks, :σ²)), q_f, y))
end

function _gaussian_exp_loglikelihood_kernel(σ², q_f, y)
return -0.5 * (log(2π) + log(σ²) + ((y - mean(q_f))^2 + var(q_f)) / σ²)
end

default_expectation_method(::GaussianLikelihood) = AnalyticExpectation()
Expand Down
16 changes: 14 additions & 2 deletions src/likelihoods/poisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,20 @@ function expected_loglikelihood(
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
f_μ = mean.(q_f)
return sum((y .* f_μ) - exp.(f_μ .+ (var.(q_f) / 2)) - loggamma.(y .+ 1))
return sum(_poisson_exp_loglikelihood_kernel.(q_f, y))
end

function expected_loglikelihood(
::AnalyticExpectation,
::AbstractArray{<:PoissonLikelihood{ExpLink}},
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
return sum(_poisson_exp_loglikelihood_kernel.(q_f, y))
end

function _poisson_exp_loglikelihood_kernel(q_f, y)
return (y * mean(q_f)) - exp(mean(q_f) + (var(q_f) / 2)) - loggamma(y + 1)
end

default_expectation_method(::PoissonLikelihood{ExpLink}) = AnalyticExpectation()
24 changes: 24 additions & 0 deletions test/expectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
m.lik for m in implementation_types if
m.quadrature == GPLikelihoods.AnalyticExpectation && m.lik != Any
]
filter!(x -> !(x <: AbstractArray), analytic_likelihoods)
for lik_type in analytic_likelihoods
lik_type_instances = filter(lik -> isa(lik, lik_type), likelihoods_to_test)
@test !isempty(lik_type_instances)
Expand Down Expand Up @@ -120,4 +121,27 @@
)
@test isfinite(glogα)
end

@testset "non-constant likelihood" begin
@testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test
liks = fill(lik, 10)
# Test that the various methods of computing expectations return the same
# result.
methods = [
GaussHermiteExpectation(100),
MonteCarloExpectation(1e7),
GPLikelihoods.DefaultExpectationMethod(),
]
def = GPLikelihoods.default_expectation_method(lik)
if def isa GPLikelihoods.AnalyticExpectation
push!(methods, def)
end
y = [rand(rng, lik(0.)) for lik in liks]
simsurace marked this conversation as resolved.
Show resolved Hide resolved

results = map(
m -> GPLikelihoods.expected_loglikelihood(m, liks, q_f, y), methods
)
@test all(x -> isapprox(x, results[end]; atol=1e-6, rtol=1e-3), results)
end
end
end
Loading