Skip to content

Commit

Permalink
Fix indexing for chains in different threads
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Jan 8, 2025
1 parent 5a3b155 commit ec621b7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probabilistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.6.0"
version = "5.6.1"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
22 changes: 14 additions & 8 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,13 +391,24 @@ function mcmcsample(

# Copy the random number generator, model, and sample for each thread
nchunks = min(nchains, Threads.nthreads())
chunksize = cld(nchains, nchunks)
interval = 1:nchunks
# `copy` instead of `deepcopy` for RNGs: https://github.com/JuliaLang/julia/issues/42899
rngs = [copy(rng) for _ in interval]
models = [deepcopy(model) for _ in interval]
samplers = [deepcopy(sampler) for _ in interval]

# Distribute chains amongst the chunks. If nchains/nchunks = m with
# remainder n, then the first n chunks will have m + 1 chains, and the rest
# will have m chains.
m, n = divrem(nchains, nchunks)
chain_index_groups = UnitRange{Int}[]
current_index = 1
for i in interval
nchains_this_chunk = i <= n ? m + 1 : m
push!(chain_index_groups, current_index:(current_index + nchains_this_chunk - 1))
current_index += nchains_this_chunk
end

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand Down Expand Up @@ -436,13 +447,8 @@ function mcmcsample(

Distributed.@async begin
try
Distributed.@sync for (i, _rng, _model, _sampler) in
zip(1:nchunks, rngs, models, samplers)
chainidxs = if i == nchunks
((i - 1) * chunksize + 1):nchains
else
((i - 1) * chunksize + 1):(i * chunksize)
end
Distributed.@sync for (chainidxs, _rng, _model, _sampler) in
zip(chain_index_groups, rngs, models, samplers)
Threads.@spawn for chainidx in chainidxs
# Seed the chunk-specific random number generator with the pre-made seed.
Random.seed!(_rng, seeds[chainidx])
Expand Down

0 comments on commit ec621b7

Please sign in to comment.