From fe446d140a5eeaf273de9922f52b861563facd16 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 16 Feb 2023 18:12:42 +0000 Subject: [PATCH] Fixed significant bugs (#141) * fixed bugs so now the package is actually useable * added proper testing of the swapping capabilities * bumped patch version * added some docstrings --- Project.toml | 2 +- src/sampler.jl | 2 +- src/state.jl | 34 ++++++++++++++++++++++++---- test/runtests.jl | 59 ++++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 88 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 844122b..52cefd8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCTempering" uuid = "ce233488-44ea-4441-b732-192676ce2298" authors = ["Harrison Wilde and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/sampler.jl b/src/sampler.jl index f180de9..e719e76 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -44,7 +44,7 @@ If `I...` is not specified, the sampler corresponding to `β=1.0` will be return """ sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1) function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...) - return getsampler(sampler.sampler, state.chain_to_process[I...]) + return getsampler(sampler.sampler, chain_to_process(state, I...)) end """ diff --git a/src/state.jl b/src/state.jl index 83b610a..9da6bc8 100644 --- a/src/state.jl +++ b/src/state.jl @@ -33,7 +33,7 @@ temperatures between workers rather than the full states. This implementation follows approach (2). -Here's an exemplar realisation of five steps of sampling and swap-attempts: +Here's an example realisation of five steps of sampling and swap-attempts: ``` Chains: process_to_chain chain_to_process inverse_temperatures[process_to_chain[i]] @@ -82,6 +82,24 @@ The indices here are exactly those represented by `states[k].chain_to_process[1] swap_acceptance_ratios end +""" + process_to_chain(state, I...) + +Return the chain index corresponding to the process index `I`. +""" +process_to_chain(state::TemperedState, I...) = process_to_chain(state.process_to_chain, I...) +# NOTE: Array impl. is useful for testing. +process_to_chain(proc2chain::AbstractArray, I...) = proc2chain[I...] + +""" + chain_to_process(state, I...) + +Return the process index corresponding to the chain index `I`. +""" +chain_to_process(state::TemperedState, I...) = chain_to_process(state.chain_to_process, I...) +# NOTE: Array impl. is useful for testing. +chain_to_process(chain2proc::AbstractArray, I...) = chain2proc[I...] + """ transition_for_chain(state[, I...]) @@ -89,7 +107,7 @@ Return the transition corresponding to the chain indexed by `I...`. If `I...` is not specified, the transition corresponding to `β=1.0` will be returned, i.e. `I = (1, )`. """ transition_for_chain(state::TemperedState) = transition_for_chain(state, 1) -transition_for_chain(state::TemperedState, I...) = state.transitions_and_states[state.chain_to_process[I...]][1] +transition_for_chain(state::TemperedState, I...) = transition_for_process(state, chain_to_process(state, I...)) """ transition_for_process(state, I...) @@ -105,7 +123,7 @@ Return the state corresponding to the chain indexed by `I...`. If `I...` is not specified, the state corresponding to `β=1.0` will be returned. """ state_for_chain(state::TemperedState) = state_for_chain(state, 1) -state_for_chain(state::TemperedState, I...) = state.transitions_and_states[I...][2] +state_for_chain(state::TemperedState, I...) = state_for_process(state, chain_to_process(state, I...)) """ state_for_process(state, I...) @@ -121,14 +139,20 @@ Return the β corresponding to the chain indexed by `I...`. If `I...` is not specified, the β corresponding to `β=1.0` will be returned. """ β_for_chain(state::TemperedState) = β_for_chain(state, 1) -β_for_chain(state::TemperedState, I...) = state.inverse_temperatures[state.chain_to_process[I...]] +β_for_chain(state::TemperedState, I...) = β_for_chain(state.inverse_temperatures, I...) +# NOTE: Array impl. is useful for testing. +β_for_chain(chain_to_beta::AbstractArray, I...) = chain_to_beta[I...] """ β_for_process(state, I...) Return the β corresponding to the process indexed by `I...`. """ -β_for_process(state::TemperedState, I...) = state.inverse_temperatures[I...] +β_for_process(state::TemperedState, I...) = β_for_process(state.inverse_temperatures, state.process_to_chain, I...) +# NOTE: Array impl. is useful for testing. +function β_for_process(chain_to_beta::AbstractArray, proc2chain::AbstractArray, I...) + return β_for_chain(chain_to_beta, process_to_chain(proc2chain, I...)) +end """ getparams(transition) diff --git a/test/runtests.jl b/test/runtests.jl index 53536ca..a13b255 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -86,6 +86,15 @@ function test_and_sample_model( callback=callback, progress=progress, init_params=init_params ) + # Let's make sure the process ↔ chain mapping is valid. + numtemps = MCMCTempering.numtemps(sampler_tempered) + for state in states_tempered + for i = 1:numtemps + # These two should be inverses of each other. + @test MCMCTempering.process_to_chain(state, MCMCTempering.chain_to_process(state, i)) == i + end + end + # Extract the states that were swapped. states_swapped = filter(Base.Fix2(getproperty, :is_swap), states_tempered) # Swap acceptance ratios should be compared against the target acceptance in case of adaptation. @@ -192,6 +201,52 @@ function compare_chains( end @testset "MCMCTempering.jl" begin + @testset "Swapping" begin + # Chains: process_to_chain chain_to_process chain_to_beta[process_to_chain[i]] + # | | | | 1 2 3 4 1 2 3 4 1.00 0.75 0.50 0.25 + # | | | | + # V | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 + # Λ | | + # | V | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 + # | Λ | + # Initial values. + process_to_chain = [1, 2, 3, 4] + chain_to_process = [1, 2, 3, 4] + chain_to_beta = [1.0, 0.75, 0.5, 0.25] + + # Make swap chain 1 (now on process 1) ↔ chain 2 (now on process 2) + MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 1) + # Expected result: chain 1 is now on process 2, chain 2 is now on process 1. + target_process_to_chain = [2, 1, 3, 4] + @test process_to_chain[chain_to_process] == 1:length(process_to_chain) + @testset "$((process_idx, chain_idx, process_β))" for (process_idx, chain_idx, process_β) in zip( + [1, 2, 3, 4], + target_process_to_chain, + chain_to_beta[target_process_to_chain] + ) + @test MCMCTempering.process_to_chain(process_to_chain, chain_idx) == process_idx + @test MCMCTempering.chain_to_process(chain_to_process, process_idx) == chain_idx + @test MCMCTempering.β_for_chain(chain_to_beta, chain_idx) == process_β + @test MCMCTempering.β_for_process(chain_to_beta, process_to_chain, process_idx) == process_β + end + + # Make swap chain 2 (now on process 1) ↔ chain 3 (now on process 3) + MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 2) + # Expected result: chain 3 is now on process 1, chain 2 is now on process 3. + target_process_to_chain = [3, 1, 2, 4] + @test process_to_chain[chain_to_process] == 1:length(process_to_chain) + @testset "$((process_idx, chain_idx, process_β))" for (process_idx, chain_idx, process_β) in zip( + [1, 2, 3, 4], + target_process_to_chain, + chain_to_beta[target_process_to_chain] + ) + @test MCMCTempering.process_to_chain(process_to_chain, process_idx) == chain_idx + @test MCMCTempering.chain_to_process(chain_to_process, chain_idx) == process_idx + @test MCMCTempering.β_for_chain(chain_to_beta, chain_idx) == process_β + @test MCMCTempering.β_for_process(chain_to_beta, process_to_chain, process_idx) == process_β + end + end + @testset "Simple MvNormal with no expected swaps" begin num_iterations = 10_000 d = 1 @@ -346,7 +401,7 @@ end map_parameters!(b, chain_tempered) # TODO: Make it not broken, i.e. produce reasonable results. - compare_chains(chain_hmc, chain_tempered, atol=0.1, compare_std=false, compare_ess=false, isbroken=true) + compare_chains(chain_hmc, chain_tempered, atol=0.2, compare_std=false, compare_ess=true, isbroken=false) end @testset "AdvancedMH.jl" begin @@ -384,7 +439,7 @@ end map_parameters!(b, chain_tempered) # TODO: Make it not broken, i.e. produce reasonable results. - compare_chains(chain_mh, chain_tempered, atol=0.1, compare_std=false, compare_ess=false, isbroken=true) + compare_chains(chain_mh, chain_tempered, atol=0.2, compare_std=false, compare_ess=true, isbroken=false) end end end