From 43a56305eb7b2a0b55f0126952a1a0f218ecc0fd Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 15:45:21 +0200
Subject: [PATCH 01/56] Add PosteriorStats as dependency

---
 Project.toml | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/Project.toml b/Project.toml
index 5c5bca73..d3eee142 100644
--- a/Project.toml
+++ b/Project.toml
@@ -18,6 +18,7 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
 MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
 NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
 OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
+PosteriorStats = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5"
 PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
@@ -38,6 +39,7 @@ MCMCDiagnosticTools = "0.3"
 MLJModelInterface = "0.3.5, 0.4, 1.0"
 NaturalSort = "1"
 OrderedCollections = "1.4"
+PosteriorStats = "0.1.2"
 PrettyTables = "0.9, 0.10, 0.11, 0.12, 1, 2"
 RecipesBase = "0.7, 0.8, 1.0"
 StatsBase = "0.33.2, 0.34"

From 2994a561e89dd34013085bc64a6681d8b04c8e74 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 15:45:58 +0200
Subject: [PATCH 02/56] Import and reexport PosteriorStats functions

---
 src/MCMCChains.jl | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl
index 3422620f..50cdf582 100644
--- a/src/MCMCChains.jl
+++ b/src/MCMCChains.jl
@@ -16,6 +16,7 @@ import MCMCDiagnosticTools
 import MLJModelInterface
 import NaturalSort
 import OrderedCollections
+import PosteriorStats
 import PrettyTables
 import StatsFuns
 import Tables
@@ -32,7 +33,6 @@ export set_section, get_params, sections, sort_sections, setinfo
 export replacenames, namesingroup, group
 export autocor, describe, sample, summarystats, AbstractWeights, mean, quantile
 export ChainDataFrame
-export summarize
 
 # Reexport diagnostics functions
 using MCMCDiagnosticTools: discretediag, ess, ess_rhat, AutocovMethod, FFTAutocovMethod,
@@ -48,6 +48,10 @@ export rafterydiag
 export rstar
 
 export hpd
+# Reexport stats functions
+using PosteriorStats: default_diagnostics, default_stats, default_summary_stats, hdi,
+    summarize
+export default_diagnostics, default_stats, default_summary_stats, hdi, summarize
 
 """
     Chains

From 9ffb60b24991b01087c15af5f4d09ffa2d9f8f6c Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 15:47:12 +0200
Subject: [PATCH 03/56] Forward to PosteriorStats.summarize

---
 src/summarize.jl | 44 ++++++++------------------------------------
 1 file changed, 8 insertions(+), 36 deletions(-)

diff --git a/src/summarize.jl b/src/summarize.jl
index 70bba5d8..8e723da3 100644
--- a/src/summarize.jl
+++ b/src/summarize.jl
@@ -133,56 +133,28 @@ Summarize `chains` in a `ChainsDataFrame`.
 * `summarize(chns; sections=[:parameters])`  : Chain summary of :parameters section
 * `summarize(chns; sections=[:parameters, :internals])` : Chain summary for multiple sections
 """
