Skip to content

Commit

Permalink
new interface, fix pre-sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez committed Oct 23, 2024
1 parent d0ab700 commit 29152da
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/algorithms/apf.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
export AuxiliaryParticleFilter, APF

struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: AbstractFilter
mutable struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: AbstractFilter
N::Integer
resampler::RS
aux::Vector # Auxiliary weights
end

function AuxiliaryParticleFilter(
N::Integer, threshold::Real=1.0, resampler::AbstractResampler=Systematic()
N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
)
conditional_resampler = ESSResampler(threshold, resampler)
return AuxiliaryParticleFilter(N, conditional_resampler, zeros(N))
Expand All @@ -25,7 +25,7 @@ function initialise(
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N))
initial_weights = fill(-log(T(filter.N)), filter.N)

return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state)
return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter)
end

function update_weights!(
Expand Down Expand Up @@ -57,16 +57,16 @@ function predict(
auxiliary_weights = map(
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), predicted
)
state.filtered.log_weights .+= auxiliary_weights
states.filtered.log_weights .+= auxiliary_weights
filter.aux = auxiliary_weights

states.proposed = resample(rng, filter.resampler, states.filtered, filter)
states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter)
states.proposed.particles = map(
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...),
states.proposed.particles,
)

return update_ref!(states, ref_state, step)
return update_ref!(states, ref_state, filter, step)
end

function update(
Expand Down

0 comments on commit 29152da

Please sign in to comment.