Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez committed Oct 23, 2024
1 parent 29152da commit 0fadd2b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/algorithms/apf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mutable struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: Abst
end

function AuxiliaryParticleFilter(
N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
N::Integer; threshold::Real=0., resampler::AbstractResampler=Systematic()
)
conditional_resampler = ESSResampler(threshold, resampler)
return AuxiliaryParticleFilter(N, conditional_resampler, zeros(N))
Expand Down
45 changes: 3 additions & 42 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,53 +95,14 @@ end
_, _, data = sample(rng, model, 20)

bf = BF(2^10; threshold=0.8)
apf = APF(2^10, threshold=1.)
_, llbf = GeneralisedFilters.filter(rng, model, bf, data)
_, llapf= GeneralisedFilters.filter(rng, model, apf, data)
_, llkf = GeneralisedFilters.filter(rng, model, KF(), data)

# since this is log valued, we can up the tolerance
@test llkf llbf atol = 2
end

@testitem "APF filter test" begin
using GeneralisedFilters
using SSMProblems
using StableRNGs
using PDMats
using LinearAlgebra
using Random: randexp

T = Float32
rng = StableRNG(1234)
σx², σy² = randexp(rng, T, 2)

# initial state distribution
μ0 = zeros(T, 2)
Σ0 = PDMat(T[1 0; 0 1])

# state transition equation
A = T[1 1; 0 1]
b = T[0; 0]
Q = PDiagMat([σx²; 0])

# observation equation
H = T[1 0]
c = T[0;]
R = [σy²;;]

# when working with PDMats, the Kalman filter doesn't play nicely without this
function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real}
return PDMat(Symmetric(mat))
end

model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R)
_, _, data = sample(rng, model, 20)

bf = APF(2^10, threshold=0.8)
_, llbf = GeneralisedFilters.filter(rng, model, bf, data)
_, llkf = GeneralisedFilters.filter(rng, model, KF(), data)

# since this is log valued, we can up the tolerance
@test llkf llbf atol = 2
@test llkf llapf atol = 2
end

@testitem "Forward algorithm test" begin
Expand Down

0 comments on commit 0fadd2b

Please sign in to comment.