diff --git a/Project.toml b/Project.toml index 7d7e129..9bb2749 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.6.0" +version = "5.6.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/sample.jl b/src/sample.jl index 2324604..5b8e052 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -391,13 +391,24 @@ function mcmcsample( # Copy the random number generator, model, and sample for each thread nchunks = min(nchains, Threads.nthreads()) - chunksize = cld(nchains, nchunks) interval = 1:nchunks # `copy` instead of `deepcopy` for RNGs: https://github.com/JuliaLang/julia/issues/42899 rngs = [copy(rng) for _ in interval] models = [deepcopy(model) for _ in interval] samplers = [deepcopy(sampler) for _ in interval] + # Distribute chains amongst the chunks. If nchains/nchunks = m with + # remainder n, then the first n chunks will have m + 1 chains, and the rest + # will have m chains. + m, n = divrem(nchains, nchunks) + chain_index_groups = UnitRange{Int}[] + current_index = 1 + for i in interval + nchains_this_chunk = i <= n ? m + 1 : m + push!(chain_index_groups, current_index:(current_index + nchains_this_chunk - 1)) + current_index += nchains_this_chunk + end + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -436,13 +447,8 @@ function mcmcsample( Distributed.@async begin try - Distributed.@sync for (i, _rng, _model, _sampler) in - zip(1:nchunks, rngs, models, samplers) - chainidxs = if i == nchunks - ((i - 1) * chunksize + 1):nchains - else - ((i - 1) * chunksize + 1):(i * chunksize) - end + Distributed.@sync for (chainidxs, _rng, _model, _sampler) in + zip(chain_index_groups, rngs, models, samplers) Threads.@spawn for chainidx in chainidxs # Seed the chunk-specific random number generator with the pre-made seed. Random.seed!(_rng, seeds[chainidx])