From 04ecb4eec657c62edb957be1b1608bd98af2d9cc Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 8 Jan 2025 13:04:25 -0500 Subject: [PATCH 1/4] fixed type stability of linear filter --- src/algorithms/kalman.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/kalman.jl b/src/algorithms/kalman.jl index 2702b1f..825aedd 100644 --- a/src/algorithms/kalman.jl +++ b/src/algorithms/kalman.jl @@ -11,7 +11,7 @@ function initialise( rng::AbstractRNG, model::LinearGaussianStateSpaceModel, filter::KalmanFilter; kwargs... ) μ0, Σ0 = calc_initial(model.dyn; kwargs...) - return Gaussian(μ0, Σ0) + return Gaussian(μ0, Matrix(Σ0)) end function predict( From 152917b6a02c8f61238ffae22912493d02243ff9 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 8 Jan 2025 13:05:10 -0500 Subject: [PATCH 2/4] added MLE demonstration --- research/maximum_likelihood/Project.toml | 7 +++ research/maximum_likelihood/mle_demo.jl | 72 ++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 research/maximum_likelihood/Project.toml create mode 100644 research/maximum_likelihood/mle_demo.jl diff --git a/research/maximum_likelihood/Project.toml b/research/maximum_likelihood/Project.toml new file mode 100644 index 0000000..abf67da --- /dev/null +++ b/research/maximum_likelihood/Project.toml @@ -0,0 +1,7 @@ +[deps] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" diff --git a/research/maximum_likelihood/mle_demo.jl b/research/maximum_likelihood/mle_demo.jl new file mode 100644 index 0000000..c438da6 --- /dev/null +++ b/research/maximum_likelihood/mle_demo.jl @@ -0,0 +1,72 @@ +using GeneralisedFilters +using SSMProblems +using LinearAlgebra +using Random + +## TOY MODEL ############################################################################### + +# this is taken from an example in Kalman.jl +function toy_model(θ::T) where {T<:Real} + μ0 = T[1.0, 0.0] + Σ0 = Diagonal(ones(T, 2)) + + A = T[0.8 θ/2; -0.1 0.8] + Q = Diagonal(T[0.2, 1.0]) + b = zeros(T, 2) + + H = Matrix{T}(I, 1, 2) + R = Diagonal(T[0.2]) + c = zeros(T, 1) + + return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) +end + +# data generation process +rng = MersenneTwister(1234) +true_model = toy_model(1.0) +_, _, ys = sample(rng, true_model, 10000) + +# evaluate and return the log evidence +function logℓ(θ, data) + rng = MersenneTwister(1234) + _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), KF(), data) + return ll +end + +# check type stability (important for use with Enzyme) +@code_warntype logℓ([1.0], ys) + +## MLE ##################################################################################### + +using DifferentiationInterface +using ForwardDiff +using Optimisers + +# initial value +θ = [0.7] + +# setup optimiser (feel free to use other backends) +state = Optimisers.setup(Optimisers.Descent(0.5), θ) +backend = AutoForwardDiff() +num_epochs = 1000 + +# prepare gradients for faster AD +grad_prep = prepare_gradient(logℓ, backend, θ, Constant(ys)) +hess_prep = prepare_hessian(logℓ, backend, θ, Constant(ys)) + +for epoch in 1:num_epochs + # calculate gradients + val, ∇logℓ = DifferentiationInterface.value_and_gradient( + logℓ, grad_prep, backend, θ, Constant(ys) + ) + + # adjust the learning rate for a hacky Newton's method + H = DifferentiationInterface.hessian(logℓ, hess_prep, backend, θ, Constant(ys)) + Optimisers.update!(state, θ, inv(H)*∇logℓ) + + # stopping condition and printer + (epoch % 5) == 1 && println("$(epoch-1):\t $(θ[])") + if (∇logℓ'*∇logℓ) < 1e-12 + break + end +end From efe1573c95be91dd9bd76ed327ce21c66951b1e9 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 8 Jan 2025 16:11:02 -0500 Subject: [PATCH 3/4] flipped sign of objective function --- research/maximum_likelihood/mle_demo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/research/maximum_likelihood/mle_demo.jl b/research/maximum_likelihood/mle_demo.jl index c438da6..c7f86fa 100644 --- a/research/maximum_likelihood/mle_demo.jl +++ b/research/maximum_likelihood/mle_demo.jl @@ -30,7 +30,7 @@ _, _, ys = sample(rng, true_model, 10000) function logℓ(θ, data) rng = MersenneTwister(1234) _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), KF(), data) - return ll + return -ll end # check type stability (important for use with Enzyme) From 5bb89e60592759994a171c169aa5369dc79e6562 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 15 Jan 2025 17:46:04 -0500 Subject: [PATCH 4/4] fixed toml --- research/maximum_likelihood/Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/research/maximum_likelihood/Project.toml b/research/maximum_likelihood/Project.toml index abf67da..337df8c 100644 --- a/research/maximum_likelihood/Project.toml +++ b/research/maximum_likelihood/Project.toml @@ -1,7 +1,10 @@ [deps] DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GeneralisedFilters = "3ef92589-7ab8-43f9-b5b9-a3a0c86ecbb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"