Skip to content

Commit

Permalink
Add plot_psense_dist (#93)
Browse files Browse the repository at this point in the history
* Add plot_psense_dist

* Update src/arviz_plots/plots/psensedistplot.py

Co-authored-by: Oriol Abril-Pla <[email protected]>

* Update src/arviz_plots/plots/psensedistplot.py

Co-authored-by: Oriol Abril-Pla <[email protected]>

* concat da and simplify logic

* set sample_dims to sample

* refactor

* minor fixes and update pyproject to install from GH

* support sample_dims argument and all backends

* add minigallery to docstring

* tweak example

* ensure pointinterval only plot does not have yticks

* add initial test for psense plot

* add test and example

* fix test

* fix docstring

* rename __group__

---------

Co-authored-by: Oriol Abril-Pla <[email protected]>
  • Loading branch information
aloctavodia and OriolAbril authored Oct 23, 2024
1 parent 71223f5 commit 2758b92
Show file tree
Hide file tree
Showing 12 changed files with 479 additions and 26 deletions.
1 change: 1 addition & 0 deletions docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ A complementary introduction and guide to ``plot_...`` functions is available at
plot_ess
plot_ess_evolution
plot_forest
plot_psense_dist
plot_ridge
plot_trace
plot_trace_dist
26 changes: 26 additions & 0 deletions docs/source/gallery/model_criticism/plot_psense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
# Power scaling prior sensitivity plot
Plot of power scaling prior sensitivity distribution
---
:::{seealso}
API Documentation: {func}`~arviz_plots.plot_psense_dist`
:::
"""
from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-clean")

idata = load_arviz_data("rugby")
pc = azp.plot_psense_dist(
idata,
var_names=["defs", "sd_att", "sd_def"],
coords={"team": ["Scotland", "Wales"]},
pc_kwargs={"y": [-2, -1, 0]},
backend="none",
)
pc.show()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
]
dynamic = ["version", "description"]
dependencies = [
"arviz-base==0.2",
"arviz-base @ git+https://github.com/arviz-devs/arviz-base",
"arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats",
]

Expand Down
36 changes: 31 additions & 5 deletions src/arviz_plots/backend/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def create_plotting_grid(
number,
rows=1,
cols=1,
*,
figsize=None,
figsize_units="inches",
squeeze=True,
Expand Down Expand Up @@ -188,21 +189,46 @@ def create_plotting_grid(
width_ratios /= width_ratios.sum()
plot_widths = np.ceil(chart_width * width_ratios).astype(int)

shared_xrange = {}
shared_yrange = {}
for row in range(rows):
for col in range(cols):
subplot_kws_i = subplot_kws.copy()
if col != 0 and sharex == "row":
subplot_kws_i["x_range"] = shared_xrange[row]
if row != 0 and sharex == "col":
subplot_kws_i["x_range"] = shared_xrange[col]
if col != 0 and sharey == "row":
subplot_kws_i["y_range"] = shared_yrange[row]
if row != 0 and sharey == "col":
subplot_kws_i["y_range"] = shared_yrange[col]
if width_ratios is not None:
subplot_kws["width"] = plot_widths[col]
if (row == 0) and (col == 0) and (sharex or sharey):
p = figure(**subplot_kws) # pylint: disable=invalid-name
if (row == 0) and (col == 0) and (sharex is True or sharey is True):
p = figure(**subplot_kws_i) # pylint: disable=invalid-name
figures[row, col] = p
if sharex:
if sharex is True:
subplot_kws["x_range"] = p.x_range
if sharey:
if sharey is True:
subplot_kws["y_range"] = p.y_range
elif col == 0 and (sharex == "row" or sharey == "row"):
p = figure(**subplot_kws_i) # pylint: disable=invalid-name
figures[row, col] = p
if sharex == "row":
shared_xrange[row] = p.x_range
if sharey == "row":
shared_yrange[row] = p.y_range
elif row == 0 and (sharex == "col" or sharey == "col"):
p = figure(**subplot_kws_i) # pylint: disable=invalid-name
figures[row, col] = p
if sharex == "col":
shared_xrange[col] = p.x_range
if sharey == "col":
shared_yrange[col] = p.y_range
elif row * cols + (col + 1) > number:
figures[row, col] = None
else:
figures[row, col] = figure(**subplot_kws)
figures[row, col] = figure(**subplot_kws_i)
if squeeze and figures.size == 1:
return None, figures[0, 0]
layout = gridplot(figures.tolist(), **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/arviz_plots/backend/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def create_plotting_grid(
number,
rows=1,
cols=1,
*,
figsize=None,
figsize_units="inches",
squeeze=True,
Expand Down
1 change: 1 addition & 0 deletions src/arviz_plots/backend/none/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def create_plotting_grid(
number,
rows=1,
cols=1,
*,
figsize=None,
figsize_units="inches",
squeeze=True,
Expand Down
4 changes: 4 additions & 0 deletions src/arviz_plots/backend/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def create_plotting_grid(
number, # pylint: disable=unused-argument
rows=1,
cols=1,
*,
figsize=None,
figsize_units="inches",
squeeze=True,
Expand Down Expand Up @@ -242,6 +243,9 @@ def create_plotting_grid(
layout_kwargs["height"] = figsize[1]

kwargs["figure"] = go.Figure(layout=layout_kwargs)
share_lookup = {True: "all", "col": "columns", "row": "rows"}
sharex = share_lookup.get(sharex, sharex)
sharey = share_lookup.get(sharey, sharey)

chart = make_subplots(
rows=int(rows),
Expand Down
2 changes: 1 addition & 1 deletion src/arviz_plots/plot_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def generate_aes_dt(self, aes=None, **kwargs):
var_name: set(dims) <= set(da.dims) for var_name, da in self.data.items()
}
if not any(aes_dims_in_var.values()):
warnings.warning(
warnings.warn(
"Provided mapping for {aes_key} will only use the neutral element"
)
aes_shape = [self.data.sizes[dim] for dim in dims]
Expand Down
2 changes: 2 additions & 0 deletions src/arviz_plots/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .essplot import plot_ess
from .evolutionplot import plot_ess_evolution
from .forestplot import plot_forest
from .psensedistplot import plot_psense_dist
from .ridgeplot import plot_ridge
from .tracedistplot import plot_trace_dist
from .traceplot import plot_trace
Expand All @@ -18,4 +19,5 @@
"plot_ess",
"plot_ess_evolution",
"plot_ridge",
"plot_psense_dist",
]
Loading

0 comments on commit 2758b92

Please sign in to comment.