-function summarize(
+function PosteriorStats.summarize(
     chains::Chains, funs...;
     sections = _default_sections(chains),
-    func_names::AbstractVector{Symbol} = Symbol[],
     append_chains::Bool = true,
-    name::String = "",
-    additional_df = nothing
+    kwargs...
 )
-    # If we weren't given any functions, fall back to summary stats.
-    if isempty(funs)
-        return summarystats(chains; sections, append_chains, name)
-    end
-
     # Generate a chain to work on.
     chn = Chains(chains, _clean_sections(chains, sections))
 
     # Obtain names of parameters.
     names_of_params = names(chn)
 
-    # If no function names were given, make a new list.
-    fnames = isempty(func_names) ? collect(nameof.(funs)) : func_names
-
-    # Obtain the additional named tuple.
-    additional_nt = additional_df === nothing ? NamedTuple() : additional_df.nt
-
     if append_chains
         # Evaluate the functions.
-        data = to_matrix(chn)
-        fvals = [[f(data[:, i]) for i in axes(data, 2)] for f in funs]
-
-        # Build the ChainDataFrame.
-        nt = merge((; parameters = names_of_params, zip(fnames, fvals)...), additional_nt)
-        df = ChainDataFrame(name, nt)
-
-        return df
+        data = _permutedims_diagnostics(chains.value.data)
+        summarize(data, funs...; var_names=names_of_params, kwargs...)
     else
         # Evaluate the functions.
         data = to_vector_of_matrices(chn)
-        vector_of_fvals = [[[f(x[:, i]) for i in axes(x, 2)] for f in funs] for x in data]
-
-        # Build the ChainDataFrames.
-        vector_of_nt = [
-            merge((; parameters = names_of_params, zip(fnames, fvals)...), additional_nt)
-            for fvals in vector_of_fvals
-        ]
-        vector_of_df = [
-            ChainDataFrame(name * " (Chain $i)", nt)
-            for (i, nt) in enumerate(vector_of_nt)
-        ]
-
-        return vector_of_df
+        return map(data) do x
+            z = reshape(x, size(x, 1), 1, size(x, 2))
+            summarize(z, funs...; var_names=names_of_params, kwargs...)
+        end
     end
 end

From a6d16e5cd07f8beb299c08dd29115f06d98a6560 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 15:48:01 +0200
Subject: [PATCH 04/56] Update docstring

---
 src/summarize.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/summarize.jl b/src/summarize.jl
index 8e723da3..fecf133b 100644
--- a/src/summarize.jl
+++ b/src/summarize.jl
@@ -122,9 +122,9 @@ function Base.convert(::Type{Array}, cs::Array{ChainDataFrame{T},1}) where T<:Na
 end
 
 """
-    summarize(chains, funs...[; sections, func_names = [], name = "", append_chains = true])
+    summarize(chains, funs...[; sections, name = "", append_chains = true])
 
-Summarize `chains` in a `ChainsDataFrame`.
+Summarize `chains` in a `PosteriorStats.SummaryStats`.
 
 # Examples
 

From 4e99fe8f069f6908ede8f9d3bd8b2994fd22c119 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 15:57:40 +0200
Subject: [PATCH 05/56] Update docstring

---
 src/summarize.jl | 22 ++++++++++++++++++++--
 1 file changed, 20 insertions(+), 2 deletions(-)

diff --git a/src/summarize.jl b/src/summarize.jl
index fecf133b..f8d47b90 100644
--- a/src/summarize.jl
+++ b/src/summarize.jl
@@ -122,10 +122,27 @@ function Base.convert(::Type{Array}, cs::Array{ChainDataFrame{T},1}) where T<:Na
 end
 
 """
-    summarize(chains, funs...[; sections, name = "", append_chains = true])
+    summarize(
+        chains[, stats_funs...];
+        append_chains=true,
+        name="SummaryStats",
+        [sections, var_names],
+    )
 
 Summarize `chains` in a `PosteriorStats.SummaryStats`.
 
+`stats_funs` is a collection of functions that reduces a matrix with shape `(draws, chains)`
+to a scalar or a collection of scalars. Alternatively, an item in `stats_funs` may be a
+`Pair` of the form `name => fun` specifying the name to be used for the statistic or of the
+form `(name1, ...) => fun` when the function returns a collection. When the function returns
+a collection, the names in this latter format must be provided.
+
+If no stats functions are provided, then those specified in [`default_summary_stats`](@ref)
+are computed.
+
+`var_names` specifies the names of the parameters in data. If not provided, the names are
+inferred from data.
+
 # Examples
 
 * `summarize(chns)` : Complete chain summary
@@ -137,13 +154,14 @@ function PosteriorStats.summarize(
     chains::Chains, funs...;
     sections = _default_sections(chains),
     append_chains::Bool = true,
+    var_names=nothing,
     kwargs...
 )
     # Generate a chain to work on.
     chn = Chains(chains, _clean_sections(chains, sections))
 
     # Obtain names of parameters.
-    names_of_params = names(chn)
+    names_of_params = var_names === nothing ? names(chn) : var_names
 
     if append_chains
         # Evaluate the functions.

From 38dbdac51a98db22329e590b90e63c9201bbec71 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 15:58:11 +0200
Subject: [PATCH 06/56] Forward summarystats to summarize

---
 src/stats.jl | 85 ++++------------------------------------------------
 1 file changed, 6 insertions(+), 79 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index e31044ea..fb0f6f57 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -266,88 +266,15 @@ end
 
 
 """
-    function summarystats(
-        chains;
-        sections = _default_sections(chains),
-        append_chains= true,
-        autocov_method::AbstractAutocovMethod = AutocovMethod(),
-        maxlag = 250,
-        kwargs...
-    )
-
-Compute the mean, standard deviation, Monte Carlo standard error, bulk- and tail- effective
-sample size, and ``\\widehat{R}`` diagnostic for each parameter in the chain.
+    summarystats(chains; kwargs...)
 
-Setting `append_chains=false` will return a vector of dataframes containing the summary
-statistics for each chain.
+Compute default summary statistics from the `chains`.
 
-When estimating the effective sample size, autocorrelations are computed for at most `maxlag` lags.
+`kwargs` are forwarded to [`summarize`](@ref). To customize the summary statistics, see
+`summarize`.
 """
-function summarystats(
-    chains::Chains;
-    sections = _default_sections(chains),
-    append_chains::Bool = true,
-    autocov_method::MCMCDiagnosticTools.AbstractAutocovMethod = AutocovMethod(),
-    maxlag = 250,
-    name = "Summary Statistics",
-    kwargs...
-)
-    # Store everything.
-    funs = [mean∘cskip, std∘cskip]
-    func_names = [:mean, :std]
-
-    # Subset the chain.
-    _chains = Chains(chains, _clean_sections(chains, sections))
-
-    # Calculate MCSE and ESS/R-hat separately.
-    nt_additional = NamedTuple()
-    try
-        mcse_df = MCMCDiagnosticTools.mcse(
-            _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag,
-        )
-        nt_additional = merge(nt_additional, (; mcse=mcse_df.nt.mcse))
-    catch e
-        @warn "MCSE calculation failed: $e"
-    end
-
-    try
-        ess_tail_df = MCMCDiagnosticTools.ess(
-            _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:tail
-        )
-        nt_additional = merge(nt_additional, (ess_tail=ess_tail_df.nt.ess,))
-    catch e
-        @warn "Tail ESS calculation failed: $e"
-    end
-
-    try
-        ess_rhat_rank_df = MCMCDiagnosticTools.ess_rhat(
-            _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:rank
-        )
-        nt_ess_rhat_rank = (
-            ess_bulk=ess_rhat_rank_df.nt.ess,
-            rhat=ess_rhat_rank_df.nt.rhat,
-            ess_per_sec=ess_rhat_rank_df.nt.ess_per_sec
-        )
-        nt_additional = merge(nt_additional, nt_ess_rhat_rank)
-    catch e
-        @warn "Bulk ESS/R-hat calculation failed: $e"
-    end
-
-    # Possibly re-order the columns to stay backwards-compatible.
-    additional_keys = (:mcse, :ess_bulk, :ess_tail, :rhat, :ess_per_sec)
-    additional_df = ChainDataFrame("Additional", (; ((k, nt_additional[k]) for k in additional_keys if k ∈ keys(nt_additional))...))
-
-    # Summarize.
-    summary_df = summarize(
-        _chains, funs...;
-        func_names,
-        append_chains,
-        additional_df,
-        name,
-        sections = nothing
-    )
-
-    return summary_df
+function summarystats(chains::Chains; name = "Summary Statistics", kwargs...)
+    return summarize(chains; name, kwargs...)
 end
 
 """

From 395a3d8aa02ae7f313d105dd418ff862d9008578 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 15:58:44 +0200
Subject: [PATCH 07/56] Simplify mean implementation

---
 src/stats.jl | 14 +-------------
 1 file changed, 1 insertion(+), 13 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index fb0f6f57..fa044efb 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -283,19 +283,7 @@ end
 Calculate the mean of a chain.
 """
 function mean(chains::Chains; kwargs...)
-    # Store everything.
-    funs = [mean∘cskip]
-    func_names = [:mean]
-
-    # Summarize.
-    summary_df = summarize(
-        chains, funs...;
-        func_names = func_names,
-        name = "Mean",
-        kwargs...
-    )
-
-    return summary_df
+    return summarize(chains, :mean => mean ∘ cskip; name = "Mean", kwargs...)
 end
 
 mean(chn::Chains, syms) = mean(chn[:, syms, :])

From 4fa204ee80331e2b65d4a9f100a5aff9f1cb4ece Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:00:15 +0200
Subject: [PATCH 08/56] Simplify quantile implementation

---
 src/stats.jl | 17 ++++-------------
 1 file changed, 4 insertions(+), 13 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index fa044efb..a042e91c 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -245,23 +245,14 @@ for each chain.
 function quantile(
     chains::Chains;
     q::AbstractVector = [0.025, 0.25, 0.5, 0.75, 0.975],
-    append_chains = true,
     kwargs...
 )
     # compute quantiles
-    funs = Function[]
-    func_names = @. Symbol(100 * q, :%)
-    for i in q
-        push!(funs, x -> quantile(cskip(x), i))
+    stats_funs = map(q) do qi
+        nm = Symbol(100 * qi, :%)
+        return nm => Base.Fix2(quantile, qi) ∘ cskip
     end
-
-    return summarize(
-        chains, funs...;
-        func_names = func_names,
-        append_chains = append_chains,
-        name = "Quantiles",
-        kwargs...
-    )
+    return summarize(chains, stats_funs...; name = "Quantiles", kwargs...)
 end
 
 

From 5c0f35d02244a717379b7424d434366c310f08b7 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:01:40 +0200
Subject: [PATCH 09/56] Replace hpd with hdi

---
 src/stats.jl | 40 ++++++++++++----------------------------
 1 file changed, 12 insertions(+), 28 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index a042e91c..d5118286 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -193,45 +193,29 @@ function describe(
     return dfs
 end
 
-function _hpd(x::AbstractVector{<:Real}; alpha::Real=0.05)
-    n = length(x)
-    m = max(1, ceil(Int, alpha * n))
-
-    y = sort(x)
-    a = y[1:m]
-    b = y[(n - m + 1):n]
-    _, i = findmin(b - a)
-
-    return [a[i], b[i]]
-end
-
 """
-    hpd(chn::Chains; alpha::Real=0.05, kwargs...)
+    hdi(chn::Chains; prob::Real=0.94, kwargs...)
 
-Return the highest posterior density interval representing `1-alpha` probability mass.
+Return the unimodal highest density interval (HDI) representing `prob` probability mass.
 
 Note that this will return a single interval and will not return multiple intervals for discontinuous regions.
 
 # Examples
 
-```julia-repl
-julia> val = rand(500, 2, 3);
-julia> chn = Chains(val, [:a, :b]);
+```jldoctest; setup = :(using Random; Random.seed!(582))
+julia> val = rand(500, 2, 3)
 
-julia> hpd(chn)
-HPD
-  parameters     lower     upper 
-      Symbol   Float64   Float64 
+julia> chn = Chains(val, [:a, :b]);
 
-           a    0.0554    0.9944
-           b    0.0114    0.9460
+julia> hdi(chn)
+HDI
+      lower  upper
+ a  0.0749   0.999
+ b  0.00531  0.940
 ```
 """
-function hpd(chn::Chains; alpha::Real=0.05, kwargs...)
-    labels = [:lower, :upper]
-    l(x) = _hpd(x, alpha=alpha)[1]
-    u(x) = _hpd(x, alpha=alpha)[2]
-    return summarize(chn, l, u; name = "HPD", func_names = labels, kwargs...)
+function PosteriorStats.hdi(chn::Chains; prob::Real=0.94, kwargs...)
+    return summarize(chn, (:lower, :upper) => (x -> hdi(x; prob)); name = "HDI", kwargs...)
 end
 
 """

From 3660d0e19b6446ef12e52957e737c5f96f1c07f7 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:01:52 +0200
Subject: [PATCH 10/56] Deprecate hpd

---
 src/MCMCChains.jl | 1 -
 src/stats.jl      | 2 ++
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl
index 50cdf582..f18646b0 100644
--- a/src/MCMCChains.jl
+++ b/src/MCMCChains.jl
@@ -47,7 +47,6 @@ export mcse
 export rafterydiag
 export rstar
 
-export hpd
 # Reexport stats functions
 using PosteriorStats: default_diagnostics, default_stats, default_summary_stats, hdi,
     summarize
diff --git a/src/stats.jl b/src/stats.jl
index d5118286..0a132953 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -218,6 +218,8 @@ function PosteriorStats.hdi(chn::Chains; prob::Real=0.94, kwargs...)
     return summarize(chn, (:lower, :upper) => (x -> hdi(x; prob)); name = "HDI", kwargs...)
 end
 
+@deprecate hpd(chn::Chains; alpha::Real=0.05, kwargs...) hdi(chn; prob=1 - alpha, kwargs...)
+
 """
     quantile(chains[; q = [0.025, 0.25, 0.5, 0.75, 0.975], append_chains = true, kwargs...])
 

From e3d2d16a19530e111ffd3cde0f690101612723dc Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:02:17 +0200
Subject: [PATCH 11/56] Simplify autocor implementation

---
 src/stats.jl | 16 +++-------------
 1 file changed, 3 insertions(+), 13 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index 0a132953..1ef63789 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -17,24 +17,14 @@ Setting `append_chains=false` will return a vector of dataframes containing the
 """
 function autocor(
     chains::Chains;
-    append_chains = true,
     demean::Bool = true,
     lags::AbstractVector{<:Integer} = _default_lags(chains, append_chains),
     kwargs...
 )
-    funs = Function[]
-    func_names = @. Symbol("lag ", lags)
-    for i in lags
-        push!(funs, x -> autocor(x, [i], demean=demean)[1])
+    funs = map(lags) do lag
+        return Symbol("lag ", lag) => (x -> autocor(x, [i], demean=demean)[1])
     end
-
-    return summarize(
-        chains, funs...;
-        func_names = func_names,
-        append_chains = append_chains,
-        name = "Autocorrelation",
-        kwargs...
-    )
+    return summarize(chains, funs...; name = "Autocorrelation", kwargs...)
 end
 
 """

From 49af2d9108af3323031b7146f55187e52ef5918c Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:10:29 +0200
Subject: [PATCH 12/56] Remove unused keyword `etype`

---
 src/stats.jl | 2 --
 1 file changed, 2 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index 1ef63789..6bb9a46f 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -166,7 +166,6 @@ describe(c::Chains; args...) = describe(stdout, c; args...)
 """
     describe(io, chains[;
              q = [0.025, 0.25, 0.5, 0.75, 0.975],
-             etype = :bm,
              kwargs...])
 
 Print the summary statistics and quantiles for the chain.
@@ -175,7 +174,6 @@ function describe(
     io::IO,
     chains::Chains;
     q = [0.025, 0.25, 0.5, 0.75, 0.975],
-    etype = :bm,
     kwargs...
 )
     dfs = vcat(summarystats(chains; etype = etype, kwargs...),

From bdde660ed54ebbb90caa3d6d654b12300f1760cf Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:11:33 +0200
Subject: [PATCH 13/56] Explicitly build list of stats

---
 src/stats.jl | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index 6bb9a46f..4ec40683 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -176,9 +176,8 @@ function describe(
     q = [0.025, 0.25, 0.5, 0.75, 0.975],
     kwargs...
 )
-    dfs = vcat(summarystats(chains; etype = etype, kwargs...),
-               quantile(chains; q = q, kwargs...))
-    return dfs
+    stats = [summarystats(chains; kwargs...), quantile(chains; q = q, kwargs...)]
+    return stats
 end
 
 """

From d851307c6d1fa1a50ffbf5390b4314ac0729cc1d Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:15:36 +0200
Subject: [PATCH 14/56] Simultaneously compute all quantiles

---
 src/stats.jl | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index 4ec40683..26f4d255 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -208,7 +208,7 @@ end
 @deprecate hpd(chn::Chains; alpha::Real=0.05, kwargs...) hdi(chn; prob=1 - alpha, kwargs...)
 
 """
-    quantile(chains[; q = [0.025, 0.25, 0.5, 0.75, 0.975], append_chains = true, kwargs...])
+    quantile(chains[; q = (0.025, 0.25, 0.5, 0.75, 0.975), append_chains = true, kwargs...])
 
 Compute the quantiles for each parameter in the chain.
 
@@ -217,15 +217,17 @@ for each chain.
 """
 function quantile(
     chains::Chains;
-    q::AbstractVector = [0.025, 0.25, 0.5, 0.75, 0.975],
+    q::Union{Tuple,AbstractVector} = (0.025, 0.25, 0.5, 0.75, 0.975),
     kwargs...
 )
     # compute quantiles
-    stats_funs = map(q) do qi
-        nm = Symbol(100 * qi, :%)
-        return nm => Base.Fix2(quantile, qi) ∘ cskip
-    end
-    return summarize(chains, stats_funs...; name = "Quantiles", kwargs...)
+    func_names = Tuple(Symbol.(100 .* q, :%))
+    return summarize(
+        chains,
+        func_names => (Base.Fix2(quantile, q) ∘ cskip);
+        name="Quantiles",
+        kwargs...,
+    )
 end
 
 

From 17174834fcd4bb835179c66fb67a35f7ec3ebf87 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:16:16 +0200
Subject: [PATCH 15/56] Print an extra newline

---
 src/chains.jl | 1 +
 1 file changed, 1 insertion(+)

diff --git a/src/chains.jl b/src/chains.jl
index b6852b67..3172ec22 100644
--- a/src/chains.jl
+++ b/src/chains.jl
@@ -351,6 +351,7 @@ function Base.show(io::IO, mime::MIME"text/plain", chains::Chains)
     # Show summary stats.
     summaries = describe(chains)
     for summary in summaries
+        println(io)
         println(io)
         show(io, mime, summary)
     end

From bf0665353008bbf78cad49dba362982a413b0118 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:17:02 +0200
Subject: [PATCH 16/56] Use and export SummaryStats

---
 src/MCMCChains.jl | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl
index f18646b0..b1e47634 100644
--- a/src/MCMCChains.jl
+++ b/src/MCMCChains.jl
@@ -48,9 +48,10 @@ export rafterydiag
 export rstar
 
 # Reexport stats functions
-using PosteriorStats: default_diagnostics, default_stats, default_summary_stats, hdi,
+using PosteriorStats: SummaryStats, default_diagnostics, default_stats,
+    default_summary_stats, hdi, summarize
+export SummaryStats, default_diagnostics, default_stats, default_summary_stats, hdi,
     summarize
-export default_diagnostics, default_stats, default_summary_stats, hdi, summarize
 
 """
     Chains

From 147b56b083fb970557e16400ef99d5e7cb232ec4 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:28:24 +0200
Subject: [PATCH 17/56] Use SummaryStats in place of ChainsDataFrame

---
 src/discretediag.jl | 14 +++++++-------
 src/ess_rhat.jl     | 12 ++++++------
 src/gelmandiag.jl   | 14 +++++++-------
 src/gewekediag.jl   | 10 +++++-----
 src/heideldiag.jl   | 10 +++++-----
 src/mcse.jl         |  4 ++--
 src/rafterydiag.jl  | 10 +++++-----
 7 files changed, 37 insertions(+), 37 deletions(-)

diff --git a/src/discretediag.jl b/src/discretediag.jl
index 67cec670..d262aca8 100644
--- a/src/discretediag.jl
+++ b/src/discretediag.jl
@@ -17,16 +17,16 @@ function MCMCDiagnosticTools.discretediag(
         _permutedims_diagnostics(_chains.value.data); kwargs...
     )
 
-    # Create dataframes
-    parameters = (parameters = names(_chains),)
-    between_chain_df = ChainDataFrame(
+    # Create SummaryStats
+    parameters = (parameter = names(_chains),)
+    between_chain_stats = SummaryStats(
         "Chisq diagnostic - Between chains", merge(parameters, between_chain_vals),
     )
-    within_chain_dfs = map(1:size(_chains, 3)) do i
+    within_chain_stats = map(1:size(_chains, 3)) do i
         vals = map(val -> val[:, i], within_chain_vals)
-        return ChainDataFrame("Chisq diagnostic - Chain $i", merge(parameters, vals))
+        return SummaryStats("Chisq diagnostic - Chain $i", merge(parameters, vals))
     end
-    dfs = vcat(between_chain_df, within_chain_dfs)
+    stats = [between_chain_stats, within_chain_stats...]
 
-    return dfs
+    return stats
 end
diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl
index 70d86d2b..4b203d79 100644
--- a/src/ess_rhat.jl
+++ b/src/ess_rhat.jl
@@ -24,9 +24,9 @@ function MCMCDiagnosticTools.ess(
 
     # Convert to NamedTuple
     ess_per_sec = ess ./ dur
-    nt = merge((parameters = names(_chains),), (; ess, ess_per_sec))
+    nt = merge((parameter = names(_chains),), (; ess, ess_per_sec))
 
-    return ChainDataFrame("ESS", nt)
+    return SummaryStats("ESS", nt)
 end
 
 """
@@ -48,9 +48,9 @@ function MCMCDiagnosticTools.rhat(
     )
 
     # Convert to NamedTuple
-    nt = merge((parameters = names(_chains),), (; rhat))
+    nt = merge((parameter = names(_chains),), (; rhat))
 
-    return ChainDataFrame("R-hat", nt)
+    return SummaryStats("R-hat", nt)
 end
 
 """
@@ -79,7 +79,7 @@ function MCMCDiagnosticTools.ess_rhat(
 
     # Convert to NamedTuple
     ess_per_sec = ess_rhat.ess ./ dur
-    nt = merge((parameters = names(_chains),), ess_rhat, (; ess_per_sec))
+    nt = merge((parameter = names(_chains),), ess_rhat, (; ess_per_sec))
 
-    return ChainDataFrame("ESS/R-hat", nt)
+    return SummaryStats("ESS/R-hat", nt)
 end
diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl
index 3fc894f5..39b6d23d 100644
--- a/src/gelmandiag.jl
+++ b/src/gelmandiag.jl
@@ -12,12 +12,12 @@ function MCMCDiagnosticTools.gelmandiag(
     results = MCMCDiagnosticTools.gelmandiag(_permutedims_diagnostics(psi); kwargs...)
 
     # Create a data frame with the results.
-    df = ChainDataFrame(
+    stats = SummaryStats(
         "Gelman, Rubin, and Brooks diagnostic",
-        merge((parameters = names(_chains),), results),
+        merge((parameter = names(_chains),), results),
     )
 
-    return df
+    return stats
 end
 
 function MCMCDiagnosticTools.gelmandiag_multivariate(
@@ -36,11 +36,11 @@ function MCMCDiagnosticTools.gelmandiag_multivariate(
         kwargs...,
     )
 
-    # Create a data frame with the results.
-    df = ChainDataFrame(
+    # Create SummaryStats with the results.
+    stats = SummaryStats(
         "Gelman, Rubin, and Brooks diagnostic",
-        (parameters = names(_chains), psrf = results.psrf, psrfci = results.psrfci),
+        (parameter = names(_chains), psrf = results.psrf, psrfci = results.psrfci),
     )
 
-    return df, results.psrfmultivariate
+    return stats, results.psrfmultivariate
 end
diff --git a/src/gewekediag.jl b/src/gewekediag.jl
index 72cbb5f2..f7c6ecad 100644
--- a/src/gewekediag.jl
+++ b/src/gewekediag.jl
@@ -18,12 +18,12 @@ function MCMCDiagnosticTools.gewekediag(
         return namedtuple_of_vecs
     end
 
-    # Create data frames.
-    parameters = (parameters = names(_chains),)
-    dfs = [
-        ChainDataFrame("Geweke diagnostic - Chain $i", merge(parameters, result))
+    # Create SummaryStats.
+    parameters = (parameter = names(_chains),)
+    stats = [
+        SummaryStats("Geweke diagnostic - Chain $i", merge(parameters, result))
         for (i, result) in enumerate(results)
     ]
 
-    return dfs
+    return stats
 end
diff --git a/src/heideldiag.jl b/src/heideldiag.jl
index 67def6a8..7f7acd18 100644
--- a/src/heideldiag.jl
+++ b/src/heideldiag.jl
@@ -16,14 +16,14 @@ function MCMCDiagnosticTools.heideldiag(
         return namedtuple_of_vecs
     end
 
-    # Create data frames.
-    parameters = (parameters = names(_chains),)
-    dfs = [
-        ChainDataFrame(
+    # Create SummaryStats.
+    parameters = (parameter = names(_chains),)
+    stats = [
+        SummaryStats(
             "Heidelberger and Welch diagnostic - Chain $i", merge(parameters, result)
         )
         for (i, result) in enumerate(results)
     ]
 
-    return dfs
+    return stats
 end
diff --git a/src/mcse.jl b/src/mcse.jl
index 78ae2552..5c5a2731 100644
--- a/src/mcse.jl
+++ b/src/mcse.jl
@@ -16,7 +16,7 @@ function MCMCDiagnosticTools.mcse(
         kwargs...,
     )
 
-    nt = merge((parameters = names(_chains),), (; mcse))
+    nt = merge((parameter = names(_chains),), (; mcse))
 
-    return ChainDataFrame("MCSE", nt)
+    return SummaryStats("MCSE", nt)
 end
diff --git a/src/rafterydiag.jl b/src/rafterydiag.jl
index 95126025..4dfd45a8 100644
--- a/src/rafterydiag.jl
+++ b/src/rafterydiag.jl
@@ -16,14 +16,14 @@ function MCMCDiagnosticTools.rafterydiag(
         return namedtuple_of_vecs
     end
 
-    # Create data frames.
-    parameters = (parameters = names(_chains),)
-    dfs = [
-        ChainDataFrame(
+    # Create SummaryStats.
+    parameters = (parameter = names(_chains),)
+    stats = [
+        SummaryStats(
             "Raftery and Lewis diagnostic - Chain $i", merge(parameters, result)
         )
         for (i, result) in enumerate(results)
     ]
 
-    return dfs
+    return stats
 end

From 706f29cab107b56cd09431dccc6a63a578888cdf Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:39:14 +0200
Subject: [PATCH 18/56] Update and repair changerate

---
 src/stats.jl | 30 +++++++++++++++---------------
 1 file changed, 15 insertions(+), 15 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index 26f4d255..124bfb96 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -63,11 +63,11 @@ function cor(
     names_of_params = names(_chains)
 
     if append_chains
-        df = chaindataframe_cor("Correlation", names_of_params, to_matrix(_chains))
+        df = summarystats_cor("Correlation", names_of_params, to_matrix(_chains))
         return df
     else
         vector_of_df = [
-            chaindataframe_cor(
+            summarystats_cor(
                 "Correlation - Chain $i", names_of_params, data
             )
             for (i, data) in enumerate(to_vector_of_matrices(_chains))
@@ -76,16 +76,16 @@ function cor(
     end
 end
 
-function chaindataframe_cor(name, names_of_params, chains::AbstractMatrix; kwargs...)
+function summarystats_cor(name, names_of_params, chains::AbstractMatrix; kwargs...)
     # Compute the correlation matrix.
     cormat = cor(chains)
 
     # Summarize the results in a named tuple.
-    nt = (; parameters = names_of_params,
+    nt = (; parameter = names_of_params,
           zip(names_of_params, (cormat[:, i] for i in axes(cormat, 2)))...)
 
-    # Create a ChainDataFrame.
-    return ChainDataFrame(name, nt; kwargs...)
+    # Create a SummaryStats.
+    return SummaryStats(name, nt; kwargs...)
 end
 
 """
@@ -109,28 +109,28 @@ function changerate(
     names_of_params = names(_chains)
 
     if append_chains
-        df = chaindataframe_changerate("Change Rate", names_of_params, _chains.value.data)
-        return df
+        stats = summarystats_changerate("Change Rate", names_of_params, _chains.value.data)
+        return stats
     else
-        vector_of_df = [
-            chaindataframe_changerate(
+        vector_of_stats = [
+            summarystats_changerate(
                 "Change Rate - Chain $i", names_of_params, data
             )
             for (i, data) in enumerate(to_vector_of_matrices(_chains))
         ]
-        return vector_of_df
+        return vector_of_stats
     end
 end
 
-function chaindataframe_changerate(name, names_of_params, chains; kwargs...)
+function summarystats_changerate(name, names_of_params, chains; kwargs...)
     # Compute the change rates.
     changerates, mvchangerate = changerate(chains)
 
     # Summarize the results in a named tuple.
-    nt = (; zip(names_of_params, changerates)..., multivariate = mvchangerate)
+    nt = (; parameter=names_of_params, changerate=changerates)
 
-    # Create a ChainDataFrame.
-    return ChainDataFrame(name, nt; kwargs...)
+    # Create a SummaryStats.
+    return SummaryStats(name, nt; kwargs...), mvchangerate
 end
 
 changerate(chains::AbstractMatrix{<:Real}) = changerate(reshape(chains, Val(3)))

From 2c07794a3282cb55fd3e85ed8c16d228f8b3836c Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:39:57 +0200
Subject: [PATCH 19/56] Remove ChainDataFrame

---
 src/MCMCChains.jl |   1 -
 src/summarize.jl  | 123 ----------------------------------------------
 src/tables.jl     |  39 ---------------
 3 files changed, 163 deletions(-)

diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl
index b1e47634..dfbc73b2 100644
--- a/src/MCMCChains.jl
+++ b/src/MCMCChains.jl
@@ -32,7 +32,6 @@ export setrange, resetrange
 export set_section, get_params, sections, sort_sections, setinfo
 export replacenames, namesingroup, group
 export autocor, describe, sample, summarystats, AbstractWeights, mean, quantile
-export ChainDataFrame
 
 # Reexport diagnostics functions
 using MCMCDiagnosticTools: discretediag, ess, ess_rhat, AutocovMethod, FFTAutocovMethod,
diff --git a/src/summarize.jl b/src/summarize.jl
index f8d47b90..99f73fd3 100644
--- a/src/summarize.jl
+++ b/src/summarize.jl
@@ -1,126 +1,3 @@
-struct ChainDataFrame{NT<:NamedTuple}
-    name::String
-    nt::NT
-    nrows::Int
-    ncols::Int
-
-    function ChainDataFrame(name::String, nt::NamedTuple)
-        lengths = length(first(nt))
-        all(x -> length(x) == lengths, nt) || error("Lengths must be equal.")
-
-        return new{typeof(nt)}(name, nt, lengths, length(nt))
-    end
-end
-
-ChainDataFrame(nt::NamedTuple) = ChainDataFrame("", nt)
-
-Base.size(c::ChainDataFrame) = (c.nrows, c.ncols)
-Base.names(c::ChainDataFrame) = collect(keys(c.nt))
-
-# Display
-
-function Base.show(io::IO, df::ChainDataFrame)
-    print(io, df.name, " (", df.nrows, " x ", df.ncols, ")")
-end
-
-function Base.show(io::IO, ::MIME"text/plain", df::ChainDataFrame)
-    digits = get(io, :digits, 4)
-    formatter = PrettyTables.ft_printf("%.$(digits)f")
-
-    println(io, df.name)
-    # Support for PrettyTables 0.9 (`borderless`) and 0.10 (`tf_borderless`)
-    PrettyTables.pretty_table(
-        io, df.nt;
-        formatters = formatter,
-        tf = isdefined(PrettyTables, :borderless) ? PrettyTables.borderless : PrettyTables.tf_borderless,
-    )
-end
-
-Base.isequal(c1::ChainDataFrame, c2::ChainDataFrame) = isequal(c1, c2)
-
-# Index functions
-function Base.getindex(c::ChainDataFrame, s::Union{Colon, Integer, UnitRange}, g::Union{Colon, Integer, UnitRange})
-    convert(Array, getindex(c, c.nt[:parameters][s], collect(keys(c.nt))[g]))
-end
-
-Base.getindex(c::ChainDataFrame, s::Vector{Symbol}, ::Colon) = getindex(c, s)
-function Base.getindex(c::ChainDataFrame, s::Union{Symbol, Vector{Symbol}})
-    getindex(c, s, collect(keys(c.nt)))
-end
-
-function Base.getindex(c::ChainDataFrame, s::Union{Colon, Integer, UnitRange}, ks)
-    getindex(c, c.nt[:parameters][s], ks)
-end
-
-# dispatches involing `String` and `AbstractVector{String}`
-Base.getindex(c::ChainDataFrame, s::String, ks) = getindex(c, Symbol(s), ks)
-function Base.getindex(c::ChainDataFrame, s::AbstractVector{String}, ks)
-    return getindex(c, Symbol.(s), ks)
-end
-
-# dispatch for `Symbol`
-Base.getindex(c::ChainDataFrame, s::Symbol, ks) = getindex(c, [s], ks)
-
-function Base.getindex(c::ChainDataFrame, s::AbstractVector{Symbol}, ks::Symbol)
-    return getindex(c, s, [ks])
-end
-
-function Base.getindex(
-    c::ChainDataFrame,
-    s::AbstractVector{Symbol},
-    ks::AbstractVector{Symbol}
-)
-    ind = indexin(s, c.nt[:parameters])
-
-    not_found = map(x -> x === nothing, ind)
-
-    any(not_found) && error("Cannot find parameters $(s[not_found]) in chain")
-
-    # If there are multiple columns, return a new CDF.
-    if length(ks) > 1
-        if !(:parameters in ks)
-            ks = vcat(:parameters, ks)
-        end
-        nt = NamedTuple{tuple(ks...)}(tuple([c.nt[k][ind] for k in ks]...))
-        return ChainDataFrame(c.name, nt)
-    else
-        # Otherwise, return a vector if there's multiple parameters
-        # or just a scalar if there's one parameter.
-        if length(s) == 1
-            return c.nt[ks[1]][ind][1]
-        else
-            return c.nt[ks[1]][ind]
-        end
-    end
-end
-
-function Base.lastindex(c::ChainDataFrame, i::Integer)
-    if i == 1
-        return c.nrows
-    elseif i ==2
-        return c.ncols
-    else
-        error("No such dimension")
-    end
-end
-
-function Base.convert(::Type{Array}, c::C) where C<:ChainDataFrame
-    T = promote_eltype_namedtuple_tail(c.nt)
-    arr = Array{T, 2}(undef, c.nrows, c.ncols - 1)
-
-    for (i, k) in enumerate(Iterators.drop(keys(c.nt), 1))
-        arr[:, i] = c.nt[k]
-    end
-
-    return arr
-end
-
-function Base.convert(::Type{Array}, cs::Array{ChainDataFrame{T},1}) where T<:NamedTuple
-    return mapreduce((x, y) -> cat(x, y; dims = Val(3)), cs) do c
-        reshape(convert(Array, c), Val(3))
-    end
-end
-
 """
     summarize(
         chains[, stats_funs...];
diff --git a/src/tables.jl b/src/tables.jl
index d5ad2e67..7a70db04 100644
--- a/src/tables.jl
+++ b/src/tables.jl
@@ -69,42 +69,3 @@ function IteratorInterfaceExtensions.getiterator(chn::Chains)
 end
 
 TableTraits.isiterabletable(::Chains) = true
-
-####
-#### ChainDataFrame
-####
-
-#### Tables interface
-
-Tables.istable(::Type{<:ChainDataFrame}) = true
-
-# AbstractColumns interface
-
-Tables.columnaccess(::Type{<:ChainDataFrame}) = true
-
-Tables.columns(cdf::ChainDataFrame) = cdf
-
-Tables.columnnames(::ChainDataFrame{<:NamedTuple{names}}) where {names} = names
-
-Tables.getcolumn(cdf::ChainDataFrame, i::Int) = cdf.nt[i]
-Tables.getcolumn(cdf::ChainDataFrame, nm::Symbol) = cdf.nt[nm]
-
-# row access
-
-Tables.rowaccess(::Type{<:ChainDataFrame}) = true
-
-Tables.rows(cdf::ChainDataFrame) = Tables.rows(Tables.columntable(cdf))
-
-function Tables.schema(::ChainDataFrame{NamedTuple{names,T}}) where {names,T}
-    types = ntuple(i -> eltype(fieldtype(T, i)), fieldcount(T))
-    return Tables.Schema(names, types)
-end
-
-#### TableTraits interface
-
-IteratorInterfaceExtensions.isiterable(::ChainDataFrame) = true
-function IteratorInterfaceExtensions.getiterator(cdf::ChainDataFrame)
-    return Tables.datavaluerows(Tables.columntable(cdf))
-end
-
-TableTraits.isiterabletable(::ChainDataFrame) = true

From 58ae52abda0606acd2a6e921e9d4a1a40faebc2e Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:50:34 +0200
Subject: [PATCH 20/56] Update docs

---
 docs/src/stats.md     |  2 +-
 docs/src/summarize.md | 11 +++++++----
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/docs/src/stats.md b/docs/src/stats.md
index 7b178ddb..c53ecac0 100644
--- a/docs/src/stats.md
+++ b/docs/src/stats.md
@@ -8,5 +8,5 @@ describe
 mean
 summarystats
 quantile
-hpd
+hdi
 ```
diff --git a/docs/src/summarize.md b/docs/src/summarize.md
index e1865bd6..14ae8101 100644
--- a/docs/src/summarize.md
+++ b/docs/src/summarize.md
@@ -1,8 +1,11 @@
 # Summarize
 
-The methods listed below are defined in `src/summarize.jl`.
+The methods listed below are related to summarizing chains.
 
-```@autodocs
-Modules = [MCMCChains]
-Pages = ["summarize.jl"]
+```@docs
+SummaryStats
+summarize
+default_summary_stats
+default_stats
+default_diagnostics
 ```

From 340a694d3afc8a303d00e20d4935e2a1d925e075 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 16:51:20 +0200
Subject: [PATCH 21/56] Increment major version

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index d3eee142..8512d12b 100644
--- a/Project.toml
+++ b/Project.toml
@@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
 keywords = ["markov chain monte carlo", "probablistic programming"]
 license = "MIT"
 desc = "Chain types and utility functions for MCMC simulations."
-version = "6.0.3"
+version = "7.0.0"
 
 [deps]
 AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

From a67ac11d824082352eb5057074312f13fc23b6fd Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 19:23:35 +0200
Subject: [PATCH 22/56] Increment MCMCChains compat for docs

---
 docs/Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/docs/Project.toml b/docs/Project.toml
index 45d71e81..376e68f9 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -15,7 +15,7 @@ CategoricalArrays = "0.8, 0.9, 0.10"
 DataFrames = "0.22, 1"
 Documenter = "0.26, 0.27"
 Gadfly = "1.3.4"
-MCMCChains = "6"
+MCMCChains = "7"
 MLJBase = "0.19, 0.20, 0.21"
 MLJDecisionTreeInterface = "0.3, 0.4"
 StatsPlots = "0.14, 0.15"

From 1af83c8a218639a64f90c9e5813c02a65de74860 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 21:33:03 +0200
Subject: [PATCH 23/56] Refer to processed chains

---
 src/summarize.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/summarize.jl b/src/summarize.jl
index 99f73fd3..df2a8b63 100644
--- a/src/summarize.jl
+++ b/src/summarize.jl
@@ -42,7 +42,7 @@ function PosteriorStats.summarize(
 
     if append_chains
         # Evaluate the functions.
-        data = _permutedims_diagnostics(chains.value.data)
+        data = _permutedims_diagnostics(chn.value.data)
         summarize(data, funs...; var_names=names_of_params, kwargs...)
     else
         # Evaluate the functions.

From 4aed409b5124d762f9043ce6d06613e07bb3cd04 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 21:33:10 +0200
Subject: [PATCH 24/56] Fix doctest

---
 src/stats.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/stats.jl b/src/stats.jl
index 124bfb96..37c970e1 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -190,7 +190,7 @@ Note that this will return a single interval and will not return multiple interv
 # Examples
 
 ```jldoctest; setup = :(using Random; Random.seed!(582))
-julia> val = rand(500, 2, 3)
+julia> val = rand(500, 2, 3);
 
 julia> chn = Chains(val, [:a, :b]);
 

From eef9393af7ad0ce4ab32e8405a4b20e3b80fbb2f Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 22:22:37 +0200
Subject: [PATCH 25/56] Add back append_chains keyword

---
 src/stats.jl | 1 +
 1 file changed, 1 insertion(+)

diff --git a/src/stats.jl b/src/stats.jl
index 37c970e1..eb90415e 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -17,6 +17,7 @@ Setting `append_chains=false` will return a vector of dataframes containing the
 """
 function autocor(
     chains::Chains;
+    append_chains::Bool = true,
     demean::Bool = true,
     lags::AbstractVector{<:Integer} = _default_lags(chains, append_chains),
     kwargs...

From 6a1b482a02b0ff3e24d5865cfe6b292c13ab7b41 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 22:22:57 +0200
Subject: [PATCH 26/56] Compute all lags simultaneously

---
 src/stats.jl | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index eb90415e..a32817e9 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -22,10 +22,9 @@ function autocor(
     lags::AbstractVector{<:Integer} = _default_lags(chains, append_chains),
     kwargs...
 )
-    funs = map(lags) do lag
-        return Symbol("lag ", lag) => (x -> autocor(x, [i], demean=demean)[1])
-    end
-    return summarize(chains, funs...; name = "Autocorrelation", kwargs...)
+    fun_names = Tuple(Symbol.("lag", lags))
+    fun = (x -> autocor(x, lags; demean=demean))
+    return summarize(chains, fun_names => fun; name = "Autocorrelation", append_chains, kwargs...)
 end
 
 """

From c7d2021b2db87ff2965d9613aae24a9a489ca706 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 22:23:12 +0200
Subject: [PATCH 27/56] Vectorize before autocor

---
 src/stats.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/stats.jl b/src/stats.jl
index a32817e9..6566de31 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -23,7 +23,7 @@ function autocor(
     kwargs...
 )
     fun_names = Tuple(Symbol.("lag", lags))
-    fun = (x -> autocor(x, lags; demean=demean))
+    fun = (x -> autocor(vec(x), lags; demean=demean))
     return summarize(chains, fun_names => fun; name = "Autocorrelation", append_chains, kwargs...)
 end
 

From bfe8deb08cb5f432c88fae71a3283b6328d4379c Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 22:57:01 +0200
Subject: [PATCH 28/56] Correctly insert chain id into name

---
 src/summarize.jl | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/summarize.jl b/src/summarize.jl
index df2a8b63..1e54a481 100644
--- a/src/summarize.jl
+++ b/src/summarize.jl
@@ -32,6 +32,7 @@ function PosteriorStats.summarize(
     sections = _default_sections(chains),
     append_chains::Bool = true,
     var_names=nothing,
+    name::AbstractString = "SummaryStats",
     kwargs...
 )
     # Generate a chain to work on.
@@ -43,13 +44,14 @@ function PosteriorStats.summarize(
     if append_chains
         # Evaluate the functions.
         data = _permutedims_diagnostics(chn.value.data)
-        summarize(data, funs...; var_names=names_of_params, kwargs...)
+        summarize(data, funs...; var_names=names_of_params, name, kwargs...)
     else
         # Evaluate the functions.
         data = to_vector_of_matrices(chn)
-        return map(data) do x
+        return map(enumerate(data)) do (i, x)
             z = reshape(x, size(x, 1), 1, size(x, 2))
-            summarize(z, funs...; var_names=names_of_params, kwargs...)
+            name_chain = name * " (Chain $i)"
+            summarize(z, funs...; var_names=names_of_params, name=name_chain, kwargs...)
         end
     end
 end

From 646e10823451c6dc4868b0d9c6bf916c458a4a44 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 22:57:47 +0200
Subject: [PATCH 29/56] Update diagnostic tests

---
 test/diagnostic_tests.jl | 66 ++++++++++++++++++++++------------------
 1 file changed, 36 insertions(+), 30 deletions(-)

diff --git a/test/diagnostic_tests.jl b/test/diagnostic_tests.jl
index a7a4d443..706aac3e 100644
--- a/test/diagnostic_tests.jl
+++ b/test/diagnostic_tests.jl
@@ -68,8 +68,8 @@ chn_disc = Chains(val_disc, start = 1, thin = 2)
     @test all(MCMCChains.indiscretesupport(chn) .== [false, false, false, true])
     @test setinfo(chn, NamedTuple{(:A, :B)}((1,2))).info == NamedTuple{(:A, :B)}((1,2))
     @test isa(set_section(chn, Dict(:internals => ["param_1"])), AbstractChains)
-    @test mean(chn) isa ChainDataFrame
-    @test mean(chn, ["param_1", "param_3"]) isa ChainDataFrame
+    @test mean(chn) isa SummaryStats
+    @test mean(chn, ["param_1", "param_3"]) isa SummaryStats
     @test 0.95 ≤ mean(chn, "param_1") ≤ 1.05
 end
 
@@ -167,7 +167,7 @@ end
 @testset "function tests" begin
     tchain = Chains(rand(niter, nparams, nchains), ["a", "b", "c"], Dict(:internals => ["c"]))
 
-    @test eltype(discretediag(chn_disc[:,2:2,:])) <: ChainDataFrame
+    @test eltype(discretediag(chn_disc[:,2:2,:])) <: SummaryStats
 
     gelman = gelmandiag(tchain)
     gelmanmv = gelmandiag_multivariate(tchain)
@@ -176,18 +176,21 @@ end
     raferty = rafterydiag(tchain)
 
     # test raw return values
-    @test typeof(gelman) <: ChainDataFrame
-    @test typeof(gelmanmv) <: Tuple{ChainDataFrame,Float64}
-    @test typeof(geweke) <: Array{<:ChainDataFrame}
-    @test typeof(heidel) <: Array{<:ChainDataFrame}
-    @test typeof(raferty) <: Array{<:ChainDataFrame}
-
-    # test ChainDataFrame sizes
-    @test size(gelman) == (2,3)
-    @test size(gelmanmv[1]) == (2,3)
-    @test size(geweke[1]) == (2,3)
-    @test size(heidel[1]) == (2,7)
-    @test size(raferty[1]) == (2,6)
+    @test typeof(gelman) <: SummaryStats
+    @test typeof(gelmanmv) <: Tuple{SummaryStats,Float64}
+    @test typeof(geweke) <: Array{<:SummaryStats}
+    @test typeof(heidel) <: Array{<:SummaryStats}
+    @test typeof(raferty) <: Array{<:SummaryStats}
+
+    # test SummaryStats sizes
+    for s in (gelman, gelmanmv[1], geweke[1], heidel[1], raferty[1])
+        @test length(s[:parameter]) == 2
+    end
+    @test length(keys(gelman)) == 3
+    @test length(keys(gelmanmv[1])) == 3
+    @test length(keys(geweke[1])) == 3
+    @test length(keys(heidel[1])) == 7
+    @test length(keys(raferty[1])) == 6
 end
 
 @testset "stats tests" begin
@@ -203,31 +206,34 @@ end
             @test lags == filter!(x -> x < n, [1, 5, 10, 50])
 
             acor = autocor(c; append_chains=append_chains)
-            # Number of columns in the ChainDataFrame(s): lags + parameters
+            # Number of columns in the SummaryStats: lags + parameters
             ncols = length(lags) + 1
             if append_chains
-                @test acor isa ChainDataFrame
-                @test size(acor)[2] == ncols
+                @test acor isa SummaryStats
+                @test length(keys(acor)) == ncols
             else
-                @test acor isa Vector{<:ChainDataFrame}
-                @test all(size(a)[2] == ncols for a in acor)
+                @test acor isa Vector{<:SummaryStats}
+                @test all(length(keys(a)) == ncols for a in acor)
             end
         end
-        @test autocor(c) isa ChainDataFrame
-        @test convert(Array, autocor(c)) == convert(Array, autocor(c; append_chains=true))
+        @test autocor(c) isa SummaryStats
+        @test autocor(c) == autocor(c; append_chains=true)
     end
 
-    @test MCMCChains.cor(chn) isa ChainDataFrame
-    @test MCMCChains.cor(chn; append_chains = false) isa Vector{<:ChainDataFrame}
+    @test MCMCChains.cor(chn) isa SummaryStats
+    @test MCMCChains.cor(chn; append_chains = false) isa Vector{<:SummaryStats}
+
+    @test MCMCChains.changerate(chn) isa Tuple{SummaryStats,Float64}
+    @test MCMCChains.changerate(chn; append_chains = false) isa Vector{<:Tuple{SummaryStats,Float64}}
 
-    @test MCMCChains.changerate(chn) isa ChainDataFrame
-    @test MCMCChains.changerate(chn; append_chains = false) isa Vector{<:ChainDataFrame}
+    @test hdi(chn) isa SummaryStats
+    @test hdi(chn; append_chains = false) isa Vector{<:SummaryStats}
 
-    @test hpd(chn) isa ChainDataFrame
-    @test hpd(chn; append_chains = false) isa Vector{<:ChainDataFrame}
+    result = hdi(chn)
+    @test all(result[:upper] .> result[:lower])
 
-    result = hpd(chn)
-    @test all(result.nt.upper .> result.nt.lower)
+    @test_deprecated hpd(chn)
+    @test hpd(chn) == hdi(chn; prob=0.95)
 end
 
 @testset "vector of vectors" begin

From 8066b460c1650d0e1917d0301de481da15156520 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 22:58:02 +0200
Subject: [PATCH 30/56] Update ess_rhat_tests.jl

---
 test/ess_rhat_tests.jl | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/test/ess_rhat_tests.jl b/test/ess_rhat_tests.jl
index 2545c494..955f9429 100644
--- a/test/ess_rhat_tests.jl
+++ b/test/ess_rhat_tests.jl
@@ -18,8 +18,8 @@ using Test
 
     for f in (ess, ess_rhat)
         s = f(c)
-        @test length(s[:,:ess_per_sec]) == 5
-        @test all(map(!ismissing, s[:,:ess_per_sec]))
+        @test length(s[:ess_per_sec]) == 5
+        @test all(map(!ismissing, s[:ess_per_sec]))
     end
 end
 
@@ -37,8 +37,8 @@ end
         ess_array, rhat_array = ess_rhat(
             permutedims(x, (1, 3, 2)); autocov_method = autocov_method, kind = kind,
         )
-        @test ess_df[:,2] == ess_rhat_df[:,2] == ess_array
-        @test rhat_df[:,2] == ess_rhat_df[:,3] == rhat_array
+        @test ess_df[:ess] == ess_rhat_df[:ess] == ess_array
+        @test rhat_df[:rhat] == ess_rhat_df[:rhat] == rhat_array
     end
 end
 
@@ -51,5 +51,5 @@ end
         @test_throws ArgumentError ess(chain; autocov_method = autocov_method)
         @test_throws ArgumentError ess_rhat(chain; autocov_method = autocov_method)
     end
-    @test all(isnan, rhat(chain)[:, 2])
+    @test all(isnan, rhat(chain)[:rhat])
 end

From 34b1da275374ec3322ade90895da3361998b29b5 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 22:58:06 +0200
Subject: [PATCH 31/56] Update mcse_tests.jl

---
 test/mcse_tests.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/test/mcse_tests.jl b/test/mcse_tests.jl
index fdbefbb8..84f183d8 100644
--- a/test/mcse_tests.jl
+++ b/test/mcse_tests.jl
@@ -20,7 +20,7 @@ mymean(x) = mean(x)
                 mcse_array = mcse(
                     PermutedDimsArray(x, (1, 3, 2)); autocov_method = autocov_method, kind = kind,
                 )
-                @test mcse_df[:,2] == mcse_array
+                @test mcse_df[:mcse] == mcse_array
             end
         else
             # analyze chain
@@ -28,7 +28,7 @@ mymean(x) = mean(x)
 
             # analyze array
             mcse_array = mcse(PermutedDimsArray(x, (1, 3, 2)); kind = kind)
-            @test mcse_df[:,2] == mcse_array
+            @test mcse_df[:mcse] == mcse_array
         end
     end
 end

From 5b0db148583f2fa2c87348540a2e76e500d920f4 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 22:58:17 +0200
Subject: [PATCH 32/56] Increment MCMCChains compat for tests

---
 test/Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/Project.toml b/test/Project.toml
index 15bc6c96..95386cbe 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -30,7 +30,7 @@ Documenter = "0.26, 0.27"
 FFTW = "1.1"
 IteratorInterfaceExtensions = "1"
 KernelDensity = "0.6.2"
-MCMCChains = "6"
+MCMCChains = "7"
 MLJBase = "0.18, 0.19, 0.20, 0.21"
 MLJDecisionTreeInterface = "0.3, 0.4"
 StatsBase = "0.33.2, 0.34"

From fb04e613c4036b6bfee156bfd7f4e6a9156de752 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 22:58:32 +0200
Subject: [PATCH 33/56] Remove ChainDataFrames to tables tests

---
 test/tables_tests.jl | 79 --------------------------------------------
 1 file changed, 79 deletions(-)

diff --git a/test/tables_tests.jl b/test/tables_tests.jl
index ba3cb71c..7792d14f 100644
--- a/test/tables_tests.jl
+++ b/test/tables_tests.jl
@@ -128,83 +128,4 @@ using DataFrames
             @test isequal(Tables.columntable(df), Tables.columntable(chn))
         end
     end
-
-    @testset "ChainDataFrames" begin
-        val = rand(1000, 8, 4)
-        colnames = ["a", "b", "c", "d", "e", "f", "g", "h"]
-        internal_colnames = ["c", "d", "e", "f", "g", "h"]
-        chn = Chains(val, colnames, Dict(:internals => internal_colnames))
-        cdf = describe(chn)[1]
-
-        @testset "Tables interface" begin
-            @test Tables.istable(typeof(cdf))
-
-            @testset "column access" begin
-                @test Tables.columnaccess(typeof(cdf))
-                @test Tables.columns(cdf) === cdf
-                @test Tables.columnnames(cdf) == keys(cdf.nt)
-                for (k, v) in pairs(cdf.nt)
-                    @test isequal(Tables.getcolumn(cdf, k), v)
-                end
-                @test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1])
-                @test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2])
-                @test_throws Exception Tables.getcolumn(cdf, :blah)
-                @test_throws Exception Tables.getcolumn(cdf, length(cdf.nt) + 1)
-            end
-
-            @testset "row access" begin
-                @test Tables.rowaccess(typeof(cdf))
-                @test Tables.rows(cdf) isa Tables.RowIterator
-                @test eltype(Tables.rows(cdf)) <: Tables.AbstractRow
-                rows = collect(Tables.rows(cdf))
-                @test eltype(rows) <: Tables.AbstractRow
-                @test size(rows) === (2,)
-                @testset for i in 1:2
-                    row = rows[i]
-                    @test Tables.columnnames(row) == keys(cdf.nt)
-                    for j in length(cdf.nt)
-                        @test isequal(Tables.getcolumn(row, j), cdf.nt[j][i])
-                        @test isequal(Tables.getcolumn(row, keys(cdf.nt)[j]), cdf.nt[j][i])
-                    end
-                end
-            end
-
-            @testset "integration tests" begin
-                @test length(Tables.rowtable(cdf)) == length(cdf.nt[1])
-                @test isequal(Tables.columntable(cdf), cdf.nt)
-                nt = Tables.rowtable(cdf)[1]
-                @test isequal(nt, (; (k => v[1] for (k, v) in pairs(cdf.nt))...))
-                @test isequal(nt, collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1])
-                nt = Tables.rowtable(cdf)[2]
-                @test isequal(nt, (; (k => v[2] for (k, v) in pairs(cdf.nt))...))
-                @test isequal(nt, collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2])
-                @test isequal(
-                    Tables.matrix(Tables.rowtable(cdf)),
-                    Tables.matrix(Tables.columntable(cdf)),
-                )
-            end
-
-            @testset "schema" begin
-                @test Tables.schema(cdf) isa Tables.Schema
-                @test Tables.schema(cdf).names === keys(cdf.nt)
-                @test Tables.schema(cdf).types === eltype.(values(cdf.nt))
-            end
-        end
-
-        @testset "TableTraits interface" begin
-            @test IteratorInterfaceExtensions.isiterable(cdf)
-            @test TableTraits.isiterabletable(cdf)
-            nt = collect(Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 1))[1]
-            @test isequal(nt, (; (k => v[1] for (k, v) in pairs(cdf.nt))...))
-            nt = collect(Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 2))[2]
-            @test isequal(nt, (; (k => v[2] for (k, v) in pairs(cdf.nt))...))
-        end
-
-        @testset "DataFrames.DataFrame constructor" begin
-            @inferred DataFrame(cdf)
-            df = DataFrame(cdf)
-            @test df isa DataFrame
-            @test isequal(Tables.columntable(df), cdf.nt)
-        end
-    end
 end

From faf667bab1a2c5ea745d0571485bca6e22bf9513 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 22:58:41 +0200
Subject: [PATCH 34/56] Update summarize_tests.jl

---
 test/summarize_tests.jl | 47 +++++++++++++++++++++--------------------
 1 file changed, 24 insertions(+), 23 deletions(-)

diff --git a/test/summarize_tests.jl b/test/summarize_tests.jl
index d847ec43..315d6d2e 100644
--- a/test/summarize_tests.jl
+++ b/test/summarize_tests.jl
@@ -1,54 +1,55 @@
 using MCMCChains, Test
 using Statistics: std
 
-@testset "Summarize to DataFrame tests" begin
+@testset "Summarize tests" begin
     val = rand(1000, 8, 4)
+    parm_names = ["a", "b", "c", "d", "e", "f", "g", "h"]
     chns = Chains(
         val,
-        ["a", "b", "c", "d", "e", "f", "g", "h"],
+        parm_names,
         Dict(:internals => ["c", "d", "e", "f", "g", "h"])
     )
 
     parm_df = summarize(chns, sections=[:parameters])
+    parm_array_df = summarize(PermutedDimsArray(val[:, 1:2, :], (1, 3, 2)); var_names=[:a, :b])
 
-    # check that display of ChainDataFrame does not error
+    # check that display of SummaryStats does not error
     println("compact display:")
     show(stdout, parm_df)
     println("\nverbose display:")
     show(stdout, "text/plain", parm_df)
 
-    @test 0.48 < parm_df[:a, :mean][1] < 0.52
-    @test names(parm_df) == [:parameters, :mean, :std, :mcse, :ess_bulk, :ess_tail, :rhat, :ess_per_sec]
+    @test 0.48 < parm_df[:mean][1] < 0.52
+    @test parm_df == parm_array_df
 
     # Indexing tests
-    @test isequal(convert(Array, parm_df[:a, :]), convert(Array, parm_df[:a]))
-    @test parm_df[:a, :][:,:parameters] == :a
-    @test parm_df[[:a, :b], :][:,:parameters] == [:a, :b]
+    @test parm_df[:parameter] == [:a, :b]
 
     all_sections_df = summarize(chns, sections=[:parameters, :internals])
-    @test all_sections_df isa ChainDataFrame
-    @test all_sections_df[:,:parameters] == [:a, :b, :c, :d, :e, :f, :g, :h]
-    @test size(all_sections_df) == (8, 8)
-    @test all_sections_df.name == ""
+    all_sections_array_df = summarize(PermutedDimsArray(val, (1, 3, 2)); var_names=Symbol.(parm_names))
+    @test all_sections_df isa SummaryStats
+    @test all_sections_df[:parameter] == Symbol.(parm_names)
+    @test all_sections_array_df == all_sections_df
+    @test all_sections_df.name == "SummaryStats"
 
     all_sections_dfs = summarize(chns, sections=[:parameters, :internals], name = "Summary", append_chains = false)
-    @test all_sections_dfs isa Vector{<:ChainDataFrame}
+    @test all_sections_dfs isa Vector{<:SummaryStats}
     for (i, all_sections_df) in enumerate(all_sections_dfs)
-        @test all_sections_df[:,:parameters] == [:a, :b, :c, :d, :e, :f, :g, :h]
-        @test size(all_sections_df) == (8, 8)
+        @test all_sections_df[:parameter] == Symbol.(parm_names)
+        @test length(keys(all_sections_df)) == length(keys(all_sections_array_df))
         @test all_sections_df.name == "Summary (Chain $i)"
     end
 
     two_parms_two_funs_df = summarize(chns[[:a, :b]], mean, std)
-    @test two_parms_two_funs_df[:, :parameters] == [:a, :b]
-    @test size(two_parms_two_funs_df) == (2, 3)
+    @test two_parms_two_funs_df[:parameter] == [:a, :b]
+    @test keys(two_parms_two_funs_df) == (:parameter, :mean, :std)
 
     three_parms_df = summarize(chns[[:a, :b, :c]], mean, std, sections=[:parameters, :internals])
-    @test three_parms_df[:, :parameters] == [:a, :b, :c]
-    @test size(three_parms_df) == (3, 3)
+    @test three_parms_df[:parameter] == [:a, :b, :c]
+    @test keys(three_parms_df) == (:parameter, :mean, :std)
 
-    three_parms_df_2 = summarize(chns[[:a, :b, :g]], mean, std,
-    sections=[:parameters, :internals], func_names=[:mean, :sd])
-    @test three_parms_df_2[:, :parameters] == [:a, :b, :g]
-    @test size(three_parms_df_2) == (3, 3)
+    three_parms_df_2 = summarize(chns[[:a, :b, :g]], :mymean => mean, :mystd => std,
+    sections=[:parameters, :internals])
+    @test three_parms_df_2[:parameter] == [:a, :b, :g]
+    @test keys(three_parms_df_2) == (:parameter, :mymean, :mystd)
 end

From fba84a2c2565216e0e3bfe7345a697e4e0bb0693 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 23:00:04 +0200
Subject: [PATCH 35/56] Use crossreference

---
 src/summarize.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/summarize.jl b/src/summarize.jl
index 1e54a481..69515141 100644
--- a/src/summarize.jl
+++ b/src/summarize.jl
@@ -6,7 +6,7 @@
         [sections, var_names],
     )
 
-Summarize `chains` in a `PosteriorStats.SummaryStats`.
+Summarize `chains` in a [`SummaryStats`](@ref).
 
 `stats_funs` is a collection of functions that reduces a matrix with shape `(draws, chains)`
 to a scalar or a collection of scalars. Alternatively, an item in `stats_funs` may be a

From d51dd659111765ff3d14db559fde120fd1df8dca Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 23:16:51 +0200
Subject: [PATCH 36/56] Update plotting functions

---
 src/plot.jl | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/plot.jl b/src/plot.jl
index 1575846c..0a89266f 100644
--- a/src/plot.jl
+++ b/src/plot.jl
@@ -64,7 +64,7 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
         lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag)
         # Chains are already appended in `c` if desired, hence we use `append_chains=false`
         ac = autocor(c; sections = nothing, lags = lags, append_chains=false)
-        ac_mat = convert(Array, ac)
+        ac_mat = cat(reduce.(hcat, Iterators.drop.(ac, 1))...; dims=3)
         val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :]
         _AutocorPlot(lags, val)
     elseif st ∈ supportedplots
@@ -209,7 +209,7 @@ function _compute_plot_data(
     ordered = false
 )
 
-    chain_dic = Dict(zip(quantile(chains)[:,1], quantile(chains)[:,4]))
+    chain_dic = Dict(zip(quantile(chains)[2], quantile(chains)[5]))
     sorted_chain = sort(collect(zip(values(chain_dic), keys(chain_dic))))
     sorted_par = [sorted_chain[i][2] for i in 1:length(par_names)]
     par = (ordered ? sorted_par : par_names)
@@ -217,9 +217,9 @@ function _compute_plot_data(
 
     chain_sections = MCMCChains.group(chains, Symbol(par[i]))
     chain_vec = vec(chain_sections.value.data)
-    lower_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.lower
+    lower_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j])[:lower]
         for j in 1:length(hpdi)]
-    upper_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.upper
+    upper_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j])[:upper]
         for j in 1:length(hpdi)]
     h = _riser + spacer*(i-1)
     qs = quantile(chain_vec, q)

From 5d13530315ec6d2b1c38c15fc77bea75ce7daa1b Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 23:36:25 +0200
Subject: [PATCH 37/56] Update to use hdi and hdi_prob

---
 src/plot.jl | 66 ++++++++++++++++++++++++++---------------------------
 1 file changed, 33 insertions(+), 33 deletions(-)

diff --git a/src/plot.jl b/src/plot.jl
index 0a89266f..9727e536 100644
--- a/src/plot.jl
+++ b/src/plot.jl
@@ -195,7 +195,7 @@ function _compute_plot_data(
     i::Integer,
     chains::Chains,
     par_names::AbstractVector{Symbol};
-    hpd_val = [0.05, 0.2],
+    hdi_prob = [0.94, 0.8],
     q = [0.1, 0.9],
     spacer = 0.4,
     _riser = 0.2,
@@ -203,9 +203,9 @@ function _compute_plot_data(
     show_mean = true,
     show_median = true,
     show_qi = false,
-    show_hpdi = true,
+    show_hdii = true,
     fill_q = true,
-    fill_hpd = false,
+    fill_hdi = false,
     ordered = false
 )
 
@@ -213,19 +213,19 @@ function _compute_plot_data(
     sorted_chain = sort(collect(zip(values(chain_dic), keys(chain_dic))))
     sorted_par = [sorted_chain[i][2] for i in 1:length(par_names)]
     par = (ordered ? sorted_par : par_names)
-    hpdi = sort(hpd_val)
+    hdii = sort(hdi_prob)
 
     chain_sections = MCMCChains.group(chains, Symbol(par[i]))
     chain_vec = vec(chain_sections.value.data)
-    lower_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j])[:lower]
-        for j in 1:length(hpdi)]
-    upper_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j])[:upper]
-        for j in 1:length(hpdi)]
+    lower_hdi = [MCMCChains.hdi(chain_sections, prob = hdii[j])[:lower]
+        for j in 1:length(hdii)]
+    upper_hdi = [MCMCChains.hdi(chain_sections, prob = hdii[j])[:upper]
+        for j in 1:length(hdii)]
     h = _riser + spacer*(i-1)
     qs = quantile(chain_vec, q)
     k_density = kde(chain_vec)
-    if fill_hpd
-        x_int = filter(x -> lower_hpd[1][1] <= x <= upper_hpd[1][1], k_density.x)
+    if fill_hdi
+        x_int = filter(x -> lower_hdi[1][1] <= x <= upper_hdi[1][1], k_density.x)
         val = pdf(k_density, x_int) .+ h
     elseif fill_q
         x_int = filter(x -> qs[1] <= x <= qs[2], k_density.x)
@@ -239,22 +239,22 @@ function _compute_plot_data(
     min = minimum(k_density.density .+ h)
     q_int = (show_qi ? [qs[1], chain_med, qs[2]] : [chain_med])
 
-    return par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med,
+    return par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med,
         chain_mean, min, q_int
 end
 
 @recipe function f(
     p::RidgelinePlot;
-    hpd_val = [0.05, 0.2],
+    hdi_prob = [0.94, 0.8],
     q = [0.1, 0.9],
     spacer = 0.5,
     _riser = 0.2,
     show_mean = true,
     show_median = true,
     show_qi = false,
-    show_hpdi = true,
+    show_hdii = true,
     fill_q = true,
-    fill_hpd = false,
+    fill_hdi = false,
     ordered = false
 )
 
@@ -262,10 +262,10 @@ end
     par_names = p.args[2]
 
     for i in 1:length(par_names)
-        par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med, chain_mean,
-            min, q_int = _compute_plot_data(i, chn, par_names; hpd_val = hpd_val, q = q,
+        par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean,
+            min, q_int = _compute_plot_data(i, chn, par_names; hdi_prob = hdi_prob, q = q,
             spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median,
-            show_qi = show_qi, show_hpdi = show_hpdi, fill_q = fill_q, fill_hpd = fill_hpd,
+            show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi,
             ordered = ordered)
 
         yticks --> (length(par_names) > 1 ?
@@ -322,28 +322,28 @@ end
         end
         @series begin
             seriestype := :path
-            label := (show_hpdi ? (i == 1 ? "$(Integer((1-hpdi[1])*100))% HPDI" : nothing)
+            label := (show_hdii ? (i == 1 ? "$(round(Int, hdii[i]*100))% HDI" : nothing)
                 : nothing)
-            linewidth --> (show_hpdi ? 2 : 0)
+            linewidth --> (show_hdii ? 2 : 0)
             seriesalpha --> 0.80
             linecolor --> :darkblue
-            [lower_hpd[1][1], upper_hpd[1][1]], [h, h]
+            [lower_hdi[1][1], upper_hdi[1][1]], [h, h]
         end
     end
 end
 
 @recipe function f(
     p::ForestPlot;
-    hpd_val = [0.05, 0.2],
+    hdi_prob = [0.94, 0.8],
     q = [0.1, 0.9],
     spacer = 0.5,
     _riser = 0.2,
     show_mean = true,
     show_median = true,
     show_qi = false,
-    show_hpdi = true,
+    show_hdii = true,
     fill_q = true,
-    fill_hpd = false,
+    fill_hdi = false,
     ordered = false
 )
 
@@ -351,25 +351,25 @@ end
     par_names = p.args[2]
 
     for i in 1:length(par_names)
-        par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med, chain_mean,
-            min, q_int = _compute_plot_data(i, chn, par_names; hpd_val = hpd_val, q = q,
+        par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean,
+            min, q_int = _compute_plot_data(i, chn, par_names; hdi_prob = hdi_prob, q = q,
             spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median,
-            show_qi = show_qi, show_hpdi = show_hpdi, fill_q = fill_q, fill_hpd = fill_hpd,
+            show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi,
             ordered = ordered)
 
         yticks --> (length(par_names) > 1 ?
             (_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default)
         yaxis --> (length(par_names) > 1 ? "Parameters" : "Density" )
 
-        for j in 1:length(hpdi)
+        for j in 1:length(hdii)
             @series begin
                 seriestype := :path
-                label := (show_hpdi ?
-                    (i == 1 ? "$(Integer((1-hpdi[j])*100))% HPDI" : nothing) : nothing)
+                label := (show_hdii ?
+                    (i == 1 ? "$(round(Int, hdii[j]*100))% HDI" : nothing) : nothing)
                 linecolor --> j
-                linewidth --> (show_hpdi ? 1.5*j : 0)
+                linewidth --> (show_hdii ? 1.5*j : 0)
                 seriesalpha --> 0.80
-                [lower_hpd[j][1], upper_hpd[j][1]], [h, h]
+                [lower_hdi[j][1], upper_hdi[j][1]], [h, h]
             end
         end
         @series begin
@@ -377,7 +377,7 @@ end
             label := (show_median ? (i == 1 ? "Median" : nothing) : nothing)
             markershape --> :diamond
             markercolor --> "#000000"
-            markersize --> (show_median ? length(hpdi) : 0)
+            markersize --> (show_median ? length(hdii) : 0)
             [chain_med], [h]
         end
         @series begin
@@ -385,7 +385,7 @@ end
             label := (show_mean ? (i == 1 ? "Mean" : nothing) : nothing)
             markershape --> :circle
             markercolor --> :gray
-            markersize --> (show_mean ? length(hpdi) : 0)
+            markersize --> (show_mean ? length(hdii) : 0)
             [chain_mean], [h]
         end
         @series begin

From 57a292168f46f4230e3d84da8ce77547ee3faedc Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 20 Aug 2023 23:48:07 +0200
Subject: [PATCH 38/56] Fix missing tests

---
 test/missing_tests.jl | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/test/missing_tests.jl b/test/missing_tests.jl
index 46b6227c..89a7ba00 100644
--- a/test/missing_tests.jl
+++ b/test/missing_tests.jl
@@ -5,9 +5,7 @@ using Random
 
 # Tests for missing values.
 function testdiff(cdf1, cdf2)
-    m1 = convert(Array, cdf1)
-    m2 = convert(Array, cdf2)
-    return all(((x, y),) -> isapprox(x, y; atol=1e-2), zip(m1, m2))
+    return all(((x, y),) -> isapprox(x, y; atol=1e-2), Iterators.drop(zip(cdf1, cdf2), 1))
 end
 
 @testset "utils" begin
@@ -35,9 +33,9 @@ end
     rf_2 = rafterydiag(chn_m)
 
     @testset "diagnostics missing tests" for i in 1:nchains
-        @test testdiff(gw_1, gw_2)
-        @test testdiff(hd_1, hd_2)
-        @test testdiff(rf_1, rf_2)
+        @test all(Base.splat(testdiff), zip(gw_1, gw_2))
+        @test all(Base.splat(testdiff), zip(hd_1, hd_2))
+        @test all(Base.splat(testdiff), zip(rf_1, rf_2))
     end
 
     @test_throws MethodError discretediag(chn_m)

From fc96b86fad93af4b836360dadc3ec7cccd829f27 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Mon, 21 Aug 2023 00:08:15 +0200
Subject: [PATCH 39/56] Remove references not defined here.

---
 docs/src/summarize.md | 6 +-----
 src/summarize.jl      | 6 +++---
 2 files changed, 4 insertions(+), 8 deletions(-)

diff --git a/docs/src/summarize.md b/docs/src/summarize.md
index 14ae8101..4e30ae2d 100644
--- a/docs/src/summarize.md
+++ b/docs/src/summarize.md
@@ -1,11 +1,7 @@
 # Summarize
 
-The methods listed below are related to summarizing chains.
+The methods listed below are defined in `src/summarize.jl`.
 
 ```@docs
-SummaryStats
 summarize
-default_summary_stats
-default_stats
-default_diagnostics
 ```
diff --git a/src/summarize.jl b/src/summarize.jl
index 69515141..34775926 100644
--- a/src/summarize.jl
+++ b/src/summarize.jl
@@ -6,7 +6,7 @@
         [sections, var_names],
     )
 
-Summarize `chains` in a [`SummaryStats`](@ref).
+Summarize `chains` in a `PosteriorStats.SummaryStats`.
 
 `stats_funs` is a collection of functions that reduces a matrix with shape `(draws, chains)`
 to a scalar or a collection of scalars. Alternatively, an item in `stats_funs` may be a
@@ -14,8 +14,8 @@ to a scalar or a collection of scalars. Alternatively, an item in `stats_funs` m
 form `(name1, ...) => fun` when the function returns a collection. When the function returns
 a collection, the names in this latter format must be provided.
 
-If no stats functions are provided, then those specified in [`default_summary_stats`](@ref)
-are computed.
+If no stats functions are provided, then those specified in
+`PosteriorStats.default_summary_stats` are computed.
 
 `var_names` specifies the names of the parameters in data. If not provided, the names are
 inferred from data.

From cbd06d20395d6413724c0fbbe6817420625469ee Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Mon, 21 Aug 2023 00:20:17 +0200
Subject: [PATCH 40/56] Improve vertical spacing

---
 src/chains.jl | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/src/chains.jl b/src/chains.jl
index 3172ec22..aede0ee7 100644
--- a/src/chains.jl
+++ b/src/chains.jl
@@ -346,11 +346,13 @@ function Base.show(io::IO, chains::Chains)
 end
 
 function Base.show(io::IO, mime::MIME"text/plain", chains::Chains)
-    print(io, "Chains ", chains, ":\n\n", header(chains))
+    println(io, "Chains ", chains, ":\n\n", header(chains))
 
     # Show summary stats.
     summaries = describe(chains)
-    for summary in summaries
+    summary, others = Iterators.peel(summaries)
+    show(io, mime, summary)
+    for summary in others
         println(io)
         println(io)
         show(io, mime, summary)

From f2bdc6888fbb9ccbb62271800aad0f30c73c8803 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Mon, 21 Aug 2023 01:00:45 +0200
Subject: [PATCH 41/56] Bump PosteriorStats compat

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index 8512d12b..87c248f4 100644
--- a/Project.toml
+++ b/Project.toml
@@ -39,7 +39,7 @@ MCMCDiagnosticTools = "0.3"
 MLJModelInterface = "0.3.5, 0.4, 1.0"
 NaturalSort = "1"
 OrderedCollections = "1.4"
-PosteriorStats = "0.1.2"
+PosteriorStats = "0.1.3"
 PrettyTables = "0.9, 0.10, 0.11, 0.12, 1, 2"
 RecipesBase = "0.7, 0.8, 1.0"
 StatsBase = "0.33.2, 0.34"

From c19b13401ee4dbe37584d029cf087334fa6c4de1 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sat, 23 Dec 2023 17:57:17 -0800
Subject: [PATCH 42/56] Bump PosteriorStats compat

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index dca08380..9f1e8663 100644
--- a/Project.toml
+++ b/Project.toml
@@ -39,7 +39,7 @@ MCMCDiagnosticTools = "0.3"
 MLJModelInterface = "0.3.5, 0.4, 1.0"
 NaturalSort = "1"
 OrderedCollections = "1.4"
-PosteriorStats = "0.1.3"
+PosteriorStats = "0.2"
 PrettyTables = "0.9, 0.10, 0.11, 0.12, 1, 2"
 RecipesBase = "0.7, 0.8, 1.0"
 StatsBase = "0.33.2, 0.34"

From 02be9c70f49830b34ee03f8daf1f69383c402f2d Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sat, 23 Dec 2023 18:01:23 -0800
Subject: [PATCH 43/56] Make stack available for older Julia versions

---
 Project.toml      | 2 ++
 src/MCMCChains.jl | 1 +
 2 files changed, 3 insertions(+)

diff --git a/Project.toml b/Project.toml
index 9f1e8663..c094fadb 100644
--- a/Project.toml
+++ b/Project.toml
@@ -8,6 +8,7 @@ version = "7.0.0"
 [deps]
 AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
 AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
+Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
 Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
@@ -31,6 +32,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
 [compat]
 AbstractMCMC = "0.4, 0.5, 1.0, 2.0, 3.0, 4, 5"
 AxisArrays = "0.4.4"
+Compat = "4.2.0"
 Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
 Formatting = "0.4"
 IteratorInterfaceExtensions = "0.1.1, 1"
diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl
index dfbc73b2..10038d8f 100644
--- a/src/MCMCChains.jl
+++ b/src/MCMCChains.jl
@@ -1,5 +1,6 @@
 module MCMCChains
 
+using Compat: stack
 using AxisArrays
 const axes = Base.axes
 import AbstractMCMC

From ea2548e8e0b8ba077a16ee2e9798ded63d046ae9 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sat, 23 Dec 2023 18:11:52 -0800
Subject: [PATCH 44/56] Update SummaryStats constructor user

---
 src/discretediag.jl |  6 +++---
 src/ess_rhat.jl     | 12 ++++++------
 src/gelmandiag.jl   |  6 ++++--
 src/gewekediag.jl   |  3 +--
 src/heideldiag.jl   |  3 +--
 src/mcse.jl         |  4 ++--
 src/rafterydiag.jl  |  3 +--
 src/stats.jl        |  6 +++---
 8 files changed, 21 insertions(+), 22 deletions(-)

diff --git a/src/discretediag.jl b/src/discretediag.jl
index d262aca8..5cb7e054 100644
--- a/src/discretediag.jl
+++ b/src/discretediag.jl
@@ -18,13 +18,13 @@ function MCMCDiagnosticTools.discretediag(
     )
 
     # Create SummaryStats
-    parameters = (parameter = names(_chains),)
+    param_names = names(_chains)
     between_chain_stats = SummaryStats(
-        "Chisq diagnostic - Between chains", merge(parameters, between_chain_vals),
+        "Chisq diagnostic - Between chains", between_chain_vals, param_names,
     )
     within_chain_stats = map(1:size(_chains, 3)) do i
         vals = map(val -> val[:, i], within_chain_vals)
-        return SummaryStats("Chisq diagnostic - Chain $i", merge(parameters, vals))
+        return SummaryStats("Chisq diagnostic - Chain $i", vals, param_names)
     end
     stats = [between_chain_stats, within_chain_stats...]
 
diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl
index 4b203d79..19ee1774 100644
--- a/src/ess_rhat.jl
+++ b/src/ess_rhat.jl
@@ -24,9 +24,9 @@ function MCMCDiagnosticTools.ess(
 
     # Convert to NamedTuple
     ess_per_sec = ess ./ dur
-    nt = merge((parameter = names(_chains),), (; ess, ess_per_sec))
+    nt = (; ess, ess_per_sec)
 
-    return SummaryStats("ESS", nt)
+    return SummaryStats("ESS", nt, names(_chains))
 end
 
 """
@@ -48,9 +48,9 @@ function MCMCDiagnosticTools.rhat(
     )
 
     # Convert to NamedTuple
-    nt = merge((parameter = names(_chains),), (; rhat))
+    nt = (; rhat)
 
-    return SummaryStats("R-hat", nt)
+    return SummaryStats("R-hat", nt, names(_chains))
 end
 
 """
@@ -79,7 +79,7 @@ function MCMCDiagnosticTools.ess_rhat(
 
     # Convert to NamedTuple
     ess_per_sec = ess_rhat.ess ./ dur
-    nt = merge((parameter = names(_chains),), ess_rhat, (; ess_per_sec))
+    nt = merge(ess_rhat, (; ess_per_sec))
 
-    return SummaryStats("ESS/R-hat", nt)
+    return SummaryStats("ESS/R-hat", nt, names(_chains))
 end
diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl
index 39b6d23d..0d40c6b2 100644
--- a/src/gelmandiag.jl
+++ b/src/gelmandiag.jl
@@ -14,7 +14,8 @@ function MCMCDiagnosticTools.gelmandiag(
     # Create a data frame with the results.
     stats = SummaryStats(
         "Gelman, Rubin, and Brooks diagnostic",
-        merge((parameter = names(_chains),), results),
+        results,
+        names(_chains),
     )
 
     return stats
@@ -39,7 +40,8 @@ function MCMCDiagnosticTools.gelmandiag_multivariate(
     # Create SummaryStats with the results.
     stats = SummaryStats(
         "Gelman, Rubin, and Brooks diagnostic",
-        (parameter = names(_chains), psrf = results.psrf, psrfci = results.psrfci),
+        (psrf = results.psrf, psrfci = results.psrfci),
+        names(_chains),
     )
 
     return stats, results.psrfmultivariate
diff --git a/src/gewekediag.jl b/src/gewekediag.jl
index f7c6ecad..83af5b67 100644
--- a/src/gewekediag.jl
+++ b/src/gewekediag.jl
@@ -19,9 +19,8 @@ function MCMCDiagnosticTools.gewekediag(
     end
 
     # Create SummaryStats.
-    parameters = (parameter = names(_chains),)
     stats = [
-        SummaryStats("Geweke diagnostic - Chain $i", merge(parameters, result))
+        SummaryStats("Geweke diagnostic - Chain $i", result, names(_chains))
         for (i, result) in enumerate(results)
     ]
 
diff --git a/src/heideldiag.jl b/src/heideldiag.jl
index 7f7acd18..5db6cae4 100644
--- a/src/heideldiag.jl
+++ b/src/heideldiag.jl
@@ -17,10 +17,9 @@ function MCMCDiagnosticTools.heideldiag(
     end
 
     # Create SummaryStats.
-    parameters = (parameter = names(_chains),)
     stats = [
         SummaryStats(
-            "Heidelberger and Welch diagnostic - Chain $i", merge(parameters, result)
+            "Heidelberger and Welch diagnostic - Chain $i", result, names(_chains),
         )
         for (i, result) in enumerate(results)
     ]
diff --git a/src/mcse.jl b/src/mcse.jl
index 5c5a2731..677f40b0 100644
--- a/src/mcse.jl
+++ b/src/mcse.jl
@@ -16,7 +16,7 @@ function MCMCDiagnosticTools.mcse(
         kwargs...,
     )
 
-    nt = merge((parameter = names(_chains),), (; mcse))
+    nt = (; mcse)
 
-    return SummaryStats("MCSE", nt)
+    return SummaryStats("MCSE", nt, names(_chains))
 end
diff --git a/src/rafterydiag.jl b/src/rafterydiag.jl
index 4dfd45a8..7fd97a5e 100644
--- a/src/rafterydiag.jl
+++ b/src/rafterydiag.jl
@@ -17,10 +17,9 @@ function MCMCDiagnosticTools.rafterydiag(
     end
 
     # Create SummaryStats.
-    parameters = (parameter = names(_chains),)
     stats = [
         SummaryStats(
-            "Raftery and Lewis diagnostic - Chain $i", merge(parameters, result)
+            "Raftery and Lewis diagnostic - Chain $i", result, names(_chains),
         )
         for (i, result) in enumerate(results)
     ]
diff --git a/src/stats.jl b/src/stats.jl
index 6566de31..a3957838 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -122,15 +122,15 @@ function changerate(
     end
 end
 
-function summarystats_changerate(name, names_of_params, chains; kwargs...)
+function summarystats_changerate(name, names_of_params, chains)
     # Compute the change rates.
     changerates, mvchangerate = changerate(chains)
 
     # Summarize the results in a named tuple.
-    nt = (; parameter=names_of_params, changerate=changerates)
+    nt = (; changerate=changerates)
 
     # Create a SummaryStats.
-    return SummaryStats(name, nt; kwargs...), mvchangerate
+    return SummaryStats(name, nt, names_of_params), mvchangerate
 end
 
 changerate(chains::AbstractMatrix{<:Real}) = changerate(reshape(chains, Val(3)))

From 929c6546ee3684bd8bcb94304a41308653abc582 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sat, 23 Dec 2023 18:12:12 -0800
Subject: [PATCH 45/56] Use dict backing for cor summary

---
 src/stats.jl | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index a3957838..16f26605 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -76,16 +76,15 @@ function cor(
     end
 end
 
-function summarystats_cor(name, names_of_params, chains::AbstractMatrix; kwargs...)
+function summarystats_cor(name, names_of_params, chains::AbstractMatrix)
     # Compute the correlation matrix.
     cormat = cor(chains)
 
-    # Summarize the results in a named tuple.
-    nt = (; parameter = names_of_params,
-          zip(names_of_params, (cormat[:, i] for i in axes(cormat, 2)))...)
+    # Summarize the results in a dict
+    dict = OrderedCollections.OrderedDict(zip(names_of_params, eachcol(cormat)))
 
     # Create a SummaryStats.
-    return SummaryStats(name, nt; kwargs...)
+    return SummaryStats(name, dict, names_of_params)
 end
 
 """

From 5f7aa7c0ebfbc00e6f27a5ee39a2eb2dd5679f20 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sat, 23 Dec 2023 18:15:13 -0800
Subject: [PATCH 46/56] Remove unused kwargs

---
 src/stats.jl | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index 16f26605..002a3684 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -43,7 +43,7 @@ function _default_lags(chains::Chains, append_chains::Bool)
 end
 
 """
-    cor(chains[; sections, append_chains = true, kwargs...])
+    cor(chains[; sections, append_chains = true])
 
 Compute the Pearson correlation matrix for the chain.
 
@@ -54,7 +54,6 @@ function cor(
     chains::Chains;
     sections = _default_sections(chains),
     append_chains = true,
-    kwargs...
 )
     # Subset the chain.
     _chains = Chains(chains, _clean_sections(chains, sections))
@@ -88,7 +87,7 @@ function summarystats_cor(name, names_of_params, chains::AbstractMatrix)
 end
 
 """
-    changerate(chains[; sections, append_chains = true, kwargs...])
+    changerate(chains[; sections, append_chains = true])
 
 Compute the change rate for the chain.
 
@@ -99,7 +98,6 @@ function changerate(
     chains::Chains{<:Real};
     sections = _default_sections(chains),
     append_chains = true,
-    kwargs...
 )
     # Subset the chain.
     _chains = Chains(chains, _clean_sections(chains, sections))

From 49fe1c88bb6b97ee1de0d4d64a0580f2acd0781e Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sat, 23 Dec 2023 18:15:33 -0800
Subject: [PATCH 47/56] Refactor autocor to avoid large namedtuple

---
 src/stats.jl | 37 +++++++++++++++++++++++++++++++++----
 1 file changed, 33 insertions(+), 4 deletions(-)

diff --git a/src/stats.jl b/src/stats.jl
index 002a3684..7858fe2a 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -17,14 +17,43 @@ Setting `append_chains=false` will return a vector of dataframes containing the
 """
 function autocor(
     chains::Chains;
+    sections = _default_sections(chains),
     append_chains::Bool = true,
     demean::Bool = true,
     lags::AbstractVector{<:Integer} = _default_lags(chains, append_chains),
-    kwargs...
+    var_names = nothing,
+    kwargs...,
 )
-    fun_names = Tuple(Symbol.("lag", lags))
-    fun = (x -> autocor(vec(x), lags; demean=demean))
-    return summarize(chains, fun_names => fun; name = "Autocorrelation", append_chains, kwargs...)
+    chn = Chains(chains, _clean_sections(chains, sections))
+
+    # Obtain names of parameters.
+    names_of_params = var_names === nothing ? names(chn) : var_names
+
+    # set up the functions to be evaluated
+    col_names = Symbol.("lag", lags)
+
+    # avoids using summarize directly to support simultaneously computing a large number of
+    # lags without constructing a huge NamedTuple
+    if append_chains
+        # Evaluate the functions.
+        data = _permutedims_diagnostics(chn.value.data)
+        vals = stack(map(eachslice(data; dims=3)) do x
+            return autocor(vec(x), lags; demean=demean)
+        end)
+        table = Tables.table(vals'; header=col_names)
+        return SummaryStats("Autocorrelation", table, names_of_params)
+    else
+        # Evaluate the functions.
+        data = to_vector_of_matrices(chn)
+        return map(enumerate(data)) do (i, x)
+            name_chain = "Autocorrelation (Chain $i)"
+            vals = stack(map(eachslice(x; dims=2)) do xi
+                return autocor(xi, lags; demean=demean)
+            end)
+            table = Tables.table(vals'; header=col_names)
+            return SummaryStats(name_chain, table, names_of_params)
+        end
+    end
 end
 
 """

From 1cef3ba24c7c3ee6139f18ecb22c503bfb457a21 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sat, 23 Dec 2023 18:15:53 -0800
Subject: [PATCH 48/56] Use stack for autocor to avoid large compile times

---
 src/plot.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/plot.jl b/src/plot.jl
index 9727e536..6a3bb38b 100644
--- a/src/plot.jl
+++ b/src/plot.jl
@@ -64,7 +64,7 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
         lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag)
         # Chains are already appended in `c` if desired, hence we use `append_chains=false`
         ac = autocor(c; sections = nothing, lags = lags, append_chains=false)
-        ac_mat = cat(reduce.(hcat, Iterators.drop.(ac, 1))...; dims=3)
+        ac_mat = stack(map(stack ∘ Base.Fix2(Iterators.drop, 1), ac))
         val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :]
         _AutocorPlot(lags, val)
     elseif st ∈ supportedplots

From 17e08ed4bc9016923fedbefdb326a939be8211a3 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sat, 23 Dec 2023 19:24:49 -0800
Subject: [PATCH 49/56] Improve type inference of OrderedDict

---
 src/stats.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/stats.jl b/src/stats.jl
index 7858fe2a..41a814dd 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -109,7 +109,7 @@ function summarystats_cor(name, names_of_params, chains::AbstractMatrix)
     cormat = cor(chains)
 
     # Summarize the results in a dict
-    dict = OrderedCollections.OrderedDict(zip(names_of_params, eachcol(cormat)))
+    dict = OrderedCollections.OrderedDict(k => v for (k, v) in zip(names_of_params, eachcol(cormat)))
 
     # Create a SummaryStats.
     return SummaryStats(name, dict, names_of_params)

From 1e3e96a52382ca7b9b04f99e09eaecc76e9df8e7 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sat, 23 Dec 2023 19:51:34 -0800
Subject: [PATCH 50/56] Make doctest reproducible

---
 docs/Project.toml |  2 ++
 src/stats.jl      | 12 +++++++-----
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/docs/Project.toml b/docs/Project.toml
index 52ca0f6a..506ee3e4 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -7,6 +7,7 @@ Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
 MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
 MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
 MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
+StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
 StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
 
 [compat]
@@ -18,5 +19,6 @@ Gadfly = "1.3.4"
 MCMCChains = "7"
 MLJBase = "0.19, 0.20, 0.21, 1"
 MLJDecisionTreeInterface = "0.3, 0.4"
+StableRNGs = "1"
 StatsPlots = "0.14, 0.15"
 julia = "1.7"
diff --git a/src/stats.jl b/src/stats.jl
index 41a814dd..bd5834b5 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -215,16 +215,18 @@ Note that this will return a single interval and will not return multiple interv
 
 # Examples
 
-```jldoctest; setup = :(using Random; Random.seed!(582))
-julia> val = rand(500, 2, 3);
+```jldoctest
+julia> using StableRNGs; rng = StableRNG(42);
+
+julia> val = rand(rng, 500, 2, 3);
 
 julia> chn = Chains(val, [:a, :b]);
 
 julia> hdi(chn)
 HDI
-      lower  upper
- a  0.0749   0.999
- b  0.00531  0.940
+     lower  upper
+ a  0.0630  0.994
+ b  0.0404  0.968
 ```
 """
 function PosteriorStats.hdi(chn::Chains; prob::Real=0.94, kwargs...)

From d00189449f0794f88462a09f5ad9935b50625be8 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sat, 23 Dec 2023 23:13:24 -0800
Subject: [PATCH 51/56] Add StableRNGs as test dependency

---
 test/Project.toml | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/test/Project.toml b/test/Project.toml
index f9f4e97e..d9a20ba1 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -14,6 +14,7 @@ MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
 Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
+StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
 Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
 StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
@@ -33,6 +34,7 @@ KernelDensity = "0.6.2"
 MCMCChains = "7"
 MLJBase = "0.18, 0.19, 0.20, 0.21, 1"
 MLJDecisionTreeInterface = "0.3, 0.4"
+StableRNGs = "1"
 StatsBase = "0.33.2, 0.34"
 StatsPlots = "0.14.17, 0.15"
 TableTraits = "1"

From b3a31431cbfb5189608d7d7361708d50d775d45d Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 11 Feb 2024 00:47:32 +0100
Subject: [PATCH 52/56] Apply suggestions from code review

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 src/MCMCChains.jl | 3 +--
 src/stats.jl      | 4 +---
 2 files changed, 2 insertions(+), 5 deletions(-)

diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl
index 10038d8f..0ccad7b6 100644
--- a/src/MCMCChains.jl
+++ b/src/MCMCChains.jl
@@ -50,8 +50,7 @@ export rstar
 # Reexport stats functions
 using PosteriorStats: SummaryStats, default_diagnostics, default_stats,
     default_summary_stats, hdi, summarize
-export SummaryStats, default_diagnostics, default_stats, default_summary_stats, hdi,
-    summarize
+export SummaryStats, hdi, summarize
 
 """
     Chains
diff --git a/src/stats.jl b/src/stats.jl
index bd5834b5..17f5f035 100644
--- a/src/stats.jl
+++ b/src/stats.jl
@@ -29,13 +29,12 @@ function autocor(
     # Obtain names of parameters.
     names_of_params = var_names === nothing ? names(chn) : var_names
 
-    # set up the functions to be evaluated
+    # Construct column names for lags.
     col_names = Symbol.("lag", lags)
 
     # avoids using summarize directly to support simultaneously computing a large number of
     # lags without constructing a huge NamedTuple
     if append_chains
-        # Evaluate the functions.
         data = _permutedims_diagnostics(chn.value.data)
         vals = stack(map(eachslice(data; dims=3)) do x
             return autocor(vec(x), lags; demean=demean)
@@ -43,7 +42,6 @@ function autocor(
         table = Tables.table(vals'; header=col_names)
         return SummaryStats("Autocorrelation", table, names_of_params)
     else
-        # Evaluate the functions.
         data = to_vector_of_matrices(chn)
         return map(enumerate(data)) do (i, x)
             name_chain = "Autocorrelation (Chain $i)"

From 0fad7f9a9e4722b90ae764cfd204fababc42a5ad Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 11 Feb 2024 11:44:40 +0100
Subject: [PATCH 53/56] Avoid splatting

---
 src/discretediag.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/discretediag.jl b/src/discretediag.jl
index 5cb7e054..0b3d699c 100644
--- a/src/discretediag.jl
+++ b/src/discretediag.jl
@@ -26,7 +26,7 @@ function MCMCDiagnosticTools.discretediag(
         vals = map(val -> val[:, i], within_chain_vals)
         return SummaryStats("Chisq diagnostic - Chain $i", vals, param_names)
     end
-    stats = [between_chain_stats, within_chain_stats...]
+    stats = vcat([between_chain_stats], within_chain_stats)
 
     return stats
 end

From 3579c55b355b2e59e9c428b4e24dbe9fb22de421 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 11 Feb 2024 18:52:36 +0100
Subject: [PATCH 54/56] Avoid recomputing all medians for every parameter

---
 src/plot.jl | 28 +++++++++++++++-------------
 1 file changed, 15 insertions(+), 13 deletions(-)

diff --git a/src/plot.jl b/src/plot.jl
index 6a3bb38b..19452003 100644
--- a/src/plot.jl
+++ b/src/plot.jl
@@ -206,16 +206,10 @@ function _compute_plot_data(
     show_hdii = true,
     fill_q = true,
     fill_hdi = false,
-    ordered = false
 )
+    hdii = sort(hdi_prob; rev=true)
 
-    chain_dic = Dict(zip(quantile(chains)[2], quantile(chains)[5]))
-    sorted_chain = sort(collect(zip(values(chain_dic), keys(chain_dic))))
-    sorted_par = [sorted_chain[i][2] for i in 1:length(par_names)]
-    par = (ordered ? sorted_par : par_names)
-    hdii = sort(hdi_prob)
-
-    chain_sections = MCMCChains.group(chains, Symbol(par[i]))
+    chain_sections = MCMCChains.group(chains, Symbol(par_names[i]))
     chain_vec = vec(chain_sections.value.data)
     lower_hdi = [MCMCChains.hdi(chain_sections, prob = hdii[j])[:lower]
         for j in 1:length(hdii)]
@@ -239,7 +233,7 @@ function _compute_plot_data(
     min = minimum(k_density.density .+ h)
     q_int = (show_qi ? [qs[1], chain_med, qs[2]] : [chain_med])
 
-    return par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med,
+    return par_names, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med,
         chain_mean, min, q_int
 end
 
@@ -261,12 +255,16 @@ end
     chn = p.args[1]
     par_names = p.args[2]
 
+    if ordered
+        par_table_names, par_medians = summarize(chn[:, par_names, :], median)
+        par_names = par_table_names[sortperm(par_medians)]
+    end
+
     for i in 1:length(par_names)
         par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean,
             min, q_int = _compute_plot_data(i, chn, par_names; hdi_prob = hdi_prob, q = q,
             spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median,
-            show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi,
-            ordered = ordered)
+            show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi)
 
         yticks --> (length(par_names) > 1 ?
             (_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default)
@@ -350,12 +348,16 @@ end
     chn = p.args[1]
     par_names = p.args[2]
 
+    if ordered
+        par_table_names, par_medians = summarize(chn[:, par_names, :], median)
+        par_names = par_table_names[sortperm(par_medians)]
+    end
+
     for i in 1:length(par_names)
         par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean,
             min, q_int = _compute_plot_data(i, chn, par_names; hdi_prob = hdi_prob, q = q,
             spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median,
-            show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi,
-            ordered = ordered)
+            show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi)
 
         yticks --> (length(par_names) > 1 ?
             (_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default)

From 53329164285b10edcf646f2aaf98a678636da6a6 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 11 Feb 2024 20:22:27 +0100
Subject: [PATCH 55/56] Add test for show method

---
 test/runtests.jl   |  4 ++++
 test/show_tests.jl | 23 +++++++++++++++++++++++
 2 files changed, 27 insertions(+)
 create mode 100644 test/show_tests.jl

diff --git a/test/runtests.jl b/test/runtests.jl
index dfa37395..0c0e8819 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -60,6 +60,10 @@ Random.seed!(0)
     println("Model statistics")
     @time include("modelstats_test.jl")
 
+    # run tests for show methods
+    println("Show methods")
+    @time include("show_tests.jl")
+
     # run tests for concatenation
     println("Concatenation")
     @time include("concatenation_tests.jl")
diff --git a/test/show_tests.jl b/test/show_tests.jl
new file mode 100644
index 00000000..f1a6b9b8
--- /dev/null
+++ b/test/show_tests.jl
@@ -0,0 +1,23 @@
+using Test
+using MCMCChains
+
+@testset "Show tests" begin
+    rng = MersenneTwister(1234)
+    val = rand(rng, 100, 4, 4)
+    parm_names = ["a", "b", "c", "d"]
+    chns = Chains(val, parm_names, Dict(:internals => ["b", "d"]))[1:2:99, :, :]
+    str = sprint(show, "text/plain", chns)
+    stats_str = sprint(show, "text/plain", summarystats(chns))
+    quantile_str = sprint(show, "text/plain", quantile(chns))
+    @test str == """Chains MCMC chain (50×4×4 Array{Float64, 3}):
+
+    Iterations        = 1:2:99
+    Number of chains  = 4
+    Samples per chain = 50
+    parameters        = a, c
+    internals         = b, d
+
+    $stats_str
+
+    $quantile_str"""
+end

From d4882200aa05dee6f34ce7c0bbfa473c35794c30 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Sun, 17 Nov 2024 22:02:27 +0100
Subject: [PATCH 56/56] Update ess_rhat tests

---
 test/ess_rhat_tests.jl | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/test/ess_rhat_tests.jl b/test/ess_rhat_tests.jl
index df487f7d..63adb79c 100644
--- a/test/ess_rhat_tests.jl
+++ b/test/ess_rhat_tests.jl
@@ -49,15 +49,15 @@ end
     for autocov_method in (AutocovMethod(), FFTAutocovMethod(), BDAAutocovMethod())
         # analyze chain
         ess_df = ess(chain; autocov_method = autocov_method)
-        @test isequal(ess_df[:, :ess], fill(NaN, 5))
-        @test isequal(ess_df[:, :ess_per_sec], fill(missing, 5))
+        @test isequal(ess_df[:ess], fill(NaN, 5))
+        @test isequal(ess_df[:ess_per_sec], fill(missing, 5))
         
         ess_rhat_df = ess_rhat(chain; autocov_method = autocov_method)
-        @test isequal(ess_rhat_df[:, :ess], fill(NaN, 5))
-        @test isequal(ess_rhat_df[:, :rhat], fill(NaN, 5))
-        @test isequal(ess_rhat_df[:, :ess_per_sec], fill(missing, 5))
+        @test isequal(ess_rhat_df[:ess], fill(NaN, 5))
+        @test isequal(ess_rhat_df[:rhat], fill(NaN, 5))
+        @test isequal(ess_rhat_df[:ess_per_sec], fill(missing, 5))
     end
 
     rhat_df = rhat(chain)
-    @test isequal(rhat_df[:, :rhat], fill(NaN, 5))
+    @test isequal(rhat_df[:rhat], fill(NaN, 5))
 end