Skip to content

Commit

Permalink
Improvements to state_from and more extensive docs (#155)
Browse files Browse the repository at this point in the history
* added example of adding support for AdvancedHMC to docs, also
demonstrating how to specialize certain methods to improve efficiency

* removed unintended addition

* fixed bug in state_from for MultiModel

* added some convenience methods for SwapTransition

* hopefully fix test failure

* more testing for bundle_samples of compositions

* slight modification to the docs
  • Loading branch information
torfjelde authored Mar 21, 2023
1 parent f9ab139 commit ed1ca98
Show file tree
Hide file tree
Showing 11 changed files with 338 additions and 42 deletions.
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MCMCTempering = "ce233488-44ea-4441-b732-192676ce2298"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
18 changes: 14 additions & 4 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ MCMCTempering.tempered
MCMCTempering.TemperedSampler
```

Under the hood, [`MCMCTempering.TemperedSampler`](@ref) is actually just a "fancy" representation of a composition (represented using a [`MCMCTempering.CompositionSampler`](@ref)) of a [`MultiSampler`](@ref) and a [`SwapSampler`](@ref).
Under the hood, [`MCMCTempering.TemperedSampler`](@ref) is actually just a "fancy" representation of a composition (represented using a [`MCMCTempering.CompositionSampler`](@ref)) of a [`MCMCTempering.MultiSampler`](@ref) and a [`MCMCTempering.SwapSampler`](@ref).

Roughly speaking, the implementation of `AbstractMCMC.step` for [`MCMCTempering.TemperedSampler`](@ref) is basically

Expand All @@ -26,7 +26,17 @@ which in this case is provided by repeated calls to [`MCMCTempering.make_tempere
MCMCTempering.make_tempered_model
```

This should be overloaded if you have some custom model-type that does not support the LogDensityProblems.jl-interface.
This should be overloaded if you have some custom model-type that does not support the LogDensityProblems.jl-interface. In the case where the model _does_ support the LogDensityProblems.jl-interface, then the following will automatically be constructed

```@docs
MCMCTempering.TemperedLogDensityProblem
```

In addition, for computation of the tempered logdensities, we have

```@docs
MCMCTempering.compute_logdensities
```

## Swapping

Expand Down Expand Up @@ -80,13 +90,13 @@ MCMCTempering.CompositionState

Large compositions can have unfortunate effects on the compilation times in Julia.

To alleviate this issue we also have the [`RepeatedSampler`](@ref):
To alleviate this issue we also have the [`MCMCTempering.RepeatedSampler`](@ref):

```@docs
MCMCTempering.RepeatedSampler
```

In the case where [`saveall`](@ref) returns `false`, `step` for a [`MCMCTempering.RepeatedSampler`](@ref) simply returns the last transition and state; if it returns `true`, then the transition is of type [`MCMCTempering.SequentialTransitions`](@ref) and the state is of type [`MCMCTempering.SequentialStates`](@ref).
In the case where [`MCMCTempering.saveall`](@ref) returns `false`, `step` for a [`MCMCTempering.RepeatedSampler`](@ref) simply returns the last transition and state; if it returns `true`, then the transition is of type [`MCMCTempering.SequentialTransitions`](@ref) and the state is of type [`MCMCTempering.SequentialStates`](@ref).

```@docs
MCMCTempering.SequentialTransitions
Expand Down
160 changes: 157 additions & 3 deletions docs/src/getting-started.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# Getting started

## Mixture of Gaussians
# Getting started: a simple Mixture of Gaussians example

Suppose we have a mixture of Gaussians, e.g. something like

Expand Down Expand Up @@ -38,6 +36,8 @@ LogDensityProblems.capabilities(::Type{<:DistributionLogDensity}) = LogDensityPr
target_model = DistributionLogDensity(target_distribution)
```

## Metropolis-Hastings (AdvancedMH.jl)

Immediately one might reach for a standard sampler, e.g. a random-walk Metropolis-Hastings (RWMH) from [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl) and start sampling using `sample`:

```@example gmm
Expand Down Expand Up @@ -149,3 +149,157 @@ end
density!(chain_tempered_all[1], color="green", size=figsize)
plot!(size=figsize)
```

## HMC (AdvancedHMC.jl)

We also do this with AdvancedHMC.jl.

```@example gmm
using AdvancedHMC: AdvancedHMC
using ForwardDiff: ForwardDiff # for automatic differentation of the logdensity
# Creation of the sampler.
metric = AdvancedHMC.DiagEuclideanMetric(1)
integrator = AdvancedHMC.Leapfrog(0.1)
proposal = AdvancedHMC.StaticTrajectory(integrator, 8)
sampler = AdvancedHMC.HMCSampler(proposal, metric)
sampler_tempered = MCMCTempering.TemperedSampler(sampler, inverse_temperatures)
# Sample!
num_iterations = 5_000
chain = sample(
rng,
target_model, sampler, num_iterations;
chain_type=MCMCChains.Chains,
param_names=["x"],
)
plot(chain, size=figsize)
```

Then if we want to make it work with MCMCTempering, we define the same methods as before:

```@example gmm
# Provides a convenient way of "mutating" (read: reconstructing) types with different values
# for specified fields; see usage below.
using Setfield: Setfield
function MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState)
t = state.transition
return t.z.θ, t.z.ℓπ.value
end
function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, logprob)
# NOTE: Need to recompute the gradient because it might be used in the next integration step.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, params, state.transition.z.r;
ℓκ=state.transition.z.ℓκ
)
end
```

And then, just as before, we can `sample`:

```@example gmm
chain_tempered_all = sample(
StableRNG(42),
target_model, sampler_tempered, num_iterations;
chain_type=Vector{MCMCChains.Chains},
param_names=["x"]
);
```

```@example gmm
plot(target_distribution; components=false, linewidth=2)
density!(chain)
# Tempered ones.
for chain_tempered in chain_tempered_all[2:end]
density!(chain_tempered, color="green", alpha=inv(sqrt(length(chain_tempered_all))))
end
density!(chain_tempered_all[1], color="green", size=figsize)
plot!(size=figsize)
```


Works like a charm!

_But_ we're recomputing both the logdensity and the gradient of the logdensity upon every [`MCMCTempering.setparams_and_logprob!!`](@ref) above! This seems wholly unnecessary in the tempering case, since

```math
\pi_{\beta_1}(x) = \pi(x)^{\beta_1} = \big( \pi(x)^{\beta_2} \big)^{\beta_1 / \beta_2} = \pi_{\beta_2}^{\beta_1 / \beta_2}
```

i.e. if `model` in the above is tempered with ``\beta_1`` and the `params` are coming from a model with ``\beta_2``, we can could just compute it as

```julia
(β_1 / β_2) * logprob
```

and similarly for the gradient! Luckily, it's possible to tell MCMCTempering that this should be done by overloading the [`MCMCTempering.state_from`](@ref) method. In particular, we'll specify that when we're working with two models of type [`MCMCTempering.TemperedLogDensityProblem`](@ref) and two states of type `AdvancedHMC.HMCState`, then we can just re-use scale the logdensity and gradient computation from the [`MCMCTempering.state_from`](@ref) to get the quantities we want, thus avoiding unnecessary computations:

```@docs
MCMCTempering.state_from
```

```@example gmm
using AbstractMCMC: AbstractMCMC
function MCMCTempering.state_from(
# AdvancedHMC.jl works with `LogDensityModel`, and by default `AbstractMCMC` will wrap
# the input model with `LogDensityModel`, thus asusming it implements the
# LogDensityProblems.jl-interface, by default.
model::AbstractMCMC.LogDensityModel{<:MCMCTempering.TemperedLogDensityProblem},
model_from::AbstractMCMC.LogDensityModel{<:MCMCTempering.TemperedLogDensityProblem},
state::AdvancedHMC.HMCState,
state_from::AdvancedHMC.HMCState,
)
# We'll need the momentum and the kinetic energy from `ze.`
z = state.transition.z
# From this, we'll need everything else.
z_from = state_from.transition.z
params_from = z_from.θ
logprob_from = z_from.ℓπ.value
gradient_from = z_from.ℓπ.gradient
# `logprob` is actually `β * actual_logprob`, and we want it to be `β_from * actual_logprob`, so
# we can compute the "new" logprob as `(β_from / β) * logprob_from`.
beta = model.logdensity.beta
beta_from = model_from.logdensity.beta
delta_beta = beta / beta_from
logprob_new = delta_beta * logprob_from
gradient_new = delta_beta .* gradient_from
# Construct `PhasePoint`. Note that we keep `r` and `ℓκ` from the original state.
return Setfield.@set state.transition.z = AdvancedHMC.PhasePoint(
params_from,
z.r,
AdvancedHMC.DualValue(logprob_new, gradient_new),
z.ℓκ
)
end
```

!!! note
For a general model we'd also have to do the same for [`MCMCTempering.compute_logdensities`](@ref) if we want to completely eliminate unnecessary computations, but for `AbstractMCMC.LogDensity{<:MCMCTempering.TemperedLogDensityProblem}` this is already implemented in MCMCTempering.

Now we can do the same but slightly faster:

```@example gmm
chain_tempered_all = sample(
StableRNG(42),
target_model, sampler_tempered, num_iterations;
chain_type=Vector{MCMCChains.Chains},
param_names=["x"]
);
```

```@example gmm
plot(target_distribution; components=false, linewidth=2)
density!(chain)
# Tempered ones.
for chain_tempered in chain_tempered_all[2:end]
density!(chain_tempered, color="green", alpha=inv(sqrt(length(chain_tempered_all))))
end
density!(chain_tempered_all[1], color="green", size=figsize)
plot!(size=figsize)
```
24 changes: 15 additions & 9 deletions src/MCMCTempering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,12 @@ function AbstractMCMC.bundle_samples(
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
# TODO: Implement special one for `Vector{MCMCChains.Chains}`.
if bundle_resolve_swaps
return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...)
end
# NOTE: If we can't resolve the swaps, there's not really much we can do in terms
# of bundling the samples.
# TODO: Is this the best we can do?
!bundle_resolve_swaps && return ts

# TODO: Do better?
return ts
return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...)
end

function AbstractMCMC.bundle_samples(
Expand Down Expand Up @@ -188,15 +187,19 @@ function AbstractMCMC.bundle_samples(
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
# NOTE: If we can't resolve the swaps, there's not really much we can do in terms
# of bundling the samples.
# TODO: Is this the best we can do?
!bundle_resolve_swaps && return ts

# Resolve the swaps.
sampler_without_saveall = @set sampler.sampler_inner.saveall = Val(false)
# Resolve the swaps (using the already implemented resolution in `composition_transition`
# for this particular sampler but without `saveall`).
sampler_without_saveall = @set sampler.saveall = Val(false)
ts_actual = map(ts) do t
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
end

AbstractMCMC.bundle_samples(
return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
kwargs...
)
Expand All @@ -212,6 +215,9 @@ function AbstractMCMC.bundle_samples(
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
# NOTE: If we can't resolve the swaps, there's not really much we can do in terms
# of bundling the samples.
# TODO: Is this the best we can do?
!bundle_resolve_swaps && return ts

# Resolve the swaps (using the already implemented resolution in `composition_transition`
Expand Down
13 changes: 7 additions & 6 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,18 @@ See also: [`getparams_and_logprob`](@ref).
setparams_and_logprob!!(model, state, params, logprob) = setparams_and_logprob!!(state, params, logprob)

"""
state_from(model, state_target, state_source[, transition_source, transition_target])
state_from(model_source, state_target, state_source)
state_from(model_source, model_target, state_target, state_source)
Return a new state similar to `state_target` but updated from `state_source`, which could be
a different type of state.
"""
function state_from(model, state_target, state_source, transition_target, transition_source)
return state_from(model, state_target, state_source)
function state_from(model_target, model_source, state_target, state_source)
return state_from(model_target, state_target, state_source)
end
function state_from(model, state_target, state_source)
params, logp = getparams_and_logprob(model, state_source)
return setparams_and_logprob!!(model, state_target, params, logp)
function state_from(model_target, state_target, state_source)
params, logp = getparams_and_logprob(state_source)
return setparams_and_logprob!!(model_target, state_target, params, logp)
end

"""
Expand Down
20 changes: 20 additions & 0 deletions src/samplers/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,26 @@ function setparams_and_logprob!!(model::MultiModel, state::MultipleStates, param
return @set state.states = map(setparams_and_logprob!!, model.models, state.states, params, logprobs)
end

# NOTE: Is this too general; should we specify the types of the states?
function state_from(model::MultiModel, state::MultipleStates, state_other::MultipleStates)
@assert length(model.models) == length(state.states) == length(state_other.states) "The number of models and states must match."
return @set state.states = map(model.models, state.states, state_other.states) do m, s1, s2
state_from(m, s1, s2)
end
end

function state_from(model::MultiModel, model_other::MultiModel, state::MultipleStates, state_other::MultipleStates)
@assert length(model.models) == length(model_other.models) == length(state.states) == length(state_other.states) "The number of models and states must match."
return @set state.states = map(
model.models,
model_other.models,
state.states,
state_other.states,
) do m1, m2, s1, s2
state_from(m1, m2, s1, s2)
end
end

# TODO: Clean this up.
initparams(model::MultiModel, init_params) = map(Base.Fix1(get_init_params, init_params), 1:length(model.models))
initparams(model::MultiModel{<:Tuple}, init_params) = ntuple(length(model.models)) do i
Expand Down
8 changes: 6 additions & 2 deletions src/stepping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function AbstractMCMC.step(
for i in 1:numtemps(sampler)
])
multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)])
multistate = last(AbstractMCMC.step(rng, multimodel, multisampler; kwargs...))
multitransition, multistate = AbstractMCMC.step(rng, multimodel, multisampler; kwargs...)

# Make sure to collect, because we'll be using `setindex!(!)` later.
process_to_chain = collect(1:length(sampler.chain_to_beta))
Expand All @@ -39,7 +39,11 @@ function AbstractMCMC.step(
Dict{Int,Float64}(),
)

return AbstractMCMC.step(rng, model, sampler, TemperedState(swapstate, multistate, sampler.chain_to_beta))
swaptransition = SwapTransition(deepcopy(swapstate.chain_to_process), deepcopy(swapstate.process_to_chain))
return (
TemperedTransition(swaptransition, multitransition),
TemperedState(swapstate, multistate, sampler.chain_to_beta)
)
end

function AbstractMCMC.step(
Expand Down
Loading

2 comments on commit ed1ca98

@yebai
Copy link
Member

@yebai yebai commented on ed1ca98 Apr 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/81099

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.2 -m "<description of version>" ed1ca9886d1c49aece23aacf2bc9fcaf77725fd5
git push origin v0.3.2

Please sign in to comment.