Skip to content

Commit

Permalink
feat: Allow sampling in the backgound
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Mar 15, 2024
1 parent 96df571 commit bdd60ff
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 45 deletions.
228 changes: 183 additions & 45 deletions python/nutpie/sample.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import math
import time
from dataclasses import dataclass
from typing import Optional
from typing import Any, Literal, Optional, overload

import arviz
import fastprogress
import numpy as np
import pandas as pd
import pyarrow
from fastprogress.fastprogress import ConsoleProgressBar

from nutpie import _lib

Expand Down Expand Up @@ -125,6 +128,161 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs):
)


class _BackgroundSampler:
_sampler: Any
_num_divs: int
_tune: int
_chains: int
_compiled_model: CompiledModel
_save_warmup: bool
_progress = ConsoleProgressBar

def __init__(
self,
compiled_model,
sampler,
chains,
draws,
tune,
*,
progress_bar=True,
save_warmup=True,
return_raw_trace=False,
):
self._sampler = sampler
self._num_divs = 0
self._tune = tune
self._chains_tuning = chains
self._chains = chains
self._sampler = sampler
self._compiled_model = compiled_model
self._save_warmup = save_warmup
self._return_raw_trace = return_raw_trace
self._progress = fastprogress.progress_bar(
sampler,
total=chains * (draws + tune),
display=progress_bar,
)

def wait(self, *, timeout=None):
"""Wait until sampling is finished.
KeyboardInterrupt will lead to interrupt the waiting.
This will return after `timeout` seconds even if sampling is
not finished at this point.
"""
if self._sampler is None:
raise ValueError("Sampler is already finalized")

if timeout is None:
timeout = math.inf

start_time = time.time()

try:
for info in self._progress:
current_time = time.time()
if current_time - start_time > timeout:
raise TimeoutError("Sampling did not finish")

if info.draw == self._tune - 1:
self._chains_tuning -= 1
if info.is_diverging and info.draw > self._tune:
self._num_divs += 1
if self._chains_tuning > 0:
count = self._chains_tuning
divs = self._num_divs
self._progress.comment = (
f" Chains in warmup: {count}, Divergences: {divs}"
)
else:
count = self._chains - self._chains_tuning
divs = self._num_divs
self._progress.comment = (
f" Sampling chains: {count}, Divergences: {divs}"
)
except KeyboardInterrupt:
pass

def finalize(self):
"""Free resources of the sampler and return the trace produced so far."""
if self._sampler is None:
raise ValueError("Sampler has already been finalized")

results = self._sampler.finalize()
self._sampler = None

dims = {name: list(dim) for name, dim in self._compiled_model.dims.items()}
dims["mass_matrix_inv"] = ["unconstrained_parameter"]
dims["gradient"] = ["unconstrained_parameter"]
dims["unconstrained_draw"] = ["unconstrained_parameter"]
dims["divergence_start"] = ["unconstrained_parameter"]
dims["divergence_start_gradient"] = ["unconstrained_parameter"]
dims["divergence_end"] = ["unconstrained_parameter"]
dims["divergence_momentum"] = ["unconstrained_parameter"]

if self._return_raw_trace:
return results
else:
return _trace_to_arviz(
results,
self._tune,
self._compiled_model.shapes,
dims=dims,
coords={
name: pd.Index(vals)
for name, vals in self._compiled_model.coords.items()
},
save_warmup=self._save_warmup,
)

def abort(self):
"""Abort sampling and discard progress."""
if self._sampler is not None:
self._sampler.finalize()
self._sampler = None

def __del__(self):
self.abort()


@overload
def sample(
compiled_model: CompiledModel,
*,
draws: int,
tune: int,
chains: int,
cores: int,
seed: Optional[int],
save_warmup: bool,
progress_bar: bool,
init_mean: Optional[np.ndarray],
return_raw_trace: bool,
blocking: Literal[True],
**kwargs,
) -> arviz.InferenceData: ...


@overload
def sample(
compiled_model: CompiledModel,
*,
draws: int,
tune: int,
chains: int,
cores: int,
seed: Optional[int],
save_warmup: bool,
progress_bar: bool,
init_mean: Optional[np.ndarray],
return_raw_trace: bool,
blocking: Literal[False],
**kwargs,
) -> _BackgroundSampler: ...


def sample(
compiled_model: CompiledModel,
*,
Expand All @@ -136,7 +294,8 @@ def sample(
save_warmup: bool = True,
progress_bar: bool = True,
init_mean: Optional[np.ndarray] = None,
return_raw_trace=False,
return_raw_trace: bool = False,
blocking: bool = True,
**kwargs,
) -> arviz.InferenceData:
"""Sample the posterior distribution for a compiled model.
Expand Down Expand Up @@ -213,47 +372,26 @@ def sample(

sampler = compiled_model._make_sampler(settings, init_mean, chains, cores, seed)

sampler = _BackgroundSampler(
compiled_model,
sampler,
chains,
draws,
tune,
progress_bar=progress_bar,
save_warmup=save_warmup,
return_raw_trace=return_raw_trace,
)

if not blocking:
return sampler

try:
num_divs = 0
chains_tuning = chains
bar = fastprogress.progress_bar(
sampler,
total=chains * (draws + tune),
display=progress_bar,
)
try:
for info in bar:
if info.draw == tune - 1:
chains_tuning -= 1
if info.is_diverging and info.draw > tune:
num_divs += 1
bar.comment = (
f" Chains in warmup: {chains_tuning}, Divergences: {num_divs}"
)
except KeyboardInterrupt:
pass
finally:
results = sampler.finalize()

dims = {name: list(dim) for name, dim in compiled_model.dims.items()}
dims["mass_matrix_inv"] = ["unconstrained_parameter"]
dims["gradient"] = ["unconstrained_parameter"]
dims["unconstrained_draw"] = ["unconstrained_parameter"]
dims["divergence_start"] = ["unconstrained_parameter"]
dims["divergence_start_gradient"] = ["unconstrained_parameter"]
dims["divergence_end"] = ["unconstrained_parameter"]
dims["divergence_momentum"] = ["unconstrained_parameter"]

if return_raw_trace:
return results
else:
return _trace_to_arviz(
results,
tune,
compiled_model.shapes,
dims=dims,
coords={
name: pd.Index(vals) for name, vals in compiled_model.coords.items()
},
save_warmup=save_warmup,
)
sampler.wait()
except KeyboardInterrupt:
pass
except:
sampler.abort()
raise

return sampler.finalize()
11 changes: 11 additions & 0 deletions tests/test_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ def test_pymc_model():
trace.posterior.a # noqa: B018


def test_blocking():
with pm.Model() as model:
pm.Normal("a")

compiled = nutpie.compile_pymc_model(model)
sampler = nutpie.sample(compiled, chains=1, blocking=False)
sampler.wait()
trace = sampler.finalize()
trace.posterior.a # noqa: B018


def test_pymc_model_with_coordinate():
with pm.Model() as model:
model.add_coord("foo", length=5)
Expand Down

0 comments on commit bdd60ff

Please sign in to comment.