Skip to content

Commit

Permalink
Calculate chainidxs inside the loop
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Jan 8, 2025
1 parent ec621b7 commit 81e0fe7
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -397,17 +397,9 @@ function mcmcsample(
models = [deepcopy(model) for _ in interval]
samplers = [deepcopy(sampler) for _ in interval]

# Distribute chains amongst the chunks. If nchains/nchunks = m with
# remainder n, then the first n chunks will have m + 1 chains, and the rest
# will have m chains.
# If nchains/nchunks = m with remainder n, then the first n chunks will
# have m + 1 chains, and the rest will have m chains.
m, n = divrem(nchains, nchunks)
chain_index_groups = UnitRange{Int}[]
current_index = 1
for i in interval
nchains_this_chunk = i <= n ? m + 1 : m
push!(chain_index_groups, current_index:(current_index + nchains_this_chunk - 1))
current_index += nchains_this_chunk
end

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)
Expand Down Expand Up @@ -447,8 +439,18 @@ function mcmcsample(

Distributed.@async begin
try
Distributed.@sync for (chainidxs, _rng, _model, _sampler) in
zip(chain_index_groups, rngs, models, samplers)
Distributed.@sync for (i, _rng, _model, _sampler) in
zip(interval, rngs, models, samplers)
if i <= n
chainidx_hi = i * (m + 1)
nchains_chunk = m + 1
else
chainidx_hi = n * (m + 1) + (i - n) * m
nchains_chunk = m
end
chainidx_lo = chainidx_hi - nchains_chunk + 1
chainidxs = chainidx_lo:chainidx_hi

Threads.@spawn for chainidx in chainidxs
# Seed the chunk-specific random number generator with the pre-made seed.
Random.seed!(_rng, seeds[chainidx])
Expand Down

0 comments on commit 81e0fe7

Please sign in to comment.