From 68110b38a96c9ccd8cff359be397b320a349b603 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 15 Mar 2024 13:53:08 +0100 Subject: [PATCH] feat: Allow sampling in the backgound --- python/nutpie/sample.py | 230 ++++++++++++++++++++++++++++++++-------- tests/test_pymc.py | 11 ++ 2 files changed, 196 insertions(+), 45 deletions(-) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 3cbcf96..8d7514d 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -1,11 +1,14 @@ +import math +import time from dataclasses import dataclass -from typing import Optional +from typing import Any, Literal, Optional, overload import arviz import fastprogress import numpy as np import pandas as pd import pyarrow +from fastprogress.fastprogress import ConsoleProgressBar from nutpie import _lib @@ -125,6 +128,163 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs): ) +class _BackgroundSampler: + _sampler: Any + _num_divs: int + _tune: int + _chains: int + _compiled_model: CompiledModel + _save_warmup: bool + _progress = ConsoleProgressBar + + def __init__( + self, + compiled_model, + sampler, + chains, + draws, + tune, + *, + progress_bar=True, + save_warmup=True, + return_raw_trace=False, + ): + self._sampler = sampler + self._num_divs = 0 + self._tune = tune + self._chains_tuning = chains + self._chains = chains + self._sampler = sampler + self._compiled_model = compiled_model + self._save_warmup = save_warmup + self._return_raw_trace = return_raw_trace + self._progress = fastprogress.progress_bar( + sampler, + total=chains * (draws + tune), + display=progress_bar, + ) + + def wait(self, *, timeout=None): + """Wait until sampling is finished. + + KeyboardInterrupt will lead to interrupt the waiting. + + This will return after `timeout` seconds even if sampling is + not finished at this point. + """ + if self._sampler is None: + raise ValueError("Sampler is already finalized") + + if timeout is None: + timeout = math.inf + + start_time = time.time() + + try: + for info in self._progress: + current_time = time.time() + if current_time - start_time > timeout: + raise TimeoutError("Sampling did not finish") + + if info.draw == self._tune - 1: + self._chains_tuning -= 1 + if info.is_diverging and info.draw > self._tune: + self._num_divs += 1 + if self._chains_tuning > 0: + count = self._chains_tuning + divs = self._num_divs + self._progress.comment = ( + f" Chains in warmup: {count}, Divergences: {divs}" + ) + else: + count = self._chains - self._chains_tuning + divs = self._num_divs + self._progress.comment = ( + f" Sampling chains: {count}, Divergences: {divs}" + ) + except KeyboardInterrupt: + pass + + def finalize(self): + """Free resources of the sampler and return the trace produced so far.""" + if self._sampler is None: + raise ValueError("Sampler has already been finalized") + + results = self._sampler.finalize() + self._sampler = None + + dims = {name: list(dim) for name, dim in self._compiled_model.dims.items()} + dims["mass_matrix_inv"] = ["unconstrained_parameter"] + dims["gradient"] = ["unconstrained_parameter"] + dims["unconstrained_draw"] = ["unconstrained_parameter"] + dims["divergence_start"] = ["unconstrained_parameter"] + dims["divergence_start_gradient"] = ["unconstrained_parameter"] + dims["divergence_end"] = ["unconstrained_parameter"] + dims["divergence_momentum"] = ["unconstrained_parameter"] + + if self._return_raw_trace: + return results + else: + return _trace_to_arviz( + results, + self._tune, + self._compiled_model.shapes, + dims=dims, + coords={ + name: pd.Index(vals) + for name, vals in self._compiled_model.coords.items() + }, + save_warmup=self._save_warmup, + ) + + def abort(self): + """Abort sampling and discard progress.""" + if self._sampler is not None: + self._sampler.finalize() + self._sampler = None + + def __del__(self): + self.abort() + + +@overload +def sample( + compiled_model: CompiledModel, + *, + draws: int, + tune: int, + chains: int, + cores: int, + seed: Optional[int], + save_warmup: bool, + progress_bar: bool, + init_mean: Optional[np.ndarray], + return_raw_trace: bool, + blocking: Literal[True], + **kwargs, +) -> arviz.InferenceData: + ... + + +@overload +def sample( + compiled_model: CompiledModel, + *, + draws: int, + tune: int, + chains: int, + cores: int, + seed: Optional[int], + save_warmup: bool, + progress_bar: bool, + init_mean: Optional[np.ndarray], + return_raw_trace: bool, + blocking: Literal[False], + **kwargs, +) -> _BackgroundSampler: + ... + + def sample( compiled_model: CompiledModel, *, @@ -136,7 +296,8 @@ def sample( save_warmup: bool = True, progress_bar: bool = True, init_mean: Optional[np.ndarray] = None, - return_raw_trace=False, + return_raw_trace: bool = False, + blocking: bool = True, **kwargs, ) -> arviz.InferenceData: """Sample the posterior distribution for a compiled model. @@ -213,47 +374,26 @@ def sample( sampler = compiled_model._make_sampler(settings, init_mean, chains, cores, seed) + sampler = _BackgroundSampler( + compiled_model, + sampler, + chains, + draws, + tune, + progress_bar=progress_bar, + save_warmup=save_warmup, + return_raw_trace=return_raw_trace, + ) + + if not blocking: + return sampler + try: - num_divs = 0 - chains_tuning = chains - bar = fastprogress.progress_bar( - sampler, - total=chains * (draws + tune), - display=progress_bar, - ) - try: - for info in bar: - if info.draw == tune - 1: - chains_tuning -= 1 - if info.is_diverging and info.draw > tune: - num_divs += 1 - bar.comment = ( - f" Chains in warmup: {chains_tuning}, Divergences: {num_divs}" - ) - except KeyboardInterrupt: - pass - finally: - results = sampler.finalize() - - dims = {name: list(dim) for name, dim in compiled_model.dims.items()} - dims["mass_matrix_inv"] = ["unconstrained_parameter"] - dims["gradient"] = ["unconstrained_parameter"] - dims["unconstrained_draw"] = ["unconstrained_parameter"] - dims["divergence_start"] = ["unconstrained_parameter"] - dims["divergence_start_gradient"] = ["unconstrained_parameter"] - dims["divergence_end"] = ["unconstrained_parameter"] - dims["divergence_momentum"] = ["unconstrained_parameter"] - - if return_raw_trace: - return results - else: - return _trace_to_arviz( - results, - tune, - compiled_model.shapes, - dims=dims, - coords={ - name: pd.Index(vals) for name, vals in compiled_model.coords.items() - }, - save_warmup=save_warmup, - ) + sampler.wait() + except KeyboardInterrupt: + pass + except: + sampler.abort() + raise + + return sampler.finalize() diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 101f1a0..fc14d6d 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -15,6 +15,17 @@ def test_pymc_model(): trace.posterior.a # noqa: B018 +def test_blocking(): + with pm.Model() as model: + pm.Normal("a") + + compiled = nutpie.compile_pymc_model(model) + sampler = nutpie.sample(compiled, chains=1, blocking=False) + sampler.wait() + trace = sampler.finalize() + trace.posterior.a # noqa: B018 + + def test_pymc_model_with_coordinate(): with pm.Model() as model: model.add_coord("foo", length=5)