From 43a56305eb7b2a0b55f0126952a1a0f218ecc0fd Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 15:45:21 +0200 Subject: [PATCH 01/56] Add PosteriorStats as dependency --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 5c5bca73..d3eee142 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +PosteriorStats = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" @@ -38,6 +39,7 @@ MCMCDiagnosticTools = "0.3" MLJModelInterface = "0.3.5, 0.4, 1.0" NaturalSort = "1" OrderedCollections = "1.4" +PosteriorStats = "0.1.2" PrettyTables = "0.9, 0.10, 0.11, 0.12, 1, 2" RecipesBase = "0.7, 0.8, 1.0" StatsBase = "0.33.2, 0.34" From 2994a561e89dd34013085bc64a6681d8b04c8e74 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 15:45:58 +0200 Subject: [PATCH 02/56] Import and reexport PosteriorStats functions --- src/MCMCChains.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index 3422620f..50cdf582 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -16,6 +16,7 @@ import MCMCDiagnosticTools import MLJModelInterface import NaturalSort import OrderedCollections +import PosteriorStats import PrettyTables import StatsFuns import Tables @@ -32,7 +33,6 @@ export set_section, get_params, sections, sort_sections, setinfo export replacenames, namesingroup, group export autocor, describe, sample, summarystats, AbstractWeights, mean, quantile export ChainDataFrame -export summarize # Reexport diagnostics functions using MCMCDiagnosticTools: discretediag, ess, ess_rhat, AutocovMethod, FFTAutocovMethod, @@ -48,6 +48,10 @@ export rafterydiag export rstar export hpd +# Reexport stats functions +using PosteriorStats: default_diagnostics, default_stats, default_summary_stats, hdi, + summarize +export default_diagnostics, default_stats, default_summary_stats, hdi, summarize """ Chains From 9ffb60b24991b01087c15af5f4d09ffa2d9f8f6c Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 15:47:12 +0200 Subject: [PATCH 03/56] Forward to PosteriorStats.summarize --- src/summarize.jl | 44 ++++++++------------------------------------ 1 file changed, 8 insertions(+), 36 deletions(-) diff --git a/src/summarize.jl b/src/summarize.jl index 70bba5d8..8e723da3 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -133,56 +133,28 @@ Summarize `chains` in a `ChainsDataFrame`. * `summarize(chns; sections=[:parameters])` : Chain summary of :parameters section * `summarize(chns; sections=[:parameters, :internals])` : Chain summary for multiple sections """ -function summarize( +function PosteriorStats.summarize( chains::Chains, funs...; sections = _default_sections(chains), - func_names::AbstractVector{Symbol} = Symbol[], append_chains::Bool = true, - name::String = "", - additional_df = nothing + kwargs... ) - # If we weren't given any functions, fall back to summary stats. - if isempty(funs) - return summarystats(chains; sections, append_chains, name) - end - # Generate a chain to work on. chn = Chains(chains, _clean_sections(chains, sections)) # Obtain names of parameters. names_of_params = names(chn) - # If no function names were given, make a new list. - fnames = isempty(func_names) ? collect(nameof.(funs)) : func_names - - # Obtain the additional named tuple. - additional_nt = additional_df === nothing ? NamedTuple() : additional_df.nt - if append_chains # Evaluate the functions. - data = to_matrix(chn) - fvals = [[f(data[:, i]) for i in axes(data, 2)] for f in funs] - - # Build the ChainDataFrame. - nt = merge((; parameters = names_of_params, zip(fnames, fvals)...), additional_nt) - df = ChainDataFrame(name, nt) - - return df + data = _permutedims_diagnostics(chains.value.data) + summarize(data, funs...; var_names=names_of_params, kwargs...) else # Evaluate the functions. data = to_vector_of_matrices(chn) - vector_of_fvals = [[[f(x[:, i]) for i in axes(x, 2)] for f in funs] for x in data] - - # Build the ChainDataFrames. - vector_of_nt = [ - merge((; parameters = names_of_params, zip(fnames, fvals)...), additional_nt) - for fvals in vector_of_fvals - ] - vector_of_df = [ - ChainDataFrame(name * " (Chain $i)", nt) - for (i, nt) in enumerate(vector_of_nt) - ] - - return vector_of_df + return map(data) do x + z = reshape(x, size(x, 1), 1, size(x, 2)) + summarize(z, funs...; var_names=names_of_params, kwargs...) + end end end From a6d16e5cd07f8beb299c08dd29115f06d98a6560 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 15:48:01 +0200 Subject: [PATCH 04/56] Update docstring --- src/summarize.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/summarize.jl b/src/summarize.jl index 8e723da3..fecf133b 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -122,9 +122,9 @@ function Base.convert(::Type{Array}, cs::Array{ChainDataFrame{T},1}) where T<:Na end """ - summarize(chains, funs...[; sections, func_names = [], name = "", append_chains = true]) + summarize(chains, funs...[; sections, name = "", append_chains = true]) -Summarize `chains` in a `ChainsDataFrame`. +Summarize `chains` in a `PosteriorStats.SummaryStats`. # Examples From 4e99fe8f069f6908ede8f9d3bd8b2994fd22c119 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 15:57:40 +0200 Subject: [PATCH 05/56] Update docstring --- src/summarize.jl | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/summarize.jl b/src/summarize.jl index fecf133b..f8d47b90 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -122,10 +122,27 @@ function Base.convert(::Type{Array}, cs::Array{ChainDataFrame{T},1}) where T<:Na end """ - summarize(chains, funs...[; sections, name = "", append_chains = true]) + summarize( + chains[, stats_funs...]; + append_chains=true, + name="SummaryStats", + [sections, var_names], + ) Summarize `chains` in a `PosteriorStats.SummaryStats`. +`stats_funs` is a collection of functions that reduces a matrix with shape `(draws, chains)` +to a scalar or a collection of scalars. Alternatively, an item in `stats_funs` may be a +`Pair` of the form `name => fun` specifying the name to be used for the statistic or of the +form `(name1, ...) => fun` when the function returns a collection. When the function returns +a collection, the names in this latter format must be provided. + +If no stats functions are provided, then those specified in [`default_summary_stats`](@ref) +are computed. + +`var_names` specifies the names of the parameters in data. If not provided, the names are +inferred from data. + # Examples * `summarize(chns)` : Complete chain summary @@ -137,13 +154,14 @@ function PosteriorStats.summarize( chains::Chains, funs...; sections = _default_sections(chains), append_chains::Bool = true, + var_names=nothing, kwargs... ) # Generate a chain to work on. chn = Chains(chains, _clean_sections(chains, sections)) # Obtain names of parameters. - names_of_params = names(chn) + names_of_params = var_names === nothing ? names(chn) : var_names if append_chains # Evaluate the functions. From 38dbdac51a98db22329e590b90e63c9201bbec71 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 15:58:11 +0200 Subject: [PATCH 06/56] Forward summarystats to summarize --- src/stats.jl | 85 ++++------------------------------------------------ 1 file changed, 6 insertions(+), 79 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index e31044ea..fb0f6f57 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -266,88 +266,15 @@ end """ - function summarystats( - chains; - sections = _default_sections(chains), - append_chains= true, - autocov_method::AbstractAutocovMethod = AutocovMethod(), - maxlag = 250, - kwargs... - ) - -Compute the mean, standard deviation, Monte Carlo standard error, bulk- and tail- effective -sample size, and ``\\widehat{R}`` diagnostic for each parameter in the chain. + summarystats(chains; kwargs...) -Setting `append_chains=false` will return a vector of dataframes containing the summary -statistics for each chain. +Compute default summary statistics from the `chains`. -When estimating the effective sample size, autocorrelations are computed for at most `maxlag` lags. +`kwargs` are forwarded to [`summarize`](@ref). To customize the summary statistics, see +`summarize`. """ -function summarystats( - chains::Chains; - sections = _default_sections(chains), - append_chains::Bool = true, - autocov_method::MCMCDiagnosticTools.AbstractAutocovMethod = AutocovMethod(), - maxlag = 250, - name = "Summary Statistics", - kwargs... -) - # Store everything. - funs = [mean∘cskip, std∘cskip] - func_names = [:mean, :std] - - # Subset the chain. - _chains = Chains(chains, _clean_sections(chains, sections)) - - # Calculate MCSE and ESS/R-hat separately. - nt_additional = NamedTuple() - try - mcse_df = MCMCDiagnosticTools.mcse( - _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, - ) - nt_additional = merge(nt_additional, (; mcse=mcse_df.nt.mcse)) - catch e - @warn "MCSE calculation failed: $e" - end - - try - ess_tail_df = MCMCDiagnosticTools.ess( - _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:tail - ) - nt_additional = merge(nt_additional, (ess_tail=ess_tail_df.nt.ess,)) - catch e - @warn "Tail ESS calculation failed: $e" - end - - try - ess_rhat_rank_df = MCMCDiagnosticTools.ess_rhat( - _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:rank - ) - nt_ess_rhat_rank = ( - ess_bulk=ess_rhat_rank_df.nt.ess, - rhat=ess_rhat_rank_df.nt.rhat, - ess_per_sec=ess_rhat_rank_df.nt.ess_per_sec - ) - nt_additional = merge(nt_additional, nt_ess_rhat_rank) - catch e - @warn "Bulk ESS/R-hat calculation failed: $e" - end - - # Possibly re-order the columns to stay backwards-compatible. - additional_keys = (:mcse, :ess_bulk, :ess_tail, :rhat, :ess_per_sec) - additional_df = ChainDataFrame("Additional", (; ((k, nt_additional[k]) for k in additional_keys if k ∈ keys(nt_additional))...)) - - # Summarize. - summary_df = summarize( - _chains, funs...; - func_names, - append_chains, - additional_df, - name, - sections = nothing - ) - - return summary_df +function summarystats(chains::Chains; name = "Summary Statistics", kwargs...) + return summarize(chains; name, kwargs...) end """ From 395a3d8aa02ae7f313d105dd418ff862d9008578 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 15:58:44 +0200 Subject: [PATCH 07/56] Simplify mean implementation --- src/stats.jl | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index fb0f6f57..fa044efb 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -283,19 +283,7 @@ end Calculate the mean of a chain. """ function mean(chains::Chains; kwargs...) - # Store everything. - funs = [mean∘cskip] - func_names = [:mean] - - # Summarize. - summary_df = summarize( - chains, funs...; - func_names = func_names, - name = "Mean", - kwargs... - ) - - return summary_df + return summarize(chains, :mean => mean ∘ cskip; name = "Mean", kwargs...) end mean(chn::Chains, syms) = mean(chn[:, syms, :]) From 4fa204ee80331e2b65d4a9f100a5aff9f1cb4ece Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:00:15 +0200 Subject: [PATCH 08/56] Simplify quantile implementation --- src/stats.jl | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index fa044efb..a042e91c 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -245,23 +245,14 @@ for each chain. function quantile( chains::Chains; q::AbstractVector = [0.025, 0.25, 0.5, 0.75, 0.975], - append_chains = true, kwargs... ) # compute quantiles - funs = Function[] - func_names = @. Symbol(100 * q, :%) - for i in q - push!(funs, x -> quantile(cskip(x), i)) + stats_funs = map(q) do qi + nm = Symbol(100 * qi, :%) + return nm => Base.Fix2(quantile, qi) ∘ cskip end - - return summarize( - chains, funs...; - func_names = func_names, - append_chains = append_chains, - name = "Quantiles", - kwargs... - ) + return summarize(chains, stats_funs...; name = "Quantiles", kwargs...) end From 5c0f35d02244a717379b7424d434366c310f08b7 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:01:40 +0200 Subject: [PATCH 09/56] Replace hpd with hdi --- src/stats.jl | 40 ++++++++++++---------------------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index a042e91c..d5118286 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -193,45 +193,29 @@ function describe( return dfs end -function _hpd(x::AbstractVector{<:Real}; alpha::Real=0.05) - n = length(x) - m = max(1, ceil(Int, alpha * n)) - - y = sort(x) - a = y[1:m] - b = y[(n - m + 1):n] - _, i = findmin(b - a) - - return [a[i], b[i]] -end - """ - hpd(chn::Chains; alpha::Real=0.05, kwargs...) + hdi(chn::Chains; prob::Real=0.94, kwargs...) -Return the highest posterior density interval representing `1-alpha` probability mass. +Return the unimodal highest density interval (HDI) representing `prob` probability mass. Note that this will return a single interval and will not return multiple intervals for discontinuous regions. # Examples -```julia-repl -julia> val = rand(500, 2, 3); -julia> chn = Chains(val, [:a, :b]); +```jldoctest; setup = :(using Random; Random.seed!(582)) +julia> val = rand(500, 2, 3) -julia> hpd(chn) -HPD - parameters lower upper - Symbol Float64 Float64 +julia> chn = Chains(val, [:a, :b]); - a 0.0554 0.9944 - b 0.0114 0.9460 +julia> hdi(chn) +HDI + lower upper + a 0.0749 0.999 + b 0.00531 0.940 ``` """ -function hpd(chn::Chains; alpha::Real=0.05, kwargs...) - labels = [:lower, :upper] - l(x) = _hpd(x, alpha=alpha)[1] - u(x) = _hpd(x, alpha=alpha)[2] - return summarize(chn, l, u; name = "HPD", func_names = labels, kwargs...) +function PosteriorStats.hdi(chn::Chains; prob::Real=0.94, kwargs...) + return summarize(chn, (:lower, :upper) => (x -> hdi(x; prob)); name = "HDI", kwargs...) end """ From 3660d0e19b6446ef12e52957e737c5f96f1c07f7 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:01:52 +0200 Subject: [PATCH 10/56] Deprecate hpd --- src/MCMCChains.jl | 1 - src/stats.jl | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index 50cdf582..f18646b0 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -47,7 +47,6 @@ export mcse export rafterydiag export rstar -export hpd # Reexport stats functions using PosteriorStats: default_diagnostics, default_stats, default_summary_stats, hdi, summarize diff --git a/src/stats.jl b/src/stats.jl index d5118286..0a132953 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -218,6 +218,8 @@ function PosteriorStats.hdi(chn::Chains; prob::Real=0.94, kwargs...) return summarize(chn, (:lower, :upper) => (x -> hdi(x; prob)); name = "HDI", kwargs...) end +@deprecate hpd(chn::Chains; alpha::Real=0.05, kwargs...) hdi(chn; prob=1 - alpha, kwargs...) + """ quantile(chains[; q = [0.025, 0.25, 0.5, 0.75, 0.975], append_chains = true, kwargs...]) From e3d2d16a19530e111ffd3cde0f690101612723dc Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:02:17 +0200 Subject: [PATCH 11/56] Simplify autocor implementation --- src/stats.jl | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index 0a132953..1ef63789 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -17,24 +17,14 @@ Setting `append_chains=false` will return a vector of dataframes containing the """ function autocor( chains::Chains; - append_chains = true, demean::Bool = true, lags::AbstractVector{<:Integer} = _default_lags(chains, append_chains), kwargs... ) - funs = Function[] - func_names = @. Symbol("lag ", lags) - for i in lags - push!(funs, x -> autocor(x, [i], demean=demean)[1]) + funs = map(lags) do lag + return Symbol("lag ", lag) => (x -> autocor(x, [i], demean=demean)[1]) end - - return summarize( - chains, funs...; - func_names = func_names, - append_chains = append_chains, - name = "Autocorrelation", - kwargs... - ) + return summarize(chains, funs...; name = "Autocorrelation", kwargs...) end """ From 49af2d9108af3323031b7146f55187e52ef5918c Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:10:29 +0200 Subject: [PATCH 12/56] Remove unused keyword `etype` --- src/stats.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index 1ef63789..6bb9a46f 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -166,7 +166,6 @@ describe(c::Chains; args...) = describe(stdout, c; args...) """ describe(io, chains[; q = [0.025, 0.25, 0.5, 0.75, 0.975], - etype = :bm, kwargs...]) Print the summary statistics and quantiles for the chain. @@ -175,7 +174,6 @@ function describe( io::IO, chains::Chains; q = [0.025, 0.25, 0.5, 0.75, 0.975], - etype = :bm, kwargs... ) dfs = vcat(summarystats(chains; etype = etype, kwargs...), From bdde660ed54ebbb90caa3d6d654b12300f1760cf Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:11:33 +0200 Subject: [PATCH 13/56] Explicitly build list of stats --- src/stats.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index 6bb9a46f..4ec40683 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -176,9 +176,8 @@ function describe( q = [0.025, 0.25, 0.5, 0.75, 0.975], kwargs... ) - dfs = vcat(summarystats(chains; etype = etype, kwargs...), - quantile(chains; q = q, kwargs...)) - return dfs + stats = [summarystats(chains; kwargs...), quantile(chains; q = q, kwargs...)] + return stats end """ From d851307c6d1fa1a50ffbf5390b4314ac0729cc1d Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:15:36 +0200 Subject: [PATCH 14/56] Simultaneously compute all quantiles --- src/stats.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index 4ec40683..26f4d255 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -208,7 +208,7 @@ end @deprecate hpd(chn::Chains; alpha::Real=0.05, kwargs...) hdi(chn; prob=1 - alpha, kwargs...) """ - quantile(chains[; q = [0.025, 0.25, 0.5, 0.75, 0.975], append_chains = true, kwargs...]) + quantile(chains[; q = (0.025, 0.25, 0.5, 0.75, 0.975), append_chains = true, kwargs...]) Compute the quantiles for each parameter in the chain. @@ -217,15 +217,17 @@ for each chain. """ function quantile( chains::Chains; - q::AbstractVector = [0.025, 0.25, 0.5, 0.75, 0.975], + q::Union{Tuple,AbstractVector} = (0.025, 0.25, 0.5, 0.75, 0.975), kwargs... ) # compute quantiles - stats_funs = map(q) do qi - nm = Symbol(100 * qi, :%) - return nm => Base.Fix2(quantile, qi) ∘ cskip - end - return summarize(chains, stats_funs...; name = "Quantiles", kwargs...) + func_names = Tuple(Symbol.(100 .* q, :%)) + return summarize( + chains, + func_names => (Base.Fix2(quantile, q) ∘ cskip); + name="Quantiles", + kwargs..., + ) end From 17174834fcd4bb835179c66fb67a35f7ec3ebf87 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:16:16 +0200 Subject: [PATCH 15/56] Print an extra newline --- src/chains.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/chains.jl b/src/chains.jl index b6852b67..3172ec22 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -351,6 +351,7 @@ function Base.show(io::IO, mime::MIME"text/plain", chains::Chains) # Show summary stats. summaries = describe(chains) for summary in summaries + println(io) println(io) show(io, mime, summary) end From bf0665353008bbf78cad49dba362982a413b0118 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:17:02 +0200 Subject: [PATCH 16/56] Use and export SummaryStats --- src/MCMCChains.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index f18646b0..b1e47634 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -48,9 +48,10 @@ export rafterydiag export rstar # Reexport stats functions -using PosteriorStats: default_diagnostics, default_stats, default_summary_stats, hdi, +using PosteriorStats: SummaryStats, default_diagnostics, default_stats, + default_summary_stats, hdi, summarize +export SummaryStats, default_diagnostics, default_stats, default_summary_stats, hdi, summarize -export default_diagnostics, default_stats, default_summary_stats, hdi, summarize """ Chains From 147b56b083fb970557e16400ef99d5e7cb232ec4 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:28:24 +0200 Subject: [PATCH 17/56] Use SummaryStats in place of ChainsDataFrame --- src/discretediag.jl | 14 +++++++------- src/ess_rhat.jl | 12 ++++++------ src/gelmandiag.jl | 14 +++++++------- src/gewekediag.jl | 10 +++++----- src/heideldiag.jl | 10 +++++----- src/mcse.jl | 4 ++-- src/rafterydiag.jl | 10 +++++----- 7 files changed, 37 insertions(+), 37 deletions(-) diff --git a/src/discretediag.jl b/src/discretediag.jl index 67cec670..d262aca8 100644 --- a/src/discretediag.jl +++ b/src/discretediag.jl @@ -17,16 +17,16 @@ function MCMCDiagnosticTools.discretediag( _permutedims_diagnostics(_chains.value.data); kwargs... ) - # Create dataframes - parameters = (parameters = names(_chains),) - between_chain_df = ChainDataFrame( + # Create SummaryStats + parameters = (parameter = names(_chains),) + between_chain_stats = SummaryStats( "Chisq diagnostic - Between chains", merge(parameters, between_chain_vals), ) - within_chain_dfs = map(1:size(_chains, 3)) do i + within_chain_stats = map(1:size(_chains, 3)) do i vals = map(val -> val[:, i], within_chain_vals) - return ChainDataFrame("Chisq diagnostic - Chain $i", merge(parameters, vals)) + return SummaryStats("Chisq diagnostic - Chain $i", merge(parameters, vals)) end - dfs = vcat(between_chain_df, within_chain_dfs) + stats = [between_chain_stats, within_chain_stats...] - return dfs + return stats end diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 70d86d2b..4b203d79 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -24,9 +24,9 @@ function MCMCDiagnosticTools.ess( # Convert to NamedTuple ess_per_sec = ess ./ dur - nt = merge((parameters = names(_chains),), (; ess, ess_per_sec)) + nt = merge((parameter = names(_chains),), (; ess, ess_per_sec)) - return ChainDataFrame("ESS", nt) + return SummaryStats("ESS", nt) end """ @@ -48,9 +48,9 @@ function MCMCDiagnosticTools.rhat( ) # Convert to NamedTuple - nt = merge((parameters = names(_chains),), (; rhat)) + nt = merge((parameter = names(_chains),), (; rhat)) - return ChainDataFrame("R-hat", nt) + return SummaryStats("R-hat", nt) end """ @@ -79,7 +79,7 @@ function MCMCDiagnosticTools.ess_rhat( # Convert to NamedTuple ess_per_sec = ess_rhat.ess ./ dur - nt = merge((parameters = names(_chains),), ess_rhat, (; ess_per_sec)) + nt = merge((parameter = names(_chains),), ess_rhat, (; ess_per_sec)) - return ChainDataFrame("ESS/R-hat", nt) + return SummaryStats("ESS/R-hat", nt) end diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl index 3fc894f5..39b6d23d 100644 --- a/src/gelmandiag.jl +++ b/src/gelmandiag.jl @@ -12,12 +12,12 @@ function MCMCDiagnosticTools.gelmandiag( results = MCMCDiagnosticTools.gelmandiag(_permutedims_diagnostics(psi); kwargs...) # Create a data frame with the results. - df = ChainDataFrame( + stats = SummaryStats( "Gelman, Rubin, and Brooks diagnostic", - merge((parameters = names(_chains),), results), + merge((parameter = names(_chains),), results), ) - return df + return stats end function MCMCDiagnosticTools.gelmandiag_multivariate( @@ -36,11 +36,11 @@ function MCMCDiagnosticTools.gelmandiag_multivariate( kwargs..., ) - # Create a data frame with the results. - df = ChainDataFrame( + # Create SummaryStats with the results. + stats = SummaryStats( "Gelman, Rubin, and Brooks diagnostic", - (parameters = names(_chains), psrf = results.psrf, psrfci = results.psrfci), + (parameter = names(_chains), psrf = results.psrf, psrfci = results.psrfci), ) - return df, results.psrfmultivariate + return stats, results.psrfmultivariate end diff --git a/src/gewekediag.jl b/src/gewekediag.jl index 72cbb5f2..f7c6ecad 100644 --- a/src/gewekediag.jl +++ b/src/gewekediag.jl @@ -18,12 +18,12 @@ function MCMCDiagnosticTools.gewekediag( return namedtuple_of_vecs end - # Create data frames. - parameters = (parameters = names(_chains),) - dfs = [ - ChainDataFrame("Geweke diagnostic - Chain $i", merge(parameters, result)) + # Create SummaryStats. + parameters = (parameter = names(_chains),) + stats = [ + SummaryStats("Geweke diagnostic - Chain $i", merge(parameters, result)) for (i, result) in enumerate(results) ] - return dfs + return stats end diff --git a/src/heideldiag.jl b/src/heideldiag.jl index 67def6a8..7f7acd18 100644 --- a/src/heideldiag.jl +++ b/src/heideldiag.jl @@ -16,14 +16,14 @@ function MCMCDiagnosticTools.heideldiag( return namedtuple_of_vecs end - # Create data frames. - parameters = (parameters = names(_chains),) - dfs = [ - ChainDataFrame( + # Create SummaryStats. + parameters = (parameter = names(_chains),) + stats = [ + SummaryStats( "Heidelberger and Welch diagnostic - Chain $i", merge(parameters, result) ) for (i, result) in enumerate(results) ] - return dfs + return stats end diff --git a/src/mcse.jl b/src/mcse.jl index 78ae2552..5c5a2731 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -16,7 +16,7 @@ function MCMCDiagnosticTools.mcse( kwargs..., ) - nt = merge((parameters = names(_chains),), (; mcse)) + nt = merge((parameter = names(_chains),), (; mcse)) - return ChainDataFrame("MCSE", nt) + return SummaryStats("MCSE", nt) end diff --git a/src/rafterydiag.jl b/src/rafterydiag.jl index 95126025..4dfd45a8 100644 --- a/src/rafterydiag.jl +++ b/src/rafterydiag.jl @@ -16,14 +16,14 @@ function MCMCDiagnosticTools.rafterydiag( return namedtuple_of_vecs end - # Create data frames. - parameters = (parameters = names(_chains),) - dfs = [ - ChainDataFrame( + # Create SummaryStats. + parameters = (parameter = names(_chains),) + stats = [ + SummaryStats( "Raftery and Lewis diagnostic - Chain $i", merge(parameters, result) ) for (i, result) in enumerate(results) ] - return dfs + return stats end From 706f29cab107b56cd09431dccc6a63a578888cdf Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:39:14 +0200 Subject: [PATCH 18/56] Update and repair changerate --- src/stats.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index 26f4d255..124bfb96 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -63,11 +63,11 @@ function cor( names_of_params = names(_chains) if append_chains - df = chaindataframe_cor("Correlation", names_of_params, to_matrix(_chains)) + df = summarystats_cor("Correlation", names_of_params, to_matrix(_chains)) return df else vector_of_df = [ - chaindataframe_cor( + summarystats_cor( "Correlation - Chain $i", names_of_params, data ) for (i, data) in enumerate(to_vector_of_matrices(_chains)) @@ -76,16 +76,16 @@ function cor( end end -function chaindataframe_cor(name, names_of_params, chains::AbstractMatrix; kwargs...) +function summarystats_cor(name, names_of_params, chains::AbstractMatrix; kwargs...) # Compute the correlation matrix. cormat = cor(chains) # Summarize the results in a named tuple. - nt = (; parameters = names_of_params, + nt = (; parameter = names_of_params, zip(names_of_params, (cormat[:, i] for i in axes(cormat, 2)))...) - # Create a ChainDataFrame. - return ChainDataFrame(name, nt; kwargs...) + # Create a SummaryStats. + return SummaryStats(name, nt; kwargs...) end """ @@ -109,28 +109,28 @@ function changerate( names_of_params = names(_chains) if append_chains - df = chaindataframe_changerate("Change Rate", names_of_params, _chains.value.data) - return df + stats = summarystats_changerate("Change Rate", names_of_params, _chains.value.data) + return stats else - vector_of_df = [ - chaindataframe_changerate( + vector_of_stats = [ + summarystats_changerate( "Change Rate - Chain $i", names_of_params, data ) for (i, data) in enumerate(to_vector_of_matrices(_chains)) ] - return vector_of_df + return vector_of_stats end end -function chaindataframe_changerate(name, names_of_params, chains; kwargs...) +function summarystats_changerate(name, names_of_params, chains; kwargs...) # Compute the change rates. changerates, mvchangerate = changerate(chains) # Summarize the results in a named tuple. - nt = (; zip(names_of_params, changerates)..., multivariate = mvchangerate) + nt = (; parameter=names_of_params, changerate=changerates) - # Create a ChainDataFrame. - return ChainDataFrame(name, nt; kwargs...) + # Create a SummaryStats. + return SummaryStats(name, nt; kwargs...), mvchangerate end changerate(chains::AbstractMatrix{<:Real}) = changerate(reshape(chains, Val(3))) From 2c07794a3282cb55fd3e85ed8c16d228f8b3836c Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:39:57 +0200 Subject: [PATCH 19/56] Remove ChainDataFrame --- src/MCMCChains.jl | 1 - src/summarize.jl | 123 ---------------------------------------------- src/tables.jl | 39 --------------- 3 files changed, 163 deletions(-) diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index b1e47634..dfbc73b2 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -32,7 +32,6 @@ export setrange, resetrange export set_section, get_params, sections, sort_sections, setinfo export replacenames, namesingroup, group export autocor, describe, sample, summarystats, AbstractWeights, mean, quantile -export ChainDataFrame # Reexport diagnostics functions using MCMCDiagnosticTools: discretediag, ess, ess_rhat, AutocovMethod, FFTAutocovMethod, diff --git a/src/summarize.jl b/src/summarize.jl index f8d47b90..99f73fd3 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -1,126 +1,3 @@ -struct ChainDataFrame{NT<:NamedTuple} - name::String - nt::NT - nrows::Int - ncols::Int - - function ChainDataFrame(name::String, nt::NamedTuple) - lengths = length(first(nt)) - all(x -> length(x) == lengths, nt) || error("Lengths must be equal.") - - return new{typeof(nt)}(name, nt, lengths, length(nt)) - end -end - -ChainDataFrame(nt::NamedTuple) = ChainDataFrame("", nt) - -Base.size(c::ChainDataFrame) = (c.nrows, c.ncols) -Base.names(c::ChainDataFrame) = collect(keys(c.nt)) - -# Display - -function Base.show(io::IO, df::ChainDataFrame) - print(io, df.name, " (", df.nrows, " x ", df.ncols, ")") -end - -function Base.show(io::IO, ::MIME"text/plain", df::ChainDataFrame) - digits = get(io, :digits, 4) - formatter = PrettyTables.ft_printf("%.$(digits)f") - - println(io, df.name) - # Support for PrettyTables 0.9 (`borderless`) and 0.10 (`tf_borderless`) - PrettyTables.pretty_table( - io, df.nt; - formatters = formatter, - tf = isdefined(PrettyTables, :borderless) ? PrettyTables.borderless : PrettyTables.tf_borderless, - ) -end - -Base.isequal(c1::ChainDataFrame, c2::ChainDataFrame) = isequal(c1, c2) - -# Index functions -function Base.getindex(c::ChainDataFrame, s::Union{Colon, Integer, UnitRange}, g::Union{Colon, Integer, UnitRange}) - convert(Array, getindex(c, c.nt[:parameters][s], collect(keys(c.nt))[g])) -end - -Base.getindex(c::ChainDataFrame, s::Vector{Symbol}, ::Colon) = getindex(c, s) -function Base.getindex(c::ChainDataFrame, s::Union{Symbol, Vector{Symbol}}) - getindex(c, s, collect(keys(c.nt))) -end - -function Base.getindex(c::ChainDataFrame, s::Union{Colon, Integer, UnitRange}, ks) - getindex(c, c.nt[:parameters][s], ks) -end - -# dispatches involing `String` and `AbstractVector{String}` -Base.getindex(c::ChainDataFrame, s::String, ks) = getindex(c, Symbol(s), ks) -function Base.getindex(c::ChainDataFrame, s::AbstractVector{String}, ks) - return getindex(c, Symbol.(s), ks) -end - -# dispatch for `Symbol` -Base.getindex(c::ChainDataFrame, s::Symbol, ks) = getindex(c, [s], ks) - -function Base.getindex(c::ChainDataFrame, s::AbstractVector{Symbol}, ks::Symbol) - return getindex(c, s, [ks]) -end - -function Base.getindex( - c::ChainDataFrame, - s::AbstractVector{Symbol}, - ks::AbstractVector{Symbol} -) - ind = indexin(s, c.nt[:parameters]) - - not_found = map(x -> x === nothing, ind) - - any(not_found) && error("Cannot find parameters $(s[not_found]) in chain") - - # If there are multiple columns, return a new CDF. - if length(ks) > 1 - if !(:parameters in ks) - ks = vcat(:parameters, ks) - end - nt = NamedTuple{tuple(ks...)}(tuple([c.nt[k][ind] for k in ks]...)) - return ChainDataFrame(c.name, nt) - else - # Otherwise, return a vector if there's multiple parameters - # or just a scalar if there's one parameter. - if length(s) == 1 - return c.nt[ks[1]][ind][1] - else - return c.nt[ks[1]][ind] - end - end -end - -function Base.lastindex(c::ChainDataFrame, i::Integer) - if i == 1 - return c.nrows - elseif i ==2 - return c.ncols - else - error("No such dimension") - end -end - -function Base.convert(::Type{Array}, c::C) where C<:ChainDataFrame - T = promote_eltype_namedtuple_tail(c.nt) - arr = Array{T, 2}(undef, c.nrows, c.ncols - 1) - - for (i, k) in enumerate(Iterators.drop(keys(c.nt), 1)) - arr[:, i] = c.nt[k] - end - - return arr -end - -function Base.convert(::Type{Array}, cs::Array{ChainDataFrame{T},1}) where T<:NamedTuple - return mapreduce((x, y) -> cat(x, y; dims = Val(3)), cs) do c - reshape(convert(Array, c), Val(3)) - end -end - """ summarize( chains[, stats_funs...]; diff --git a/src/tables.jl b/src/tables.jl index d5ad2e67..7a70db04 100644 --- a/src/tables.jl +++ b/src/tables.jl @@ -69,42 +69,3 @@ function IteratorInterfaceExtensions.getiterator(chn::Chains) end TableTraits.isiterabletable(::Chains) = true - -#### -#### ChainDataFrame -#### - -#### Tables interface - -Tables.istable(::Type{<:ChainDataFrame}) = true - -# AbstractColumns interface - -Tables.columnaccess(::Type{<:ChainDataFrame}) = true - -Tables.columns(cdf::ChainDataFrame) = cdf - -Tables.columnnames(::ChainDataFrame{<:NamedTuple{names}}) where {names} = names - -Tables.getcolumn(cdf::ChainDataFrame, i::Int) = cdf.nt[i] -Tables.getcolumn(cdf::ChainDataFrame, nm::Symbol) = cdf.nt[nm] - -# row access - -Tables.rowaccess(::Type{<:ChainDataFrame}) = true - -Tables.rows(cdf::ChainDataFrame) = Tables.rows(Tables.columntable(cdf)) - -function Tables.schema(::ChainDataFrame{NamedTuple{names,T}}) where {names,T} - types = ntuple(i -> eltype(fieldtype(T, i)), fieldcount(T)) - return Tables.Schema(names, types) -end - -#### TableTraits interface - -IteratorInterfaceExtensions.isiterable(::ChainDataFrame) = true -function IteratorInterfaceExtensions.getiterator(cdf::ChainDataFrame) - return Tables.datavaluerows(Tables.columntable(cdf)) -end - -TableTraits.isiterabletable(::ChainDataFrame) = true From 58ae52abda0606acd2a6e921e9d4a1a40faebc2e Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:50:34 +0200 Subject: [PATCH 20/56] Update docs --- docs/src/stats.md | 2 +- docs/src/summarize.md | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/src/stats.md b/docs/src/stats.md index 7b178ddb..c53ecac0 100644 --- a/docs/src/stats.md +++ b/docs/src/stats.md @@ -8,5 +8,5 @@ describe mean summarystats quantile -hpd +hdi ``` diff --git a/docs/src/summarize.md b/docs/src/summarize.md index e1865bd6..14ae8101 100644 --- a/docs/src/summarize.md +++ b/docs/src/summarize.md @@ -1,8 +1,11 @@ # Summarize -The methods listed below are defined in `src/summarize.jl`. +The methods listed below are related to summarizing chains. -```@autodocs -Modules = [MCMCChains] -Pages = ["summarize.jl"] +```@docs +SummaryStats +summarize +default_summary_stats +default_stats +default_diagnostics ``` From 340a694d3afc8a303d00e20d4935e2a1d925e075 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 16:51:20 +0200 Subject: [PATCH 21/56] Increment major version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d3eee142..8512d12b 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "6.0.3" +version = "7.0.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From a67ac11d824082352eb5057074312f13fc23b6fd Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 19:23:35 +0200 Subject: [PATCH 22/56] Increment MCMCChains compat for docs --- docs/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index 45d71e81..376e68f9 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -15,7 +15,7 @@ CategoricalArrays = "0.8, 0.9, 0.10" DataFrames = "0.22, 1" Documenter = "0.26, 0.27" Gadfly = "1.3.4" -MCMCChains = "6" +MCMCChains = "7" MLJBase = "0.19, 0.20, 0.21" MLJDecisionTreeInterface = "0.3, 0.4" StatsPlots = "0.14, 0.15" From 1af83c8a218639a64f90c9e5813c02a65de74860 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 21:33:03 +0200 Subject: [PATCH 23/56] Refer to processed chains --- src/summarize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/summarize.jl b/src/summarize.jl index 99f73fd3..df2a8b63 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -42,7 +42,7 @@ function PosteriorStats.summarize( if append_chains # Evaluate the functions. - data = _permutedims_diagnostics(chains.value.data) + data = _permutedims_diagnostics(chn.value.data) summarize(data, funs...; var_names=names_of_params, kwargs...) else # Evaluate the functions. From 4aed409b5124d762f9043ce6d06613e07bb3cd04 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 21:33:10 +0200 Subject: [PATCH 24/56] Fix doctest --- src/stats.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stats.jl b/src/stats.jl index 124bfb96..37c970e1 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -190,7 +190,7 @@ Note that this will return a single interval and will not return multiple interv # Examples ```jldoctest; setup = :(using Random; Random.seed!(582)) -julia> val = rand(500, 2, 3) +julia> val = rand(500, 2, 3); julia> chn = Chains(val, [:a, :b]); From eef9393af7ad0ce4ab32e8405a4b20e3b80fbb2f Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 22:22:37 +0200 Subject: [PATCH 25/56] Add back append_chains keyword --- src/stats.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/stats.jl b/src/stats.jl index 37c970e1..eb90415e 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -17,6 +17,7 @@ Setting `append_chains=false` will return a vector of dataframes containing the """ function autocor( chains::Chains; + append_chains::Bool = true, demean::Bool = true, lags::AbstractVector{<:Integer} = _default_lags(chains, append_chains), kwargs... From 6a1b482a02b0ff3e24d5865cfe6b292c13ab7b41 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 22:22:57 +0200 Subject: [PATCH 26/56] Compute all lags simultaneously --- src/stats.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index eb90415e..a32817e9 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -22,10 +22,9 @@ function autocor( lags::AbstractVector{<:Integer} = _default_lags(chains, append_chains), kwargs... ) - funs = map(lags) do lag - return Symbol("lag ", lag) => (x -> autocor(x, [i], demean=demean)[1]) - end - return summarize(chains, funs...; name = "Autocorrelation", kwargs...) + fun_names = Tuple(Symbol.("lag", lags)) + fun = (x -> autocor(x, lags; demean=demean)) + return summarize(chains, fun_names => fun; name = "Autocorrelation", append_chains, kwargs...) end """ From c7d2021b2db87ff2965d9613aae24a9a489ca706 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 22:23:12 +0200 Subject: [PATCH 27/56] Vectorize before autocor --- src/stats.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stats.jl b/src/stats.jl index a32817e9..6566de31 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -23,7 +23,7 @@ function autocor( kwargs... ) fun_names = Tuple(Symbol.("lag", lags)) - fun = (x -> autocor(x, lags; demean=demean)) + fun = (x -> autocor(vec(x), lags; demean=demean)) return summarize(chains, fun_names => fun; name = "Autocorrelation", append_chains, kwargs...) end From bfe8deb08cb5f432c88fae71a3283b6328d4379c Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 22:57:01 +0200 Subject: [PATCH 28/56] Correctly insert chain id into name --- src/summarize.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/summarize.jl b/src/summarize.jl index df2a8b63..1e54a481 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -32,6 +32,7 @@ function PosteriorStats.summarize( sections = _default_sections(chains), append_chains::Bool = true, var_names=nothing, + name::AbstractString = "SummaryStats", kwargs... ) # Generate a chain to work on. @@ -43,13 +44,14 @@ function PosteriorStats.summarize( if append_chains # Evaluate the functions. data = _permutedims_diagnostics(chn.value.data) - summarize(data, funs...; var_names=names_of_params, kwargs...) + summarize(data, funs...; var_names=names_of_params, name, kwargs...) else # Evaluate the functions. data = to_vector_of_matrices(chn) - return map(data) do x + return map(enumerate(data)) do (i, x) z = reshape(x, size(x, 1), 1, size(x, 2)) - summarize(z, funs...; var_names=names_of_params, kwargs...) + name_chain = name * " (Chain $i)" + summarize(z, funs...; var_names=names_of_params, name=name_chain, kwargs...) end end end From 646e10823451c6dc4868b0d9c6bf916c458a4a44 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 22:57:47 +0200 Subject: [PATCH 29/56] Update diagnostic tests --- test/diagnostic_tests.jl | 66 ++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/test/diagnostic_tests.jl b/test/diagnostic_tests.jl index a7a4d443..706aac3e 100644 --- a/test/diagnostic_tests.jl +++ b/test/diagnostic_tests.jl @@ -68,8 +68,8 @@ chn_disc = Chains(val_disc, start = 1, thin = 2) @test all(MCMCChains.indiscretesupport(chn) .== [false, false, false, true]) @test setinfo(chn, NamedTuple{(:A, :B)}((1,2))).info == NamedTuple{(:A, :B)}((1,2)) @test isa(set_section(chn, Dict(:internals => ["param_1"])), AbstractChains) - @test mean(chn) isa ChainDataFrame - @test mean(chn, ["param_1", "param_3"]) isa ChainDataFrame + @test mean(chn) isa SummaryStats + @test mean(chn, ["param_1", "param_3"]) isa SummaryStats @test 0.95 ≤ mean(chn, "param_1") ≤ 1.05 end @@ -167,7 +167,7 @@ end @testset "function tests" begin tchain = Chains(rand(niter, nparams, nchains), ["a", "b", "c"], Dict(:internals => ["c"])) - @test eltype(discretediag(chn_disc[:,2:2,:])) <: ChainDataFrame + @test eltype(discretediag(chn_disc[:,2:2,:])) <: SummaryStats gelman = gelmandiag(tchain) gelmanmv = gelmandiag_multivariate(tchain) @@ -176,18 +176,21 @@ end raferty = rafterydiag(tchain) # test raw return values - @test typeof(gelman) <: ChainDataFrame - @test typeof(gelmanmv) <: Tuple{ChainDataFrame,Float64} - @test typeof(geweke) <: Array{<:ChainDataFrame} - @test typeof(heidel) <: Array{<:ChainDataFrame} - @test typeof(raferty) <: Array{<:ChainDataFrame} - - # test ChainDataFrame sizes - @test size(gelman) == (2,3) - @test size(gelmanmv[1]) == (2,3) - @test size(geweke[1]) == (2,3) - @test size(heidel[1]) == (2,7) - @test size(raferty[1]) == (2,6) + @test typeof(gelman) <: SummaryStats + @test typeof(gelmanmv) <: Tuple{SummaryStats,Float64} + @test typeof(geweke) <: Array{<:SummaryStats} + @test typeof(heidel) <: Array{<:SummaryStats} + @test typeof(raferty) <: Array{<:SummaryStats} + + # test SummaryStats sizes + for s in (gelman, gelmanmv[1], geweke[1], heidel[1], raferty[1]) + @test length(s[:parameter]) == 2 + end + @test length(keys(gelman)) == 3 + @test length(keys(gelmanmv[1])) == 3 + @test length(keys(geweke[1])) == 3 + @test length(keys(heidel[1])) == 7 + @test length(keys(raferty[1])) == 6 end @testset "stats tests" begin @@ -203,31 +206,34 @@ end @test lags == filter!(x -> x < n, [1, 5, 10, 50]) acor = autocor(c; append_chains=append_chains) - # Number of columns in the ChainDataFrame(s): lags + parameters + # Number of columns in the SummaryStats: lags + parameters ncols = length(lags) + 1 if append_chains - @test acor isa ChainDataFrame - @test size(acor)[2] == ncols + @test acor isa SummaryStats + @test length(keys(acor)) == ncols else - @test acor isa Vector{<:ChainDataFrame} - @test all(size(a)[2] == ncols for a in acor) + @test acor isa Vector{<:SummaryStats} + @test all(length(keys(a)) == ncols for a in acor) end end - @test autocor(c) isa ChainDataFrame - @test convert(Array, autocor(c)) == convert(Array, autocor(c; append_chains=true)) + @test autocor(c) isa SummaryStats + @test autocor(c) == autocor(c; append_chains=true) end - @test MCMCChains.cor(chn) isa ChainDataFrame - @test MCMCChains.cor(chn; append_chains = false) isa Vector{<:ChainDataFrame} + @test MCMCChains.cor(chn) isa SummaryStats + @test MCMCChains.cor(chn; append_chains = false) isa Vector{<:SummaryStats} + + @test MCMCChains.changerate(chn) isa Tuple{SummaryStats,Float64} + @test MCMCChains.changerate(chn; append_chains = false) isa Vector{<:Tuple{SummaryStats,Float64}} - @test MCMCChains.changerate(chn) isa ChainDataFrame - @test MCMCChains.changerate(chn; append_chains = false) isa Vector{<:ChainDataFrame} + @test hdi(chn) isa SummaryStats + @test hdi(chn; append_chains = false) isa Vector{<:SummaryStats} - @test hpd(chn) isa ChainDataFrame - @test hpd(chn; append_chains = false) isa Vector{<:ChainDataFrame} + result = hdi(chn) + @test all(result[:upper] .> result[:lower]) - result = hpd(chn) - @test all(result.nt.upper .> result.nt.lower) + @test_deprecated hpd(chn) + @test hpd(chn) == hdi(chn; prob=0.95) end @testset "vector of vectors" begin From 8066b460c1650d0e1917d0301de481da15156520 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 22:58:02 +0200 Subject: [PATCH 30/56] Update ess_rhat_tests.jl --- test/ess_rhat_tests.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/ess_rhat_tests.jl b/test/ess_rhat_tests.jl index 2545c494..955f9429 100644 --- a/test/ess_rhat_tests.jl +++ b/test/ess_rhat_tests.jl @@ -18,8 +18,8 @@ using Test for f in (ess, ess_rhat) s = f(c) - @test length(s[:,:ess_per_sec]) == 5 - @test all(map(!ismissing, s[:,:ess_per_sec])) + @test length(s[:ess_per_sec]) == 5 + @test all(map(!ismissing, s[:ess_per_sec])) end end @@ -37,8 +37,8 @@ end ess_array, rhat_array = ess_rhat( permutedims(x, (1, 3, 2)); autocov_method = autocov_method, kind = kind, ) - @test ess_df[:,2] == ess_rhat_df[:,2] == ess_array - @test rhat_df[:,2] == ess_rhat_df[:,3] == rhat_array + @test ess_df[:ess] == ess_rhat_df[:ess] == ess_array + @test rhat_df[:rhat] == ess_rhat_df[:rhat] == rhat_array end end @@ -51,5 +51,5 @@ end @test_throws ArgumentError ess(chain; autocov_method = autocov_method) @test_throws ArgumentError ess_rhat(chain; autocov_method = autocov_method) end - @test all(isnan, rhat(chain)[:, 2]) + @test all(isnan, rhat(chain)[:rhat]) end From 34b1da275374ec3322ade90895da3361998b29b5 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 22:58:06 +0200 Subject: [PATCH 31/56] Update mcse_tests.jl --- test/mcse_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/mcse_tests.jl b/test/mcse_tests.jl index fdbefbb8..84f183d8 100644 --- a/test/mcse_tests.jl +++ b/test/mcse_tests.jl @@ -20,7 +20,7 @@ mymean(x) = mean(x) mcse_array = mcse( PermutedDimsArray(x, (1, 3, 2)); autocov_method = autocov_method, kind = kind, ) - @test mcse_df[:,2] == mcse_array + @test mcse_df[:mcse] == mcse_array end else # analyze chain @@ -28,7 +28,7 @@ mymean(x) = mean(x) # analyze array mcse_array = mcse(PermutedDimsArray(x, (1, 3, 2)); kind = kind) - @test mcse_df[:,2] == mcse_array + @test mcse_df[:mcse] == mcse_array end end end From 5b0db148583f2fa2c87348540a2e76e500d920f4 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 22:58:17 +0200 Subject: [PATCH 32/56] Increment MCMCChains compat for tests --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 15bc6c96..95386cbe 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,7 +30,7 @@ Documenter = "0.26, 0.27" FFTW = "1.1" IteratorInterfaceExtensions = "1" KernelDensity = "0.6.2" -MCMCChains = "6" +MCMCChains = "7" MLJBase = "0.18, 0.19, 0.20, 0.21" MLJDecisionTreeInterface = "0.3, 0.4" StatsBase = "0.33.2, 0.34" From fb04e613c4036b6bfee156bfd7f4e6a9156de752 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 22:58:32 +0200 Subject: [PATCH 33/56] Remove ChainDataFrames to tables tests --- test/tables_tests.jl | 79 -------------------------------------------- 1 file changed, 79 deletions(-) diff --git a/test/tables_tests.jl b/test/tables_tests.jl index ba3cb71c..7792d14f 100644 --- a/test/tables_tests.jl +++ b/test/tables_tests.jl @@ -128,83 +128,4 @@ using DataFrames @test isequal(Tables.columntable(df), Tables.columntable(chn)) end end - - @testset "ChainDataFrames" begin - val = rand(1000, 8, 4) - colnames = ["a", "b", "c", "d", "e", "f", "g", "h"] - internal_colnames = ["c", "d", "e", "f", "g", "h"] - chn = Chains(val, colnames, Dict(:internals => internal_colnames)) - cdf = describe(chn)[1] - - @testset "Tables interface" begin - @test Tables.istable(typeof(cdf)) - - @testset "column access" begin - @test Tables.columnaccess(typeof(cdf)) - @test Tables.columns(cdf) === cdf - @test Tables.columnnames(cdf) == keys(cdf.nt) - for (k, v) in pairs(cdf.nt) - @test isequal(Tables.getcolumn(cdf, k), v) - end - @test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1]) - @test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2]) - @test_throws Exception Tables.getcolumn(cdf, :blah) - @test_throws Exception Tables.getcolumn(cdf, length(cdf.nt) + 1) - end - - @testset "row access" begin - @test Tables.rowaccess(typeof(cdf)) - @test Tables.rows(cdf) isa Tables.RowIterator - @test eltype(Tables.rows(cdf)) <: Tables.AbstractRow - rows = collect(Tables.rows(cdf)) - @test eltype(rows) <: Tables.AbstractRow - @test size(rows) === (2,) - @testset for i in 1:2 - row = rows[i] - @test Tables.columnnames(row) == keys(cdf.nt) - for j in length(cdf.nt) - @test isequal(Tables.getcolumn(row, j), cdf.nt[j][i]) - @test isequal(Tables.getcolumn(row, keys(cdf.nt)[j]), cdf.nt[j][i]) - end - end - end - - @testset "integration tests" begin - @test length(Tables.rowtable(cdf)) == length(cdf.nt[1]) - @test isequal(Tables.columntable(cdf), cdf.nt) - nt = Tables.rowtable(cdf)[1] - @test isequal(nt, (; (k => v[1] for (k, v) in pairs(cdf.nt))...)) - @test isequal(nt, collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1]) - nt = Tables.rowtable(cdf)[2] - @test isequal(nt, (; (k => v[2] for (k, v) in pairs(cdf.nt))...)) - @test isequal(nt, collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2]) - @test isequal( - Tables.matrix(Tables.rowtable(cdf)), - Tables.matrix(Tables.columntable(cdf)), - ) - end - - @testset "schema" begin - @test Tables.schema(cdf) isa Tables.Schema - @test Tables.schema(cdf).names === keys(cdf.nt) - @test Tables.schema(cdf).types === eltype.(values(cdf.nt)) - end - end - - @testset "TableTraits interface" begin - @test IteratorInterfaceExtensions.isiterable(cdf) - @test TableTraits.isiterabletable(cdf) - nt = collect(Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 1))[1] - @test isequal(nt, (; (k => v[1] for (k, v) in pairs(cdf.nt))...)) - nt = collect(Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 2))[2] - @test isequal(nt, (; (k => v[2] for (k, v) in pairs(cdf.nt))...)) - end - - @testset "DataFrames.DataFrame constructor" begin - @inferred DataFrame(cdf) - df = DataFrame(cdf) - @test df isa DataFrame - @test isequal(Tables.columntable(df), cdf.nt) - end - end end From faf667bab1a2c5ea745d0571485bca6e22bf9513 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 22:58:41 +0200 Subject: [PATCH 34/56] Update summarize_tests.jl --- test/summarize_tests.jl | 47 +++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/test/summarize_tests.jl b/test/summarize_tests.jl index d847ec43..315d6d2e 100644 --- a/test/summarize_tests.jl +++ b/test/summarize_tests.jl @@ -1,54 +1,55 @@ using MCMCChains, Test using Statistics: std -@testset "Summarize to DataFrame tests" begin +@testset "Summarize tests" begin val = rand(1000, 8, 4) + parm_names = ["a", "b", "c", "d", "e", "f", "g", "h"] chns = Chains( val, - ["a", "b", "c", "d", "e", "f", "g", "h"], + parm_names, Dict(:internals => ["c", "d", "e", "f", "g", "h"]) ) parm_df = summarize(chns, sections=[:parameters]) + parm_array_df = summarize(PermutedDimsArray(val[:, 1:2, :], (1, 3, 2)); var_names=[:a, :b]) - # check that display of ChainDataFrame does not error + # check that display of SummaryStats does not error println("compact display:") show(stdout, parm_df) println("\nverbose display:") show(stdout, "text/plain", parm_df) - @test 0.48 < parm_df[:a, :mean][1] < 0.52 - @test names(parm_df) == [:parameters, :mean, :std, :mcse, :ess_bulk, :ess_tail, :rhat, :ess_per_sec] + @test 0.48 < parm_df[:mean][1] < 0.52 + @test parm_df == parm_array_df # Indexing tests - @test isequal(convert(Array, parm_df[:a, :]), convert(Array, parm_df[:a])) - @test parm_df[:a, :][:,:parameters] == :a - @test parm_df[[:a, :b], :][:,:parameters] == [:a, :b] + @test parm_df[:parameter] == [:a, :b] all_sections_df = summarize(chns, sections=[:parameters, :internals]) - @test all_sections_df isa ChainDataFrame - @test all_sections_df[:,:parameters] == [:a, :b, :c, :d, :e, :f, :g, :h] - @test size(all_sections_df) == (8, 8) - @test all_sections_df.name == "" + all_sections_array_df = summarize(PermutedDimsArray(val, (1, 3, 2)); var_names=Symbol.(parm_names)) + @test all_sections_df isa SummaryStats + @test all_sections_df[:parameter] == Symbol.(parm_names) + @test all_sections_array_df == all_sections_df + @test all_sections_df.name == "SummaryStats" all_sections_dfs = summarize(chns, sections=[:parameters, :internals], name = "Summary", append_chains = false) - @test all_sections_dfs isa Vector{<:ChainDataFrame} + @test all_sections_dfs isa Vector{<:SummaryStats} for (i, all_sections_df) in enumerate(all_sections_dfs) - @test all_sections_df[:,:parameters] == [:a, :b, :c, :d, :e, :f, :g, :h] - @test size(all_sections_df) == (8, 8) + @test all_sections_df[:parameter] == Symbol.(parm_names) + @test length(keys(all_sections_df)) == length(keys(all_sections_array_df)) @test all_sections_df.name == "Summary (Chain $i)" end two_parms_two_funs_df = summarize(chns[[:a, :b]], mean, std) - @test two_parms_two_funs_df[:, :parameters] == [:a, :b] - @test size(two_parms_two_funs_df) == (2, 3) + @test two_parms_two_funs_df[:parameter] == [:a, :b] + @test keys(two_parms_two_funs_df) == (:parameter, :mean, :std) three_parms_df = summarize(chns[[:a, :b, :c]], mean, std, sections=[:parameters, :internals]) - @test three_parms_df[:, :parameters] == [:a, :b, :c] - @test size(three_parms_df) == (3, 3) + @test three_parms_df[:parameter] == [:a, :b, :c] + @test keys(three_parms_df) == (:parameter, :mean, :std) - three_parms_df_2 = summarize(chns[[:a, :b, :g]], mean, std, - sections=[:parameters, :internals], func_names=[:mean, :sd]) - @test three_parms_df_2[:, :parameters] == [:a, :b, :g] - @test size(three_parms_df_2) == (3, 3) + three_parms_df_2 = summarize(chns[[:a, :b, :g]], :mymean => mean, :mystd => std, + sections=[:parameters, :internals]) + @test three_parms_df_2[:parameter] == [:a, :b, :g] + @test keys(three_parms_df_2) == (:parameter, :mymean, :mystd) end From fba84a2c2565216e0e3bfe7345a697e4e0bb0693 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 23:00:04 +0200 Subject: [PATCH 35/56] Use crossreference --- src/summarize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/summarize.jl b/src/summarize.jl index 1e54a481..69515141 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -6,7 +6,7 @@ [sections, var_names], ) -Summarize `chains` in a `PosteriorStats.SummaryStats`. +Summarize `chains` in a [`SummaryStats`](@ref). `stats_funs` is a collection of functions that reduces a matrix with shape `(draws, chains)` to a scalar or a collection of scalars. Alternatively, an item in `stats_funs` may be a From d51dd659111765ff3d14db559fde120fd1df8dca Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 23:16:51 +0200 Subject: [PATCH 36/56] Update plotting functions --- src/plot.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/plot.jl b/src/plot.jl index 1575846c..0a89266f 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -64,7 +64,7 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag) # Chains are already appended in `c` if desired, hence we use `append_chains=false` ac = autocor(c; sections = nothing, lags = lags, append_chains=false) - ac_mat = convert(Array, ac) + ac_mat = cat(reduce.(hcat, Iterators.drop.(ac, 1))...; dims=3) val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :] _AutocorPlot(lags, val) elseif st ∈ supportedplots @@ -209,7 +209,7 @@ function _compute_plot_data( ordered = false ) - chain_dic = Dict(zip(quantile(chains)[:,1], quantile(chains)[:,4])) + chain_dic = Dict(zip(quantile(chains)[2], quantile(chains)[5])) sorted_chain = sort(collect(zip(values(chain_dic), keys(chain_dic)))) sorted_par = [sorted_chain[i][2] for i in 1:length(par_names)] par = (ordered ? sorted_par : par_names) @@ -217,9 +217,9 @@ function _compute_plot_data( chain_sections = MCMCChains.group(chains, Symbol(par[i])) chain_vec = vec(chain_sections.value.data) - lower_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.lower + lower_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j])[:lower] for j in 1:length(hpdi)] - upper_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.upper + upper_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j])[:upper] for j in 1:length(hpdi)] h = _riser + spacer*(i-1) qs = quantile(chain_vec, q) From 5d13530315ec6d2b1c38c15fc77bea75ce7daa1b Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 23:36:25 +0200 Subject: [PATCH 37/56] Update to use hdi and hdi_prob --- src/plot.jl | 66 ++++++++++++++++++++++++++--------------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/src/plot.jl b/src/plot.jl index 0a89266f..9727e536 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -195,7 +195,7 @@ function _compute_plot_data( i::Integer, chains::Chains, par_names::AbstractVector{Symbol}; - hpd_val = [0.05, 0.2], + hdi_prob = [0.94, 0.8], q = [0.1, 0.9], spacer = 0.4, _riser = 0.2, @@ -203,9 +203,9 @@ function _compute_plot_data( show_mean = true, show_median = true, show_qi = false, - show_hpdi = true, + show_hdii = true, fill_q = true, - fill_hpd = false, + fill_hdi = false, ordered = false ) @@ -213,19 +213,19 @@ function _compute_plot_data( sorted_chain = sort(collect(zip(values(chain_dic), keys(chain_dic)))) sorted_par = [sorted_chain[i][2] for i in 1:length(par_names)] par = (ordered ? sorted_par : par_names) - hpdi = sort(hpd_val) + hdii = sort(hdi_prob) chain_sections = MCMCChains.group(chains, Symbol(par[i])) chain_vec = vec(chain_sections.value.data) - lower_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j])[:lower] - for j in 1:length(hpdi)] - upper_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j])[:upper] - for j in 1:length(hpdi)] + lower_hdi = [MCMCChains.hdi(chain_sections, prob = hdii[j])[:lower] + for j in 1:length(hdii)] + upper_hdi = [MCMCChains.hdi(chain_sections, prob = hdii[j])[:upper] + for j in 1:length(hdii)] h = _riser + spacer*(i-1) qs = quantile(chain_vec, q) k_density = kde(chain_vec) - if fill_hpd - x_int = filter(x -> lower_hpd[1][1] <= x <= upper_hpd[1][1], k_density.x) + if fill_hdi + x_int = filter(x -> lower_hdi[1][1] <= x <= upper_hdi[1][1], k_density.x) val = pdf(k_density, x_int) .+ h elseif fill_q x_int = filter(x -> qs[1] <= x <= qs[2], k_density.x) @@ -239,22 +239,22 @@ function _compute_plot_data( min = minimum(k_density.density .+ h) q_int = (show_qi ? [qs[1], chain_med, qs[2]] : [chain_med]) - return par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med, + return par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean, min, q_int end @recipe function f( p::RidgelinePlot; - hpd_val = [0.05, 0.2], + hdi_prob = [0.94, 0.8], q = [0.1, 0.9], spacer = 0.5, _riser = 0.2, show_mean = true, show_median = true, show_qi = false, - show_hpdi = true, + show_hdii = true, fill_q = true, - fill_hpd = false, + fill_hdi = false, ordered = false ) @@ -262,10 +262,10 @@ end par_names = p.args[2] for i in 1:length(par_names) - par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med, chain_mean, - min, q_int = _compute_plot_data(i, chn, par_names; hpd_val = hpd_val, q = q, + par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean, + min, q_int = _compute_plot_data(i, chn, par_names; hdi_prob = hdi_prob, q = q, spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median, - show_qi = show_qi, show_hpdi = show_hpdi, fill_q = fill_q, fill_hpd = fill_hpd, + show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi, ordered = ordered) yticks --> (length(par_names) > 1 ? @@ -322,28 +322,28 @@ end end @series begin seriestype := :path - label := (show_hpdi ? (i == 1 ? "$(Integer((1-hpdi[1])*100))% HPDI" : nothing) + label := (show_hdii ? (i == 1 ? "$(round(Int, hdii[i]*100))% HDI" : nothing) : nothing) - linewidth --> (show_hpdi ? 2 : 0) + linewidth --> (show_hdii ? 2 : 0) seriesalpha --> 0.80 linecolor --> :darkblue - [lower_hpd[1][1], upper_hpd[1][1]], [h, h] + [lower_hdi[1][1], upper_hdi[1][1]], [h, h] end end end @recipe function f( p::ForestPlot; - hpd_val = [0.05, 0.2], + hdi_prob = [0.94, 0.8], q = [0.1, 0.9], spacer = 0.5, _riser = 0.2, show_mean = true, show_median = true, show_qi = false, - show_hpdi = true, + show_hdii = true, fill_q = true, - fill_hpd = false, + fill_hdi = false, ordered = false ) @@ -351,25 +351,25 @@ end par_names = p.args[2] for i in 1:length(par_names) - par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med, chain_mean, - min, q_int = _compute_plot_data(i, chn, par_names; hpd_val = hpd_val, q = q, + par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean, + min, q_int = _compute_plot_data(i, chn, par_names; hdi_prob = hdi_prob, q = q, spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median, - show_qi = show_qi, show_hpdi = show_hpdi, fill_q = fill_q, fill_hpd = fill_hpd, + show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi, ordered = ordered) yticks --> (length(par_names) > 1 ? (_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default) yaxis --> (length(par_names) > 1 ? "Parameters" : "Density" ) - for j in 1:length(hpdi) + for j in 1:length(hdii) @series begin seriestype := :path - label := (show_hpdi ? - (i == 1 ? "$(Integer((1-hpdi[j])*100))% HPDI" : nothing) : nothing) + label := (show_hdii ? + (i == 1 ? "$(round(Int, hdii[j]*100))% HDI" : nothing) : nothing) linecolor --> j - linewidth --> (show_hpdi ? 1.5*j : 0) + linewidth --> (show_hdii ? 1.5*j : 0) seriesalpha --> 0.80 - [lower_hpd[j][1], upper_hpd[j][1]], [h, h] + [lower_hdi[j][1], upper_hdi[j][1]], [h, h] end end @series begin @@ -377,7 +377,7 @@ end label := (show_median ? (i == 1 ? "Median" : nothing) : nothing) markershape --> :diamond markercolor --> "#000000" - markersize --> (show_median ? length(hpdi) : 0) + markersize --> (show_median ? length(hdii) : 0) [chain_med], [h] end @series begin @@ -385,7 +385,7 @@ end label := (show_mean ? (i == 1 ? "Mean" : nothing) : nothing) markershape --> :circle markercolor --> :gray - markersize --> (show_mean ? length(hpdi) : 0) + markersize --> (show_mean ? length(hdii) : 0) [chain_mean], [h] end @series begin From 57a292168f46f4230e3d84da8ce77547ee3faedc Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 20 Aug 2023 23:48:07 +0200 Subject: [PATCH 38/56] Fix missing tests --- test/missing_tests.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/missing_tests.jl b/test/missing_tests.jl index 46b6227c..89a7ba00 100644 --- a/test/missing_tests.jl +++ b/test/missing_tests.jl @@ -5,9 +5,7 @@ using Random # Tests for missing values. function testdiff(cdf1, cdf2) - m1 = convert(Array, cdf1) - m2 = convert(Array, cdf2) - return all(((x, y),) -> isapprox(x, y; atol=1e-2), zip(m1, m2)) + return all(((x, y),) -> isapprox(x, y; atol=1e-2), Iterators.drop(zip(cdf1, cdf2), 1)) end @testset "utils" begin @@ -35,9 +33,9 @@ end rf_2 = rafterydiag(chn_m) @testset "diagnostics missing tests" for i in 1:nchains - @test testdiff(gw_1, gw_2) - @test testdiff(hd_1, hd_2) - @test testdiff(rf_1, rf_2) + @test all(Base.splat(testdiff), zip(gw_1, gw_2)) + @test all(Base.splat(testdiff), zip(hd_1, hd_2)) + @test all(Base.splat(testdiff), zip(rf_1, rf_2)) end @test_throws MethodError discretediag(chn_m) From fc96b86fad93af4b836360dadc3ec7cccd829f27 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Mon, 21 Aug 2023 00:08:15 +0200 Subject: [PATCH 39/56] Remove references not defined here. --- docs/src/summarize.md | 6 +----- src/summarize.jl | 6 +++--- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/src/summarize.md b/docs/src/summarize.md index 14ae8101..4e30ae2d 100644 --- a/docs/src/summarize.md +++ b/docs/src/summarize.md @@ -1,11 +1,7 @@ # Summarize -The methods listed below are related to summarizing chains. +The methods listed below are defined in `src/summarize.jl`. ```@docs -SummaryStats summarize -default_summary_stats -default_stats -default_diagnostics ``` diff --git a/src/summarize.jl b/src/summarize.jl index 69515141..34775926 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -6,7 +6,7 @@ [sections, var_names], ) -Summarize `chains` in a [`SummaryStats`](@ref). +Summarize `chains` in a `PosteriorStats.SummaryStats`. `stats_funs` is a collection of functions that reduces a matrix with shape `(draws, chains)` to a scalar or a collection of scalars. Alternatively, an item in `stats_funs` may be a @@ -14,8 +14,8 @@ to a scalar or a collection of scalars. Alternatively, an item in `stats_funs` m form `(name1, ...) => fun` when the function returns a collection. When the function returns a collection, the names in this latter format must be provided. -If no stats functions are provided, then those specified in [`default_summary_stats`](@ref) -are computed. +If no stats functions are provided, then those specified in +`PosteriorStats.default_summary_stats` are computed. `var_names` specifies the names of the parameters in data. If not provided, the names are inferred from data. From cbd06d20395d6413724c0fbbe6817420625469ee Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Mon, 21 Aug 2023 00:20:17 +0200 Subject: [PATCH 40/56] Improve vertical spacing --- src/chains.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/chains.jl b/src/chains.jl index 3172ec22..aede0ee7 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -346,11 +346,13 @@ function Base.show(io::IO, chains::Chains) end function Base.show(io::IO, mime::MIME"text/plain", chains::Chains) - print(io, "Chains ", chains, ":\n\n", header(chains)) + println(io, "Chains ", chains, ":\n\n", header(chains)) # Show summary stats. summaries = describe(chains) - for summary in summaries + summary, others = Iterators.peel(summaries) + show(io, mime, summary) + for summary in others println(io) println(io) show(io, mime, summary) From f2bdc6888fbb9ccbb62271800aad0f30c73c8803 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Mon, 21 Aug 2023 01:00:45 +0200 Subject: [PATCH 41/56] Bump PosteriorStats compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8512d12b..87c248f4 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ MCMCDiagnosticTools = "0.3" MLJModelInterface = "0.3.5, 0.4, 1.0" NaturalSort = "1" OrderedCollections = "1.4" -PosteriorStats = "0.1.2" +PosteriorStats = "0.1.3" PrettyTables = "0.9, 0.10, 0.11, 0.12, 1, 2" RecipesBase = "0.7, 0.8, 1.0" StatsBase = "0.33.2, 0.34" From c19b13401ee4dbe37584d029cf087334fa6c4de1 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sat, 23 Dec 2023 17:57:17 -0800 Subject: [PATCH 42/56] Bump PosteriorStats compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index dca08380..9f1e8663 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ MCMCDiagnosticTools = "0.3" MLJModelInterface = "0.3.5, 0.4, 1.0" NaturalSort = "1" OrderedCollections = "1.4" -PosteriorStats = "0.1.3" +PosteriorStats = "0.2" PrettyTables = "0.9, 0.10, 0.11, 0.12, 1, 2" RecipesBase = "0.7, 0.8, 1.0" StatsBase = "0.33.2, 0.34" From 02be9c70f49830b34ee03f8daf1f69383c402f2d Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sat, 23 Dec 2023 18:01:23 -0800 Subject: [PATCH 43/56] Make stack available for older Julia versions --- Project.toml | 2 ++ src/MCMCChains.jl | 1 + 2 files changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index 9f1e8663..c094fadb 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ version = "7.0.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0" @@ -31,6 +32,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] AbstractMCMC = "0.4, 0.5, 1.0, 2.0, 3.0, 4, 5" AxisArrays = "0.4.4" +Compat = "4.2.0" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" Formatting = "0.4" IteratorInterfaceExtensions = "0.1.1, 1" diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index dfbc73b2..10038d8f 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -1,5 +1,6 @@ module MCMCChains +using Compat: stack using AxisArrays const axes = Base.axes import AbstractMCMC From ea2548e8e0b8ba077a16ee2e9798ded63d046ae9 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sat, 23 Dec 2023 18:11:52 -0800 Subject: [PATCH 44/56] Update SummaryStats constructor user --- src/discretediag.jl | 6 +++--- src/ess_rhat.jl | 12 ++++++------ src/gelmandiag.jl | 6 ++++-- src/gewekediag.jl | 3 +-- src/heideldiag.jl | 3 +-- src/mcse.jl | 4 ++-- src/rafterydiag.jl | 3 +-- src/stats.jl | 6 +++--- 8 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/discretediag.jl b/src/discretediag.jl index d262aca8..5cb7e054 100644 --- a/src/discretediag.jl +++ b/src/discretediag.jl @@ -18,13 +18,13 @@ function MCMCDiagnosticTools.discretediag( ) # Create SummaryStats - parameters = (parameter = names(_chains),) + param_names = names(_chains) between_chain_stats = SummaryStats( - "Chisq diagnostic - Between chains", merge(parameters, between_chain_vals), + "Chisq diagnostic - Between chains", between_chain_vals, param_names, ) within_chain_stats = map(1:size(_chains, 3)) do i vals = map(val -> val[:, i], within_chain_vals) - return SummaryStats("Chisq diagnostic - Chain $i", merge(parameters, vals)) + return SummaryStats("Chisq diagnostic - Chain $i", vals, param_names) end stats = [between_chain_stats, within_chain_stats...] diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 4b203d79..19ee1774 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -24,9 +24,9 @@ function MCMCDiagnosticTools.ess( # Convert to NamedTuple ess_per_sec = ess ./ dur - nt = merge((parameter = names(_chains),), (; ess, ess_per_sec)) + nt = (; ess, ess_per_sec) - return SummaryStats("ESS", nt) + return SummaryStats("ESS", nt, names(_chains)) end """ @@ -48,9 +48,9 @@ function MCMCDiagnosticTools.rhat( ) # Convert to NamedTuple - nt = merge((parameter = names(_chains),), (; rhat)) + nt = (; rhat) - return SummaryStats("R-hat", nt) + return SummaryStats("R-hat", nt, names(_chains)) end """ @@ -79,7 +79,7 @@ function MCMCDiagnosticTools.ess_rhat( # Convert to NamedTuple ess_per_sec = ess_rhat.ess ./ dur - nt = merge((parameter = names(_chains),), ess_rhat, (; ess_per_sec)) + nt = merge(ess_rhat, (; ess_per_sec)) - return SummaryStats("ESS/R-hat", nt) + return SummaryStats("ESS/R-hat", nt, names(_chains)) end diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl index 39b6d23d..0d40c6b2 100644 --- a/src/gelmandiag.jl +++ b/src/gelmandiag.jl @@ -14,7 +14,8 @@ function MCMCDiagnosticTools.gelmandiag( # Create a data frame with the results. stats = SummaryStats( "Gelman, Rubin, and Brooks diagnostic", - merge((parameter = names(_chains),), results), + results, + names(_chains), ) return stats @@ -39,7 +40,8 @@ function MCMCDiagnosticTools.gelmandiag_multivariate( # Create SummaryStats with the results. stats = SummaryStats( "Gelman, Rubin, and Brooks diagnostic", - (parameter = names(_chains), psrf = results.psrf, psrfci = results.psrfci), + (psrf = results.psrf, psrfci = results.psrfci), + names(_chains), ) return stats, results.psrfmultivariate diff --git a/src/gewekediag.jl b/src/gewekediag.jl index f7c6ecad..83af5b67 100644 --- a/src/gewekediag.jl +++ b/src/gewekediag.jl @@ -19,9 +19,8 @@ function MCMCDiagnosticTools.gewekediag( end # Create SummaryStats. - parameters = (parameter = names(_chains),) stats = [ - SummaryStats("Geweke diagnostic - Chain $i", merge(parameters, result)) + SummaryStats("Geweke diagnostic - Chain $i", result, names(_chains)) for (i, result) in enumerate(results) ] diff --git a/src/heideldiag.jl b/src/heideldiag.jl index 7f7acd18..5db6cae4 100644 --- a/src/heideldiag.jl +++ b/src/heideldiag.jl @@ -17,10 +17,9 @@ function MCMCDiagnosticTools.heideldiag( end # Create SummaryStats. - parameters = (parameter = names(_chains),) stats = [ SummaryStats( - "Heidelberger and Welch diagnostic - Chain $i", merge(parameters, result) + "Heidelberger and Welch diagnostic - Chain $i", result, names(_chains), ) for (i, result) in enumerate(results) ] diff --git a/src/mcse.jl b/src/mcse.jl index 5c5a2731..677f40b0 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -16,7 +16,7 @@ function MCMCDiagnosticTools.mcse( kwargs..., ) - nt = merge((parameter = names(_chains),), (; mcse)) + nt = (; mcse) - return SummaryStats("MCSE", nt) + return SummaryStats("MCSE", nt, names(_chains)) end diff --git a/src/rafterydiag.jl b/src/rafterydiag.jl index 4dfd45a8..7fd97a5e 100644 --- a/src/rafterydiag.jl +++ b/src/rafterydiag.jl @@ -17,10 +17,9 @@ function MCMCDiagnosticTools.rafterydiag( end # Create SummaryStats. - parameters = (parameter = names(_chains),) stats = [ SummaryStats( - "Raftery and Lewis diagnostic - Chain $i", merge(parameters, result) + "Raftery and Lewis diagnostic - Chain $i", result, names(_chains), ) for (i, result) in enumerate(results) ] diff --git a/src/stats.jl b/src/stats.jl index 6566de31..a3957838 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -122,15 +122,15 @@ function changerate( end end -function summarystats_changerate(name, names_of_params, chains; kwargs...) +function summarystats_changerate(name, names_of_params, chains) # Compute the change rates. changerates, mvchangerate = changerate(chains) # Summarize the results in a named tuple. - nt = (; parameter=names_of_params, changerate=changerates) + nt = (; changerate=changerates) # Create a SummaryStats. - return SummaryStats(name, nt; kwargs...), mvchangerate + return SummaryStats(name, nt, names_of_params), mvchangerate end changerate(chains::AbstractMatrix{<:Real}) = changerate(reshape(chains, Val(3))) From 929c6546ee3684bd8bcb94304a41308653abc582 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sat, 23 Dec 2023 18:12:12 -0800 Subject: [PATCH 45/56] Use dict backing for cor summary --- src/stats.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index a3957838..16f26605 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -76,16 +76,15 @@ function cor( end end -function summarystats_cor(name, names_of_params, chains::AbstractMatrix; kwargs...) +function summarystats_cor(name, names_of_params, chains::AbstractMatrix) # Compute the correlation matrix. cormat = cor(chains) - # Summarize the results in a named tuple. - nt = (; parameter = names_of_params, - zip(names_of_params, (cormat[:, i] for i in axes(cormat, 2)))...) + # Summarize the results in a dict + dict = OrderedCollections.OrderedDict(zip(names_of_params, eachcol(cormat))) # Create a SummaryStats. - return SummaryStats(name, nt; kwargs...) + return SummaryStats(name, dict, names_of_params) end """ From 5f7aa7c0ebfbc00e6f27a5ee39a2eb2dd5679f20 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sat, 23 Dec 2023 18:15:13 -0800 Subject: [PATCH 46/56] Remove unused kwargs --- src/stats.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index 16f26605..002a3684 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -43,7 +43,7 @@ function _default_lags(chains::Chains, append_chains::Bool) end """ - cor(chains[; sections, append_chains = true, kwargs...]) + cor(chains[; sections, append_chains = true]) Compute the Pearson correlation matrix for the chain. @@ -54,7 +54,6 @@ function cor( chains::Chains; sections = _default_sections(chains), append_chains = true, - kwargs... ) # Subset the chain. _chains = Chains(chains, _clean_sections(chains, sections)) @@ -88,7 +87,7 @@ function summarystats_cor(name, names_of_params, chains::AbstractMatrix) end """ - changerate(chains[; sections, append_chains = true, kwargs...]) + changerate(chains[; sections, append_chains = true]) Compute the change rate for the chain. @@ -99,7 +98,6 @@ function changerate( chains::Chains{<:Real}; sections = _default_sections(chains), append_chains = true, - kwargs... ) # Subset the chain. _chains = Chains(chains, _clean_sections(chains, sections)) From 49fe1c88bb6b97ee1de0d4d64a0580f2acd0781e Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sat, 23 Dec 2023 18:15:33 -0800 Subject: [PATCH 47/56] Refactor autocor to avoid large namedtuple --- src/stats.jl | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index 002a3684..7858fe2a 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -17,14 +17,43 @@ Setting `append_chains=false` will return a vector of dataframes containing the """ function autocor( chains::Chains; + sections = _default_sections(chains), append_chains::Bool = true, demean::Bool = true, lags::AbstractVector{<:Integer} = _default_lags(chains, append_chains), - kwargs... + var_names = nothing, + kwargs..., ) - fun_names = Tuple(Symbol.("lag", lags)) - fun = (x -> autocor(vec(x), lags; demean=demean)) - return summarize(chains, fun_names => fun; name = "Autocorrelation", append_chains, kwargs...) + chn = Chains(chains, _clean_sections(chains, sections)) + + # Obtain names of parameters. + names_of_params = var_names === nothing ? names(chn) : var_names + + # set up the functions to be evaluated + col_names = Symbol.("lag", lags) + + # avoids using summarize directly to support simultaneously computing a large number of + # lags without constructing a huge NamedTuple + if append_chains + # Evaluate the functions. + data = _permutedims_diagnostics(chn.value.data) + vals = stack(map(eachslice(data; dims=3)) do x + return autocor(vec(x), lags; demean=demean) + end) + table = Tables.table(vals'; header=col_names) + return SummaryStats("Autocorrelation", table, names_of_params) + else + # Evaluate the functions. + data = to_vector_of_matrices(chn) + return map(enumerate(data)) do (i, x) + name_chain = "Autocorrelation (Chain $i)" + vals = stack(map(eachslice(x; dims=2)) do xi + return autocor(xi, lags; demean=demean) + end) + table = Tables.table(vals'; header=col_names) + return SummaryStats(name_chain, table, names_of_params) + end + end end """ From 1cef3ba24c7c3ee6139f18ecb22c503bfb457a21 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sat, 23 Dec 2023 18:15:53 -0800 Subject: [PATCH 48/56] Use stack for autocor to avoid large compile times --- src/plot.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plot.jl b/src/plot.jl index 9727e536..6a3bb38b 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -64,7 +64,7 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag) # Chains are already appended in `c` if desired, hence we use `append_chains=false` ac = autocor(c; sections = nothing, lags = lags, append_chains=false) - ac_mat = cat(reduce.(hcat, Iterators.drop.(ac, 1))...; dims=3) + ac_mat = stack(map(stack ∘ Base.Fix2(Iterators.drop, 1), ac)) val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :] _AutocorPlot(lags, val) elseif st ∈ supportedplots From 17e08ed4bc9016923fedbefdb326a939be8211a3 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sat, 23 Dec 2023 19:24:49 -0800 Subject: [PATCH 49/56] Improve type inference of OrderedDict --- src/stats.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stats.jl b/src/stats.jl index 7858fe2a..41a814dd 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -109,7 +109,7 @@ function summarystats_cor(name, names_of_params, chains::AbstractMatrix) cormat = cor(chains) # Summarize the results in a dict - dict = OrderedCollections.OrderedDict(zip(names_of_params, eachcol(cormat))) + dict = OrderedCollections.OrderedDict(k => v for (k, v) in zip(names_of_params, eachcol(cormat))) # Create a SummaryStats. return SummaryStats(name, dict, names_of_params) From 1e3e96a52382ca7b9b04f99e09eaecc76e9df8e7 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sat, 23 Dec 2023 19:51:34 -0800 Subject: [PATCH 50/56] Make doctest reproducible --- docs/Project.toml | 2 ++ src/stats.jl | 12 +++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 52ca0f6a..506ee3e4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,6 +7,7 @@ Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" [compat] @@ -18,5 +19,6 @@ Gadfly = "1.3.4" MCMCChains = "7" MLJBase = "0.19, 0.20, 0.21, 1" MLJDecisionTreeInterface = "0.3, 0.4" +StableRNGs = "1" StatsPlots = "0.14, 0.15" julia = "1.7" diff --git a/src/stats.jl b/src/stats.jl index 41a814dd..bd5834b5 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -215,16 +215,18 @@ Note that this will return a single interval and will not return multiple interv # Examples -```jldoctest; setup = :(using Random; Random.seed!(582)) -julia> val = rand(500, 2, 3); +```jldoctest +julia> using StableRNGs; rng = StableRNG(42); + +julia> val = rand(rng, 500, 2, 3); julia> chn = Chains(val, [:a, :b]); julia> hdi(chn) HDI - lower upper - a 0.0749 0.999 - b 0.00531 0.940 + lower upper + a 0.0630 0.994 + b 0.0404 0.968 ``` """ function PosteriorStats.hdi(chn::Chains; prob::Real=0.94, kwargs...) From d00189449f0794f88462a09f5ad9935b50625be8 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sat, 23 Dec 2023 23:13:24 -0800 Subject: [PATCH 51/56] Add StableRNGs as test dependency --- test/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index f9f4e97e..d9a20ba1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" @@ -33,6 +34,7 @@ KernelDensity = "0.6.2" MCMCChains = "7" MLJBase = "0.18, 0.19, 0.20, 0.21, 1" MLJDecisionTreeInterface = "0.3, 0.4" +StableRNGs = "1" StatsBase = "0.33.2, 0.34" StatsPlots = "0.14.17, 0.15" TableTraits = "1" From b3a31431cbfb5189608d7d7361708d50d775d45d Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 11 Feb 2024 00:47:32 +0100 Subject: [PATCH 52/56] Apply suggestions from code review Co-authored-by: David Widmann <devmotion@users.noreply.github.com> Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com> --- src/MCMCChains.jl | 3 +-- src/stats.jl | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index 10038d8f..0ccad7b6 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -50,8 +50,7 @@ export rstar # Reexport stats functions using PosteriorStats: SummaryStats, default_diagnostics, default_stats, default_summary_stats, hdi, summarize -export SummaryStats, default_diagnostics, default_stats, default_summary_stats, hdi, - summarize +export SummaryStats, hdi, summarize """ Chains diff --git a/src/stats.jl b/src/stats.jl index bd5834b5..17f5f035 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -29,13 +29,12 @@ function autocor( # Obtain names of parameters. names_of_params = var_names === nothing ? names(chn) : var_names - # set up the functions to be evaluated + # Construct column names for lags. col_names = Symbol.("lag", lags) # avoids using summarize directly to support simultaneously computing a large number of # lags without constructing a huge NamedTuple if append_chains - # Evaluate the functions. data = _permutedims_diagnostics(chn.value.data) vals = stack(map(eachslice(data; dims=3)) do x return autocor(vec(x), lags; demean=demean) @@ -43,7 +42,6 @@ function autocor( table = Tables.table(vals'; header=col_names) return SummaryStats("Autocorrelation", table, names_of_params) else - # Evaluate the functions. data = to_vector_of_matrices(chn) return map(enumerate(data)) do (i, x) name_chain = "Autocorrelation (Chain $i)" From 0fad7f9a9e4722b90ae764cfd204fababc42a5ad Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 11 Feb 2024 11:44:40 +0100 Subject: [PATCH 53/56] Avoid splatting --- src/discretediag.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/discretediag.jl b/src/discretediag.jl index 5cb7e054..0b3d699c 100644 --- a/src/discretediag.jl +++ b/src/discretediag.jl @@ -26,7 +26,7 @@ function MCMCDiagnosticTools.discretediag( vals = map(val -> val[:, i], within_chain_vals) return SummaryStats("Chisq diagnostic - Chain $i", vals, param_names) end - stats = [between_chain_stats, within_chain_stats...] + stats = vcat([between_chain_stats], within_chain_stats) return stats end From 3579c55b355b2e59e9c428b4e24dbe9fb22de421 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 11 Feb 2024 18:52:36 +0100 Subject: [PATCH 54/56] Avoid recomputing all medians for every parameter --- src/plot.jl | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/plot.jl b/src/plot.jl index 6a3bb38b..19452003 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -206,16 +206,10 @@ function _compute_plot_data( show_hdii = true, fill_q = true, fill_hdi = false, - ordered = false ) + hdii = sort(hdi_prob; rev=true) - chain_dic = Dict(zip(quantile(chains)[2], quantile(chains)[5])) - sorted_chain = sort(collect(zip(values(chain_dic), keys(chain_dic)))) - sorted_par = [sorted_chain[i][2] for i in 1:length(par_names)] - par = (ordered ? sorted_par : par_names) - hdii = sort(hdi_prob) - - chain_sections = MCMCChains.group(chains, Symbol(par[i])) + chain_sections = MCMCChains.group(chains, Symbol(par_names[i])) chain_vec = vec(chain_sections.value.data) lower_hdi = [MCMCChains.hdi(chain_sections, prob = hdii[j])[:lower] for j in 1:length(hdii)] @@ -239,7 +233,7 @@ function _compute_plot_data( min = minimum(k_density.density .+ h) q_int = (show_qi ? [qs[1], chain_med, qs[2]] : [chain_med]) - return par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, + return par_names, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean, min, q_int end @@ -261,12 +255,16 @@ end chn = p.args[1] par_names = p.args[2] + if ordered + par_table_names, par_medians = summarize(chn[:, par_names, :], median) + par_names = par_table_names[sortperm(par_medians)] + end + for i in 1:length(par_names) par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean, min, q_int = _compute_plot_data(i, chn, par_names; hdi_prob = hdi_prob, q = q, spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median, - show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi, - ordered = ordered) + show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi) yticks --> (length(par_names) > 1 ? (_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default) @@ -350,12 +348,16 @@ end chn = p.args[1] par_names = p.args[2] + if ordered + par_table_names, par_medians = summarize(chn[:, par_names, :], median) + par_names = par_table_names[sortperm(par_medians)] + end + for i in 1:length(par_names) par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean, min, q_int = _compute_plot_data(i, chn, par_names; hdi_prob = hdi_prob, q = q, spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median, - show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi, - ordered = ordered) + show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi) yticks --> (length(par_names) > 1 ? (_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default) From 53329164285b10edcf646f2aaf98a678636da6a6 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 11 Feb 2024 20:22:27 +0100 Subject: [PATCH 55/56] Add test for show method --- test/runtests.jl | 4 ++++ test/show_tests.jl | 23 +++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 test/show_tests.jl diff --git a/test/runtests.jl b/test/runtests.jl index dfa37395..0c0e8819 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,6 +60,10 @@ Random.seed!(0) println("Model statistics") @time include("modelstats_test.jl") + # run tests for show methods + println("Show methods") + @time include("show_tests.jl") + # run tests for concatenation println("Concatenation") @time include("concatenation_tests.jl") diff --git a/test/show_tests.jl b/test/show_tests.jl new file mode 100644 index 00000000..f1a6b9b8 --- /dev/null +++ b/test/show_tests.jl @@ -0,0 +1,23 @@ +using Test +using MCMCChains + +@testset "Show tests" begin + rng = MersenneTwister(1234) + val = rand(rng, 100, 4, 4) + parm_names = ["a", "b", "c", "d"] + chns = Chains(val, parm_names, Dict(:internals => ["b", "d"]))[1:2:99, :, :] + str = sprint(show, "text/plain", chns) + stats_str = sprint(show, "text/plain", summarystats(chns)) + quantile_str = sprint(show, "text/plain", quantile(chns)) + @test str == """Chains MCMC chain (50×4×4 Array{Float64, 3}): + + Iterations = 1:2:99 + Number of chains = 4 + Samples per chain = 50 + parameters = a, c + internals = b, d + + $stats_str + + $quantile_str""" +end From d4882200aa05dee6f34ce7c0bbfa473c35794c30 Mon Sep 17 00:00:00 2001 From: Seth Axen <seth@sethaxen.com> Date: Sun, 17 Nov 2024 22:02:27 +0100 Subject: [PATCH 56/56] Update ess_rhat tests --- test/ess_rhat_tests.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/ess_rhat_tests.jl b/test/ess_rhat_tests.jl index df487f7d..63adb79c 100644 --- a/test/ess_rhat_tests.jl +++ b/test/ess_rhat_tests.jl @@ -49,15 +49,15 @@ end for autocov_method in (AutocovMethod(), FFTAutocovMethod(), BDAAutocovMethod()) # analyze chain ess_df = ess(chain; autocov_method = autocov_method) - @test isequal(ess_df[:, :ess], fill(NaN, 5)) - @test isequal(ess_df[:, :ess_per_sec], fill(missing, 5)) + @test isequal(ess_df[:ess], fill(NaN, 5)) + @test isequal(ess_df[:ess_per_sec], fill(missing, 5)) ess_rhat_df = ess_rhat(chain; autocov_method = autocov_method) - @test isequal(ess_rhat_df[:, :ess], fill(NaN, 5)) - @test isequal(ess_rhat_df[:, :rhat], fill(NaN, 5)) - @test isequal(ess_rhat_df[:, :ess_per_sec], fill(missing, 5)) + @test isequal(ess_rhat_df[:ess], fill(NaN, 5)) + @test isequal(ess_rhat_df[:rhat], fill(NaN, 5)) + @test isequal(ess_rhat_df[:ess_per_sec], fill(missing, 5)) end rhat_df = rhat(chain) - @test isequal(rhat_df[:, :rhat], fill(NaN, 5)) + @test isequal(rhat_df[:rhat], fill(NaN, 5)) end