Skip to content

Commit

Permalink
Fixed significant bugs (#141)
Browse files Browse the repository at this point in the history
* fixed bugs so now the package is actually useable

* added proper testing of the swapping capabilities

* bumped patch version

* added some docstrings
  • Loading branch information
torfjelde authored Feb 16, 2023
1 parent ca658ec commit fe446d1
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCTempering"
uuid = "ce233488-44ea-4441-b732-192676ce2298"
authors = ["Harrison Wilde <[email protected]> and contributors"]
version = "0.3.0"
version = "0.3.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 1 addition & 1 deletion src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ If `I...` is not specified, the sampler corresponding to `β=1.0` will be return
"""
sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1)
function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...)
return getsampler(sampler.sampler, state.chain_to_process[I...])
return getsampler(sampler.sampler, chain_to_process(state, I...))
end

"""
Expand Down
34 changes: 29 additions & 5 deletions src/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ temperatures between workers rather than the full states.
This implementation follows approach (2).
Here's an exemplar realisation of five steps of sampling and swap-attempts:
Here's an example realisation of five steps of sampling and swap-attempts:
```
Chains: process_to_chain chain_to_process inverse_temperatures[process_to_chain[i]]
Expand Down Expand Up @@ -82,14 +82,32 @@ The indices here are exactly those represented by `states[k].chain_to_process[1]
swap_acceptance_ratios
end

"""
process_to_chain(state, I...)
Return the chain index corresponding to the process index `I`.
"""
process_to_chain(state::TemperedState, I...) = process_to_chain(state.process_to_chain, I...)
# NOTE: Array impl. is useful for testing.
process_to_chain(proc2chain::AbstractArray, I...) = proc2chain[I...]

"""
chain_to_process(state, I...)
Return the process index corresponding to the chain index `I`.
"""
chain_to_process(state::TemperedState, I...) = chain_to_process(state.chain_to_process, I...)
# NOTE: Array impl. is useful for testing.
chain_to_process(chain2proc::AbstractArray, I...) = chain2proc[I...]

"""
transition_for_chain(state[, I...])
Return the transition corresponding to the chain indexed by `I...`.
If `I...` is not specified, the transition corresponding to `β=1.0` will be returned, i.e. `I = (1, )`.
"""
transition_for_chain(state::TemperedState) = transition_for_chain(state, 1)
transition_for_chain(state::TemperedState, I...) = state.transitions_and_states[state.chain_to_process[I...]][1]
transition_for_chain(state::TemperedState, I...) = transition_for_process(state, chain_to_process(state, I...))

"""
transition_for_process(state, I...)
Expand All @@ -105,7 +123,7 @@ Return the state corresponding to the chain indexed by `I...`.
If `I...` is not specified, the state corresponding to `β=1.0` will be returned.
"""
state_for_chain(state::TemperedState) = state_for_chain(state, 1)
state_for_chain(state::TemperedState, I...) = state.transitions_and_states[I...][2]
state_for_chain(state::TemperedState, I...) = state_for_process(state, chain_to_process(state, I...))

"""
state_for_process(state, I...)
Expand All @@ -121,14 +139,20 @@ Return the β corresponding to the chain indexed by `I...`.
If `I...` is not specified, the β corresponding to `β=1.0` will be returned.
"""
β_for_chain(state::TemperedState) = β_for_chain(state, 1)
β_for_chain(state::TemperedState, I...) = state.inverse_temperatures[state.chain_to_process[I...]]
β_for_chain(state::TemperedState, I...) = β_for_chain(state.inverse_temperatures, I...)
# NOTE: Array impl. is useful for testing.
β_for_chain(chain_to_beta::AbstractArray, I...) = chain_to_beta[I...]

"""
β_for_process(state, I...)
Return the β corresponding to the process indexed by `I...`.
"""
β_for_process(state::TemperedState, I...) = state.inverse_temperatures[I...]
β_for_process(state::TemperedState, I...) = β_for_process(state.inverse_temperatures, state.process_to_chain, I...)
# NOTE: Array impl. is useful for testing.
function β_for_process(chain_to_beta::AbstractArray, proc2chain::AbstractArray, I...)
return β_for_chain(chain_to_beta, process_to_chain(proc2chain, I...))
end

"""
getparams(transition)
Expand Down
59 changes: 57 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ function test_and_sample_model(
callback=callback, progress=progress, init_params=init_params
)

# Let's make sure the process ↔ chain mapping is valid.
numtemps = MCMCTempering.numtemps(sampler_tempered)
for state in states_tempered
for i = 1:numtemps
# These two should be inverses of each other.
@test MCMCTempering.process_to_chain(state, MCMCTempering.chain_to_process(state, i)) == i
end
end

# Extract the states that were swapped.
states_swapped = filter(Base.Fix2(getproperty, :is_swap), states_tempered)
# Swap acceptance ratios should be compared against the target acceptance in case of adaptation.
Expand Down Expand Up @@ -192,6 +201,52 @@ function compare_chains(
end

@testset "MCMCTempering.jl" begin
@testset "Swapping" begin
# Chains: process_to_chain chain_to_process chain_to_beta[process_to_chain[i]]
# | | | | 1 2 3 4 1 2 3 4 1.00 0.75 0.50 0.25
# | | | |
# V | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25
# Λ | |
# | V | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25
# | Λ |
# Initial values.
process_to_chain = [1, 2, 3, 4]
chain_to_process = [1, 2, 3, 4]
chain_to_beta = [1.0, 0.75, 0.5, 0.25]

# Make swap chain 1 (now on process 1) ↔ chain 2 (now on process 2)
MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 1)
# Expected result: chain 1 is now on process 2, chain 2 is now on process 1.
target_process_to_chain = [2, 1, 3, 4]
@test process_to_chain[chain_to_process] == 1:length(process_to_chain)
@testset "$((process_idx, chain_idx, process_β))" for (process_idx, chain_idx, process_β) in zip(
[1, 2, 3, 4],
target_process_to_chain,
chain_to_beta[target_process_to_chain]
)
@test MCMCTempering.process_to_chain(process_to_chain, chain_idx) == process_idx
@test MCMCTempering.chain_to_process(chain_to_process, process_idx) == chain_idx
@test MCMCTempering.β_for_chain(chain_to_beta, chain_idx) == process_β
@test MCMCTempering.β_for_process(chain_to_beta, process_to_chain, process_idx) == process_β
end

# Make swap chain 2 (now on process 1) ↔ chain 3 (now on process 3)
MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 2)
# Expected result: chain 3 is now on process 1, chain 2 is now on process 3.
target_process_to_chain = [3, 1, 2, 4]
@test process_to_chain[chain_to_process] == 1:length(process_to_chain)
@testset "$((process_idx, chain_idx, process_β))" for (process_idx, chain_idx, process_β) in zip(
[1, 2, 3, 4],
target_process_to_chain,
chain_to_beta[target_process_to_chain]
)
@test MCMCTempering.process_to_chain(process_to_chain, process_idx) == chain_idx
@test MCMCTempering.chain_to_process(chain_to_process, chain_idx) == process_idx
@test MCMCTempering.β_for_chain(chain_to_beta, chain_idx) == process_β
@test MCMCTempering.β_for_process(chain_to_beta, process_to_chain, process_idx) == process_β
end
end

@testset "Simple MvNormal with no expected swaps" begin
num_iterations = 10_000
d = 1
Expand Down Expand Up @@ -346,7 +401,7 @@ end
map_parameters!(b, chain_tempered)

# TODO: Make it not broken, i.e. produce reasonable results.
compare_chains(chain_hmc, chain_tempered, atol=0.1, compare_std=false, compare_ess=false, isbroken=true)
compare_chains(chain_hmc, chain_tempered, atol=0.2, compare_std=false, compare_ess=true, isbroken=false)
end

@testset "AdvancedMH.jl" begin
Expand Down Expand Up @@ -384,7 +439,7 @@ end
map_parameters!(b, chain_tempered)

# TODO: Make it not broken, i.e. produce reasonable results.
compare_chains(chain_mh, chain_tempered, atol=0.1, compare_std=false, compare_ess=false, isbroken=true)
compare_chains(chain_mh, chain_tempered, atol=0.2, compare_std=false, compare_ess=true, isbroken=false)
end
end
end

2 comments on commit fe446d1

@yebai
Copy link
Member

@yebai yebai commented on fe446d1 Feb 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/77815

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.1 -m "<description of version>" fe446d140a5eeaf273de9922f52b861563facd16
git push origin v0.3.1

Please sign in to comment.