diff --git a/mcbackend/__init__.py b/mcbackend/__init__.py index 5b9800b..7ffdf23 100644 --- a/mcbackend/__init__.py +++ b/mcbackend/__init__.py @@ -12,7 +12,7 @@ except ModuleNotFoundError: pass -__version__ = "0.5.1" +__version__ = "0.5.2" __all__ = [ "NumPyBackend", "Backend", diff --git a/mcbackend/core.py b/mcbackend/core.py index 30593fa..85970b8 100644 --- a/mcbackend/core.py +++ b/mcbackend/core.py @@ -23,6 +23,9 @@ _log = logging.getLogger(__file__) +__all__ = ("is_rigid", "chain_id", "Chain", "Run", "Backend") + + def is_rigid(nshape: Optional[Shape]): """Determines wheather the shape is constant. @@ -133,6 +136,20 @@ def sample_stats(self) -> Dict[str, Variable]: return {var.name: var for var in self.rmeta.sample_stats} +def get_tune_mask(chain: Chain, slc: slice = slice(None)) -> numpy.ndarray: + """Load the tuning mask from either a ``"tune"``, or a ``"*__tune"`` stat. + + Raises + ------ + KeyError + When no matching stat is found. + """ + for sname in chain.sample_stats: + if sname.endswith("__tune") or sname == "tune": + return chain.get_stats(sname, slc).astype(bool) + raise KeyError("No tune stat found.") + + class Run: """A handle on one MCMC run.""" @@ -231,14 +248,15 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) -> slc = slice(0, min_clen) # Obtain a mask by which draws can be split into warmup/posterior - if "tune" in chain.sample_stats: - tune = chain.get_stats("tune", slc).astype(bool) - else: + try: + # Use the same slice to avoid shape issues in case the chain is still active + tune = get_tune_mask(chain, slc) + except KeyError: if c == 0: _log.warning( "No 'tune' stat found. Assuming all iterations are posterior draws." ) - tune = numpy.full((chain_lengths[chain.cid],), False) + tune = numpy.full((slc.stop,), False) # Split all variables draws into warmup/posterior for var in variables: