-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft DA script #57
base: main
Are you sure you want to change the base?
Draft DA script #57
Changes from all commits
440252c
9fd4453
3fd90c4
884b9e3
1def6a1
a5a2e05
57da3ff
dc713b0
b846fa4
4263ae7
5a2aeb4
8db658b
15dfa9f
7e3c93d
f905a41
8ac1455
f11a63e
1fa3c93
8cb4338
73dd433
f71ab32
25cebf4
c729879
cf8ce02
6723a94
1452069
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
[deps] | ||
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" | ||
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" | ||
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" | ||
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" | ||
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d" | ||
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" | ||
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" | ||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" | ||
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" | ||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" | ||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" | ||
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
using Distributions | ||
using Random | ||
using SSMProblems | ||
using UnPack | ||
using OrdinaryDiffEq | ||
using LinearAlgebra | ||
using PDMats | ||
using GLMakie | ||
|
||
include("particles.jl") | ||
include("resamplers.jl") | ||
include("simple-filters.jl") | ||
|
||
Base.@kwdef struct Parameters{T<:Real} | ||
β::T = 8 / 3 | ||
ρ::T = 28.0 | ||
σ::T = 10.0 | ||
ν::T = 1.0 # Obs noise variance | ||
dt::T = 0.025 # Time step | ||
end | ||
|
||
function lorenz!(du, u, p::Parameters, t) | ||
@unpack β, ρ, σ = p | ||
du[1] = σ * (u[2] - u[1]) | ||
du[2] = u[1] * (ρ - u[3]) - u[2] | ||
return du[3] = u[1] * u[2] - β * u[3] | ||
end | ||
|
||
struct LatentNoiseProcess{T} <: LatentDynamics{Vector{T}} | ||
σ::AbstractPDMat{T} | ||
dt::T | ||
integrator | ||
end | ||
|
||
struct ObservationNoiseProcess{T} <: ObservationProcess{Vector{T}} | ||
σ::AbstractPDMat{T} | ||
end | ||
|
||
function SSMProblems.distribution(dyn::LatentNoiseProcess, step::Integer, prev_state, extra) | ||
reinit!(dyn.integrator, prev_state) | ||
step!(dyn.integrator, dyn.dt, true) | ||
return MvNormal(dyn.integrator.u, dyn.σ) | ||
end | ||
|
||
function SSMProblems.distribution(dyn::LatentNoiseProcess, extra) | ||
return MvNormal([1; 0; 0], dyn.σ) | ||
end | ||
|
||
function SSMProblems.distribution(obs::ObservationNoiseProcess, step::Integer, state, extra) | ||
return MvNormal(state, obs.σ * I) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the importance of having the identity matrix in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's redundant, |
||
end | ||
|
||
# Simulate some data | ||
u0 = [1.0; 0.0; 0.0] | ||
params = Parameters() | ||
|
||
dt = 0.025 | ||
N = 100 | ||
Np = 512 | ||
tspan = (0.0, dt * N) | ||
|
||
rng = MersenneTwister() | ||
|
||
prob = ODEProblem(lorenz!, u0, tspan, params) | ||
alg = Tsit5() | ||
integrator = init(prob, Tsit5(); dt=dt, adaptive=false) | ||
sol = solve(prob, alg; dt=dt, adaptive=false) | ||
|
||
# SSM Noise Model | ||
dyn = LatentNoiseProcess(ScalMat(3, params.dt), params.dt, integrator) | ||
obs = ObservationNoiseProcess(ScalMat(3, params.ν)) | ||
model = StateSpaceModel(dyn, obs) | ||
x0, x, y = sample(rng, model, N) | ||
|
||
filter = BF(Np; threshold=1.0, resampler=Systematic()); | ||
sparse_ancestry = AncestorCallback(eltype(model.dyn), filter.N, 1.0); | ||
tree, llbf = sample(rng, model, filter, y; callback=sparse_ancestry); | ||
lineage = get_ancestry(sparse_ancestry.tree) | ||
|
||
# Fancy 3D plot | ||
# fig = Figure() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm obsessed with this plot |
||
# lines(fig[1, 1], hcat(x0, x...)) | ||
# for (i, path) in enumerate(lineage) | ||
# lines!(fig[1, 1], hcat(path...), color=:black) | ||
# end | ||
|
||
fig = Figure() | ||
for i in eachindex(first(x)) | ||
lines(fig[i, 1], hcat(x0, x...)[i, :]) | ||
for path in lineage | ||
lines!(fig[i, 1], hcat(path...)[i, :]; color=:black, alpha=0.1) | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
using DataStructures: Stack | ||
using StatsBase | ||
|
||
## PARTICLES ############################################################################### | ||
|
||
mutable struct ParticleContainer{T,WT<:Real} | ||
filtered::Vector{T} | ||
proposed::Vector{T} | ||
ancestors::Vector{Int64} | ||
log_weights::Vector{WT} | ||
|
||
function ParticleContainer( | ||
initial_states::Vector{T}, log_weights::Vector{WT} | ||
) where {T,WT<:Real} | ||
return new{T,WT}( | ||
initial_states, similar(initial_states), eachindex(log_weights), log_weights | ||
) | ||
end | ||
end | ||
|
||
Base.collect(pc::ParticleContainer) = pc.vals | ||
Base.length(pc::ParticleContainer) = length(pc.vals) | ||
Base.keys(pc::ParticleContainer) = LinearIndices(pc.vals) | ||
|
||
# not sure if this is kosher, since it doesn't follow the convention of Base.getindex | ||
Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Int) = pc.vals[i] | ||
Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Vector{Int}) = pc.vals[i] | ||
|
||
function reset_weights!(pc::ParticleContainer{T,WT}) where {T,WT<:Real} | ||
fill!(pc.log_weights, zero(WT)) | ||
return pc.log_weights | ||
end | ||
|
||
function StatsBase.weights(pc::ParticleContainer) | ||
return softmax(pc.log_weights) | ||
end | ||
|
||
## SPARSE PARTICLE STORAGE ################################################################# | ||
|
||
Base.append!(s::Stack, a::AbstractVector) = map(x -> push!(s, x), a) | ||
|
||
mutable struct ParticleTree{T} | ||
states::Vector{T} | ||
parents::Vector{Int64} | ||
leaves::Vector{Int64} | ||
offspring::Vector{Int64} | ||
free_indices::Stack{Int64} | ||
|
||
function ParticleTree(states::Vector{T}, M::Integer) where {T} | ||
nodes = Vector{T}(undef, M) | ||
initial_free_indices = Stack{Int64}() | ||
append!(initial_free_indices, M:-1:(length(states) + 1)) | ||
@inbounds nodes[1:length(states)] = states | ||
return new{T}( | ||
nodes, zeros(Int64, M), 1:length(states), zeros(Int64, M), initial_free_indices | ||
) | ||
end | ||
end | ||
|
||
Base.length(tree::ParticleTree) = length(tree.states) | ||
Base.keys(tree::ParticleTree) = LinearIndices(tree.states) | ||
|
||
function prune!(tree::ParticleTree, offspring::Vector{Int64}) | ||
# insert new offspring counts | ||
setindex!(tree.offspring, offspring, tree.leaves) | ||
|
||
# update each branch | ||
@inbounds for i in eachindex(offspring) | ||
j = tree.leaves[i] | ||
while (j > 0) && (tree.offspring[j] == 0) | ||
push!(tree.free_indices, j) | ||
j = tree.parents[j] | ||
if j > 0 | ||
tree.offspring[j] -= 1 | ||
end | ||
end | ||
end | ||
end | ||
|
||
function insert!( | ||
tree::ParticleTree{T}, states::Vector{T}, a::AbstractVector{Int64} | ||
) where {T} | ||
# parents of new generation | ||
parents = getindex(tree.leaves, a) | ||
|
||
# ensure there are enough dead branches | ||
if (length(tree.free_indices) < length(a)) | ||
@debug "expanding tree" | ||
expand!(tree) | ||
end | ||
|
||
# find places for new states | ||
@inbounds for i in eachindex(states) | ||
tree.leaves[i] = pop!(tree.free_indices) | ||
end | ||
|
||
# insert new generation and update parent child relationships | ||
setindex!(tree.states, states, tree.leaves) | ||
setindex!(tree.parents, parents, tree.leaves) | ||
return tree | ||
end | ||
|
||
function expand!(tree::ParticleTree) | ||
M = length(tree) | ||
resize!(tree.states, 2 * M) | ||
|
||
# new allocations must be zero valued, this is not a perfect solution | ||
tree.parents = [tree.parents; zero(tree.parents)] | ||
tree.offspring = [tree.offspring; zero(tree.offspring)] | ||
append!(tree.free_indices, (2 * M):-1:(M + 1)) | ||
return tree | ||
end | ||
|
||
function get_offspring(a::AbstractVector{Int64}) | ||
offspring = zero(a) | ||
for i in a | ||
offspring[i] += 1 | ||
end | ||
return offspring | ||
end | ||
|
||
function get_ancestry(tree::ParticleTree{T}) where {T} | ||
paths = Vector{Vector{T}}(undef, length(tree.leaves)) | ||
@inbounds for (k, i) in enumerate(tree.leaves) | ||
j = tree.parents[i] | ||
xi = tree.states[i] | ||
|
||
xs = [xi] | ||
while j > 0 | ||
push!(xs, tree.states[j]) | ||
j = tree.parents[j] | ||
end | ||
paths[k] = reverse(xs) | ||
end | ||
return paths | ||
end | ||
|
||
## ANCESTOR STORAGE CALLBACK ############################################################### | ||
|
||
mutable struct AncestorCallback | ||
tree::ParticleTree | ||
|
||
function AncestorCallback(::Type{T}, N::Integer, C::Real=1.0) where {T} | ||
M = floor(Int64, C * N * log(N)) | ||
nodes = Vector{T}(undef, N) | ||
return new(ParticleTree(nodes, M)) | ||
end | ||
end | ||
|
||
function (c::AncestorCallback)(model, filter, step, states, data; kwargs...) | ||
if step == 1 | ||
# this may be incorrect, but it is functional | ||
@inbounds c.tree.states[1:(filter.N)] = deepcopy(states.filtered) | ||
end | ||
prune!(c.tree, get_offspring(states.ancestors)) | ||
insert!(c.tree, states.filtered, states.ancestors) | ||
return nothing | ||
end | ||
|
||
mutable struct ResamplerCallback | ||
tree::ParticleTree | ||
|
||
function ResamplerCallback(N::Integer, C::Real=1.0) | ||
M = floor(Int64, C * N * log(N)) | ||
nodes = collect(1:N) | ||
return new(ParticleTree(nodes, M)) | ||
end | ||
end | ||
|
||
function (c::ResamplerCallback)(model, filter, step, states, data; kwargs...) | ||
if step != 1 | ||
prune!(c.tree, get_offspring(states.ancestors)) | ||
insert!(c.tree, collect(1:(filter.N)), states.ancestors) | ||
end | ||
return nothing | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
using Random | ||
using Distributions | ||
|
||
abstract type AbstractResampler end | ||
|
||
## DOUBLE PRECISION STABLE ALGORITHMS ###################################################### | ||
|
||
struct Multinomial <: AbstractResampler end | ||
|
||
function resample( | ||
rng::AbstractRNG, ::Multinomial, weights::AbstractVector{WT}, n::Int64=length(weights) | ||
) where {WT<:Real} | ||
return rand(rng, Distributions.Categorical(weights), n) | ||
end | ||
|
||
struct Systematic <: AbstractResampler end | ||
|
||
function resample( | ||
rng::AbstractRNG, ::Systematic, weights::AbstractVector{WT}, n::Int64=length(weights) | ||
) where {WT<:Real} | ||
# pre-calculations | ||
@inbounds v = n * weights[1] | ||
u = oftype(v, rand(rng)) | ||
|
||
# initialize sampling algorithm | ||
a = Vector{Int64}(undef, n) | ||
idx = 1 | ||
|
||
@inbounds for i in 1:n | ||
while v < u | ||
idx += 1 | ||
v += n * weights[idx] | ||
end | ||
a[i] = idx | ||
u += one(u) | ||
end | ||
|
||
return a | ||
end | ||
|
||
## SINGLE PRECISION STABLE ALGORITHMS ###################################################### | ||
|
||
struct Metropolis <: AbstractResampler | ||
ε::Float64 | ||
function Metropolis(ε::Float64=0.01) | ||
return new(ε) | ||
end | ||
end | ||
|
||
# TODO: this should be done in the log domain and also parallelized | ||
function resample( | ||
rng::AbstractRNG, | ||
resampler::Metropolis, | ||
weights::AbstractVector{WT}, | ||
n::Int64=length(weights); | ||
) where {WT<:Real} | ||
# pre-calculations | ||
β = mean(weights) | ||
B = Int64(cld(log(resampler.ε), log(1 - β))) | ||
|
||
# initialize the algorithm | ||
a = Vector{Int64}(undef, n) | ||
|
||
@inbounds for i in 1:n | ||
k = i | ||
for _ in 1:B | ||
j = rand(rng, 1:n) | ||
v = weights[j] / weights[k] | ||
if rand(rng) ≤ v | ||
k = j | ||
end | ||
end | ||
a[i] = k | ||
end | ||
|
||
return a | ||
end | ||
|
||
struct Rejection <: AbstractResampler end | ||
|
||
# TODO: this should be done in the log domain and also parallelized | ||
function resample( | ||
rng::AbstractRNG, ::Rejection, weights::AbstractVector{WT}, n::Int64=length(weights) | ||
) where {WT<:Real} | ||
# pre-calculations | ||
max_weight = maximum(weights) | ||
|
||
# initialize the algorithm | ||
a = Vector{Int64}(undef, n) | ||
|
||
@inbounds for i in 1:n | ||
j = i | ||
u = rand(rng) | ||
while u > weights[j] / max_weight | ||
j = rand(rng, 1:n) | ||
u = rand(rng) | ||
end | ||
a[i] = j | ||
end | ||
|
||
return a | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sure this is just left over from some experimentation. Did you try running anything with
StaticArrays
?