From 3077cde1a3a727bfeba6e2f2e33157e48e9ec314 Mon Sep 17 00:00:00 2001 From: Tristan Fillinger Date: Fri, 1 Nov 2024 03:04:57 +0900 Subject: [PATCH] feat: refactor histplot + new get_plottables function (#534) --- src/mplhep/__init__.py | 2 + src/mplhep/plot.py | 171 ++---------------------- src/mplhep/utils.py | 288 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 302 insertions(+), 159 deletions(-) diff --git a/src/mplhep/__init__.py b/src/mplhep/__init__.py index 948671b6..a4e3d2a7 100644 --- a/src/mplhep/__init__.py +++ b/src/mplhep/__init__.py @@ -27,6 +27,7 @@ yscale_legend, ) from .styles import set_style +from .utils import get_plottables # Configs rcParams = Config( @@ -76,4 +77,5 @@ "sort_legend", "save_variations", "set_style", + "get_plottables", ] diff --git a/src/mplhep/plot.py b/src/mplhep/plot.py index 6755f3da..0953f920 100644 --- a/src/mplhep/plot.py +++ b/src/mplhep/plot.py @@ -14,10 +14,10 @@ from mpl_toolkits.axes_grid1 import axes_size, make_axes_locatable from .utils import ( - Plottable, align_marker, get_histogram_axes_title, get_plottable_protocol_bins, + get_plottables, hist_object_handler, isLight, process_histogram_parts, @@ -198,86 +198,18 @@ def histplot( else get_histogram_axes_title(hists[0].axes[0]) ) - plottables = [] - flow_bins = final_bins - for h in hists: - value, variance = np.copy(h.values()), h.variances() - if has_variances := variance is not None: - variance = np.copy(variance) - underflow, overflow = 0.0, 0.0 - underflowv, overflowv = 0.0, 0.0 - # One sided flow bins - hist (uproot hist does not have the over- or underflow traits) - if ( - hasattr(h, "axes") - and (traits := getattr(h.axes[0], "traits", None)) is not None - and hasattr(traits, "underflow") - and hasattr(traits, "overflow") - ): - if traits.overflow: - overflow = np.copy(h.values(flow=True))[-1] - if has_variances: - overflowv = np.copy(h.variances(flow=True))[-1] - if traits.underflow: - underflow = np.copy(h.values(flow=True))[0] - if has_variances: - underflowv = np.copy(h.variances(flow=True))[0] - # Both flow bins exist - uproot - elif hasattr(h, "values") and "flow" in inspect.getfullargspec(h.values).args: - if len(h.values()) + 2 == len( - h.values(flow=True) - ): # easy case, both over/under - underflow, overflow = ( - np.copy(h.values(flow=True))[0], - np.copy(h.values(flow=True))[-1], - ) - if has_variances: - underflowv, overflowv = ( - np.copy(h.variances(flow=True))[0], - np.copy(h.variances(flow=True))[-1], - ) - - # Set plottables - if flow in ("none", "hint"): - plottables.append(Plottable(value, edges=final_bins, variances=variance)) - elif flow == "show": - _flow_bin_size: float = np.max( - [0.05 * (final_bins[-1] - final_bins[0]), np.mean(np.diff(final_bins))] - ) - flow_bins = np.copy(final_bins) - if underflow > 0: - flow_bins = np.r_[flow_bins[0] - _flow_bin_size, flow_bins] - value = np.r_[underflow, value] - if has_variances: - variance = np.r_[underflowv, variance] - if overflow > 0: - flow_bins = np.r_[flow_bins, flow_bins[-1] + _flow_bin_size] - value = np.r_[value, overflow] - if has_variances: - variance = np.r_[variance, overflowv] - plottables.append(Plottable(value, edges=flow_bins, variances=variance)) - elif flow == "sum": - if underflow > 0: - value[0] += underflow - if has_variances: - variance[0] += underflowv - if overflow > 0: - value[-1] += overflow - if has_variances: - variance[-1] += overflowv - plottables.append(Plottable(value, edges=final_bins, variances=variance)) - else: - plottables.append(Plottable(value, edges=final_bins, variances=variance)) - - if w2 is not None: - for _w2, _plottable in zip( - w2.reshape(len(plottables), len(final_bins) - 1), plottables - ): - _plottable.variances = _w2 - _plottable.method = w2method - - if w2 is not None and yerr is not None: - msg = "Can only supply errors or w2" - raise ValueError(msg) + plottables, flow_info = get_plottables( + hists, + bins=final_bins, + w2=w2, + w2method=w2method, + yerr=yerr, + stack=stack, + density=density, + binwnorm=binwnorm, + flow=flow, + ) + flow_bins, underflow, overflow = flow_info _labels: list[str | None] if label is None: @@ -311,52 +243,6 @@ def iterable_not_string(arg): for i in range(len(_chunked_kwargs)): _chunked_kwargs[i][kwarg] = kwargs[kwarg] - ############################ - # # yerr calculation - _yerr: np.ndarray | None - if yerr is not None: - # yerr is array - if hasattr(yerr, "__len__"): - _yerr = np.asarray(yerr) - # yerr is a number - elif isinstance(yerr, (int, float)) and not isinstance(yerr, bool): - _yerr = np.ones((len(plottables), len(final_bins) - 1)) * yerr - # yerr is automatic - else: - _yerr = None - else: - _yerr = None - - if _yerr is not None: - assert isinstance(_yerr, np.ndarray) - if _yerr.ndim == 3: - # Already correct format - pass - elif _yerr.ndim == 2 and len(plottables) == 1: - # Broadcast ndim 2 to ndim 3 - if _yerr.shape[-2] == 2: # [[1,1], [1,1]] - _yerr = _yerr.reshape(len(plottables), 2, _yerr.shape[-1]) - elif _yerr.shape[-2] == 1: # [[1,1]] - _yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1]) - else: - msg = "yerr format is not understood" - raise ValueError(msg) - elif _yerr.ndim == 2: - # Broadcast yerr (nh, N) to (nh, 2, N) - _yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1]) - elif _yerr.ndim == 1: - # Broadcast yerr (1, N) to (nh, 2, N) - _yerr = np.tile(_yerr, 2 * len(plottables)).reshape( - len(plottables), 2, _yerr.shape[-1] - ) - else: - msg = "yerr format is not understood" - raise ValueError(msg) - - assert _yerr is not None - for yrs, _plottable in zip(_yerr, plottables): - _plottable.fixed_errors(*yrs) - # Sorting if sort is not None: if isinstance(sort, str): @@ -379,34 +265,6 @@ def iterable_not_string(arg): _chunked_kwargs = [_chunked_kwargs[ix] for ix in order] _labels = [_labels[ix] for ix in order] - # ############################ - # # Stacking, norming, density - if density is True and binwnorm is not None: - msg = "Can only set density or binwnorm." - raise ValueError(msg) - if density is True: - if stack: - _total = np.sum( - np.array([plottable.values for plottable in plottables]), axis=0 - ) - for plottable in plottables: - plottable.flat_scale(1.0 / np.sum(np.diff(final_bins) * _total)) - else: - for plottable in plottables: - plottable.density = True - elif binwnorm is not None: - for plottable, norm in zip( - plottables, np.broadcast_to(binwnorm, (len(plottables),)) - ): - plottable.flat_scale(norm) - plottable.binwnorm() - - # Stack - if stack and len(plottables) > 1: - from .utils import stack as stack_fun - - plottables = stack_fun(*plottables) - ########## # Plotting return_artists: list[StairsArtists | ErrorBarArtists] = [] @@ -443,8 +301,7 @@ def iterable_not_string(arg): if "step" in histtype: for i in range(len(plottables)): do_errors = yerr is not False and ( - (yerr is not None or w2 is not None) - or (plottables[i].variances is not None) + (yerr is not None or w2 is not None) or plottables[i]._has_variances ) _kwargs = _chunked_kwargs[i] diff --git a/src/mplhep/utils.py b/src/mplhep/utils.py index 68468d33..81060119 100644 --- a/src/mplhep/utils.py +++ b/src/mplhep/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import inspect import warnings from numbers import Real from typing import TYPE_CHECKING, Any, Iterable, Sequence @@ -52,8 +53,7 @@ def hist_object_handler( hist = (hist, None) hist_obj = ensure_plottable_histogram(hist) elif isinstance(hist, PlottableHistogram): - msg = "Cannot give bins with existing histogram" - raise TypeError(msg) + hist_obj = hist else: hist_obj = ensure_plottable_histogram((hist, *bins)) @@ -135,14 +135,297 @@ def get_histogram_axes_title(axis: Any) -> str: return "" +def get_plottables( + H, + bins=None, + yerr: ArrayLike | bool | None = None, + w2=None, + w2method=None, + flow="hint", + stack=False, + density=False, + binwnorm=None, +): + """ + Generate plottable histograms from various histogram data sources. + + Parameters + ---------- + H : object + Histogram object with containing values and optionally bins. Can be: + + - `np.histogram` tuple + - PlottableProtocol histogram object + - `boost_histogram` classic (<0.13) histogram object + - raw histogram values, provided `bins` is specified. + + Or list thereof. + bins : iterable, optional + Histogram bins, if not part of ``H``. + yerr : iterable or bool, optional + Histogram uncertainties. Following modes are supported: + - True, sqrt(N) errors or poissonian interval when ``w2`` is specified + - shape(N) array of for one sided errors or list thereof + - shape(Nx2) array of for two sided errors or list thereof + w2 : iterable, optional + Sum of the histogram weights squared for poissonian interval error + calculation + w2method: callable, optional + Function calculating CLs with signature ``low, high = fcn(w, w2)``. Here + ``low`` and ``high`` are given in absolute terms, not relative to w. + Default is ``None``. If w2 has integer values (likely to be data) poisson + interval is calculated, otherwise the resulting error is symmetric + ``sqrt(w2)``. Specifying ``poisson`` or ``sqrt`` will force that behaviours. + flow : str, optional { "show", "sum", "hint", "none"} + Whether plot the under/overflow bin. If "show", add additional under/overflow bin. + If "sum", add the under/overflow bin content to first/last bin. + stack : bool, optional + Whether to stack or overlay non-axis dimension (if it exists). N.B. in + contrast to ROOT, stacking is performed in a single call aka + ``histplot([h1, h2, ...], stack=True)`` as opposed to multiple calls. + density : bool, optional + If true, convert sum weights to probability density (i.e. integrates to 1 + over domain of axis) (Note: this option conflicts with ``binwnorm``) + binwnorm : float, optional + If true, convert sum weights to bin-width-normalized, with unit equal to + supplied value (usually you want to specify 1.) + + Returns + ------- + plottables : list of Plottable + Processed histogram objects ready for plotting. + (flow_bins, underflow, overflow) : tuple + Flow bin information for handling underflow and overflow values. + """ + plottables = [] + flow_bins = np.copy(bins) + + hists = list(process_histogram_parts(H, bins)) + final_bins, _ = get_plottable_protocol_bins(hists[0].axes[0]) + + for h in hists: + value, variance = np.copy(h.values()), h.variances() + if has_variances := variance is not None: + variance = np.copy(variance) + underflow, overflow = 0.0, 0.0 + underflowv, overflowv = 0.0, 0.0 + # One sided flow bins - hist (uproot hist does not have the over- or underflow traits) + if ( + hasattr(h, "axes") + and (traits := getattr(h.axes[0], "traits", None)) is not None + and hasattr(traits, "underflow") + and hasattr(traits, "overflow") + ): + if traits.overflow: + overflow = np.copy(h.values(flow=True))[-1] + if has_variances: + overflowv = np.copy(h.variances(flow=True))[-1] + if traits.underflow: + underflow = np.copy(h.values(flow=True))[0] + if has_variances: + underflowv = np.copy(h.variances(flow=True))[0] + # Both flow bins exist - uproot + elif hasattr(h, "values") and "flow" in inspect.getfullargspec(h.values).args: + if len(h.values()) + 2 == len( + h.values(flow=True) + ): # easy case, both over/under + underflow, overflow = ( + np.copy(h.values(flow=True))[0], + np.copy(h.values(flow=True))[-1], + ) + if has_variances: + underflowv, overflowv = ( + np.copy(h.variances(flow=True))[0], + np.copy(h.variances(flow=True))[-1], + ) + + # Set plottables + if flow in ("none", "hint"): + plottables.append(Plottable(value, edges=final_bins, variances=variance)) + elif flow == "show": + _flow_bin_size: float = np.max( + [0.05 * (final_bins[-1] - final_bins[0]), np.mean(np.diff(final_bins))] + ) + flow_bins = np.copy(final_bins) + if underflow > 0: + flow_bins = np.r_[flow_bins[0] - _flow_bin_size, flow_bins] + value = np.r_[underflow, value] + if has_variances: + variance = np.r_[underflowv, variance] + if overflow > 0: + flow_bins = np.r_[flow_bins, flow_bins[-1] + _flow_bin_size] + value = np.r_[value, overflow] + if has_variances: + variance = np.r_[variance, overflowv] + plottables.append(Plottable(value, edges=flow_bins, variances=variance)) + elif flow == "sum": + if underflow > 0: + value[0] += underflow + if has_variances: + variance[0] += underflowv + if overflow > 0: + value[-1] += overflow + if has_variances: + variance[-1] += overflowv + plottables.append(Plottable(value, edges=final_bins, variances=variance)) + else: + plottables.append(Plottable(value, edges=final_bins, variances=variance)) + + if w2 is not None: + for _w2, _plottable in zip( + w2.reshape(len(plottables), len(final_bins) - 1), plottables + ): + _plottable.variances = _w2 + _plottable.method = w2method + + if w2 is not None and yerr is not None: + msg = "Can only supply errors or w2" + raise ValueError(msg) + + yerr_plottables(plottables, final_bins, yerr) + norm_stack_plottables(plottables, final_bins, stack, density, binwnorm) + + return plottables, (flow_bins, underflow, overflow) + + +def yerr_plottables(plottables, bins, yerr=None): + """ + Calculate and format y-axis errors for Plottables. + + Parameters + ---------- + plottables : list of Plottable + List of Plottable objects. + bins : iterable + Plottable bins. + yerr : iterable or bool, optional + Histogram uncertainties. Following modes are supported: + - True, sqrt(N) errors or poissonian interval when ``w2`` is specified + - shape(N) array of for one sided errors or list thereof + - shape(Nx2) array of for two sided errors or list thereof + + Raises + ------ + ValueError + If `yerr` has an unrecognized format. + """ + + _yerr: np.ndarray | None + if yerr is not None: + # yerr is array + if hasattr(yerr, "__len__"): + _yerr = np.asarray(yerr) + # yerr is a number + elif isinstance(yerr, (int, float)) and not isinstance(yerr, bool): + _yerr = np.ones((len(plottables), len(bins) - 1)) * yerr + # yerr is automatic + else: + _yerr = None + for _plottable in plottables: + _plottable.errors() + else: + _yerr = None + if _yerr is not None: + assert isinstance(_yerr, np.ndarray) + if _yerr.ndim == 3: + # Already correct format + pass + elif _yerr.ndim == 2 and len(plottables) == 1: + # Broadcast ndim 2 to ndim 3 + if _yerr.shape[-2] == 2: # [[1,1], [1,1]] + _yerr = _yerr.reshape(len(plottables), 2, _yerr.shape[-1]) + elif _yerr.shape[-2] == 1: # [[1,1]] + _yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1]) + else: + msg = "yerr format is not understood" + raise ValueError(msg) + elif _yerr.ndim == 2: + # Broadcast yerr (nh, N) to (nh, 2, N) + _yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1]) + elif _yerr.ndim == 1: + # Broadcast yerr (1, N) to (nh, 2, N) + _yerr = np.tile(_yerr, 2 * len(plottables)).reshape( + len(plottables), 2, _yerr.shape[-1] + ) + else: + msg = "yerr format is not understood" + raise ValueError(msg) + + assert _yerr is not None + for yrs, _plottable in zip(_yerr, plottables): + _plottable.fixed_errors(*yrs) + + +def norm_stack_plottables(plottables, bins, stack=False, density=False, binwnorm=None): + """ + Normalize and stack histogram data with optional density or bin-width normalization. + + Parameters + ---------- + plottables : list of Plottable + List of Plottable objects. + bins : iterable + Plottable bins. + stack : bool, optional + Whether to stack or overlay non-axis dimension (if it exists). N.B. in + contrast to ROOT, stacking is performed in a single call aka + ``histplot([h1, h2, ...], stack=True)`` as opposed to multiple calls. + density : bool, optional + If true, convert sum weights to probability density (i.e. integrates to 1 + over domain of axis) (Note: this option conflicts with ``binwnorm``) + binwnorm : float, optional + If true, convert sum weights to bin-width-normalized, with unit equal to + supplied value (usually you want to specify 1.) + + Raises + ------ + ValueError + If both `density` and `binwnorm` are set, as they are mutually exclusive. + + Notes + ----- + Density and bin-width normalization cannot both be applied simultaneously. + For stacked histograms, this function uses an external utility to compute + the cumulative stacked values. + """ + + if density is True and binwnorm is not None: + msg = "Can only set density or binwnorm." + raise ValueError(msg) + if density is True: + if stack: + _total = np.sum( + np.array([plottable.values for plottable in plottables]), axis=0 + ) + for plottable in plottables: + plottable.flat_scale(1.0 / np.sum(np.diff(bins) * _total)) + else: + for plottable in plottables: + plottable.density = True + elif binwnorm is not None: + for plottable, norm in zip( + plottables, np.broadcast_to(binwnorm, (len(plottables),)) + ): + plottable.flat_scale(norm) + plottable.binwnorm() + + # Stack + if stack and len(plottables) > 1: + from .utils import stack as stack_fun + + plottables = stack_fun(*plottables) + + class Plottable: def __init__(self, values, *, edges=None, variances=None, yerr=None): self._values = np.array(values).astype(float) self.variances = None self._variances = None + self._has_variances = False if variances is not None: self._variances = np.array(variances).astype(float) self.variances = np.array(variances).astype(float) + self._has_variances = True self._density = False self.values = np.array(values).astype(float) @@ -221,6 +504,7 @@ def calculate_relative(method_fcn, variances): raise RuntimeError(msg) self.yerr_lo = np.nan_to_num(self.yerr_lo, 0) self.yerr_hi = np.nan_to_num(self.yerr_hi, 0) + self.variances = self.values if not self._has_variances else self.variances def fixed_errors(self, yerr_lo, yerr_hi): self.yerr_lo = yerr_lo