Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow sampling in the backgound #94

Merged
merged 6 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
599 changes: 329 additions & 270 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ crate-type = ["cdylib"]

[dependencies]
nuts-rs = "0.7.0"
numpy = "0.19.0"
numpy = "0.20.0"
ndarray = "0.15.6"
rand = "0.8.5"
thiserror = "1.0.44"
rand_chacha = "0.3.1"
rayon = "1.7.0"
arrow2 = "0.17.3"
rayon = "1.9.0"
arrow2 = "0.17.0"
anyhow = "1.0.72"
itertools = "0.11.0"
itertools = "0.12.0"
bridgestan = "2.1.2"
rand_distr = "0.4.3"
smallvec = "1.11.0"

[dependencies.pyo3]
version = "0.19.2"
version = "0.20.0"
features = ["extension-module", "anyhow"]

[dev-dependencies]
Expand Down
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,26 @@ trace_pymc = nutpie.sample(compiled_model)
`trace_pymc` now contains an ArviZ `InferenceData` object, including sampling
statistics and the posterior of the variables defined above.

We can also control the sampler in a non-blocking way:

```python
# The sampler will now run the the background
sampler = nutpie.sample(compiled_model, blocking=False)

# Pause and resume the sampling
sampler.pause()
sampler.resume()

# Wait for the sampler to finish (up to timeout seconds)
# sampler.wait(timeout=0.1)

# or we can also abort the sampler (and return the incomplete trace)
incomplete_trace = sampler.abort()

# or cancel and discard all progress:
sampler.cancel()
```

## Usage with Stan

In order to sample from Stan model, `bridgestan` needs to be installed.
Expand Down
4 changes: 2 additions & 2 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def with_data(self, **updates):
user_data=user_data,
)

def _make_sampler(self, settings, init_mean, chains, cores, seed):
def _make_sampler(self, settings, init_mean, chains, cores, seed, callback=None):
model = self._make_model(init_mean)
return _lib.PySampler.from_pymc(settings, chains, cores, model, seed)
return _lib.PySampler.from_pymc(settings, chains, cores, model, seed, callback)

def _make_model(self, init_mean):
expand_fn = _lib.ExpandFunc(
Expand Down
4 changes: 2 additions & 2 deletions python/nutpie/compile_stan.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def _make_model(self, init_mean):
return self.with_data().model
return self.model

def _make_sampler(self, settings, init_mean, chains, cores, seed):
def _make_sampler(self, settings, init_mean, chains, cores, seed, callback=None):
model = self._make_model(init_mean)
return _lib.PySampler.from_stan(settings, chains, cores, model, seed)
return _lib.PySampler.from_stan(settings, chains, cores, model, seed, callback)

@property
def n_dim(self):
Expand Down
272 changes: 226 additions & 46 deletions python/nutpie/sample.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dataclasses import dataclass
from typing import Optional
from threading import Condition, Event
from typing import Any, Literal, Optional, overload

import arviz
import fastprogress
import numpy as np
import pandas as pd
import pyarrow
from fastprogress.fastprogress import ConsoleProgressBar

from nutpie import _lib

Expand Down Expand Up @@ -125,6 +127,203 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs):
)


class _BackgroundSampler:
_sampler: Any
_num_divs: int
_tune: int
_draws: int
_chains: int
_chains_finished: int
_compiled_model: CompiledModel
_save_warmup: bool
_progress = ConsoleProgressBar

def __init__(
self,
compiled_model,
settings,
init_mean,
chains,
cores,
seed,
draws,
tune,
*,
progress_bar=True,
save_warmup=True,
return_raw_trace=False,
):
self._num_divs = 0
self._tune = settings.num_tune
self._draws = settings.num_draws
self._settings = settings
self._chains_tuning = chains
self._chains_finished = 0
self._chains = chains
self._compiled_model = compiled_model
self._save_warmup = save_warmup
self._return_raw_trace = return_raw_trace
total_draws = (self._draws + self._tune) * self._chains
self._progress = fastprogress.progress_bar(
range(total_draws),
total=total_draws,
display=progress_bar,
)
# fastprogress seems to reset the progress bar
# if we create a new iterator, but we don't want
# this for multiple calls to wait.
self._bar = iter(self._progress)

self._exit_event = Event()
self._pause_event = Event()
self._continue = Condition()

self._finished_draws = 0

next(self._bar)

def progress_callback(info):
if info.draw == self._tune - 1:
self._chains_tuning -= 1
if info.draw == self._tune + self._draws - 1:
self._chains_finished += 1
if info.is_diverging and info.draw > self._tune:
self._num_divs += 1
if self._chains_tuning > 0:
count = self._chains_tuning
divs = self._num_divs
self._progress.comment = (
f" Chains in warmup: {count}, Divergences: {divs}"
)
else:
count = self._chains - self._chains_finished
divs = self._num_divs
self._progress.comment = (
f" Sampling chains: {count}, Divergences: {divs}"
)
try:
next(self._bar)
except StopIteration:
pass
self._finished_draws += 1

if progress_bar:
callback = progress_callback
else:
callback = None

self._sampler = compiled_model._make_sampler(
settings,
init_mean,
chains,
cores,
seed,
callback=callback,
)

def wait(self, *, timeout=None):
"""Wait until sampling is finished and return the trace.

KeyboardInterrupt will lead to interrupt the waiting.

This will return after `timeout` seconds even if sampling is
not finished at this point.

This resumes the sampler in case it had been paused.
"""
self._sampler.wait(timeout)
self._sampler.finalize()
return self._extract()

def _extract(self):
results = self._sampler.extract_results()

dims = {name: list(dim) for name, dim in self._compiled_model.dims.items()}
dims["mass_matrix_inv"] = ["unconstrained_parameter"]
dims["gradient"] = ["unconstrained_parameter"]
dims["unconstrained_draw"] = ["unconstrained_parameter"]
dims["divergence_start"] = ["unconstrained_parameter"]
dims["divergence_start_gradient"] = ["unconstrained_parameter"]
dims["divergence_end"] = ["unconstrained_parameter"]
dims["divergence_momentum"] = ["unconstrained_parameter"]

if self._return_raw_trace:
return results
else:
return _trace_to_arviz(
results,
self._tune,
self._compiled_model.shapes,
dims=dims,
coords={
name: pd.Index(vals)
for name, vals in self._compiled_model.coords.items()
},
save_warmup=self._save_warmup,
)

def pause(self):
"""Pause the sampler."""
self._sampler.pause()

def resume(self):
"""Resume a paused sampler."""
self._sampler.resume()

@property
def is_finished(self):
return self._sampler.is_finished()

def abort(self):
"""Abort sampling and return the trace produced so far."""
self._sampler.abort()
return self._extract()

def cancel(self):
"""Abort sampling and discard progress."""
self._sampler.abort()

def __del__(self):
if not self._sampler.is_empty():
self.cancel()


@overload
def sample(
compiled_model: CompiledModel,
*,
draws: int,
tune: int,
chains: int,
cores: int,
seed: Optional[int],
save_warmup: bool,
progress_bar: bool,
init_mean: Optional[np.ndarray],
return_raw_trace: bool,
blocking: Literal[True],
**kwargs,
) -> arviz.InferenceData: ...


@overload
def sample(
compiled_model: CompiledModel,
*,
draws: int,
tune: int,
chains: int,
cores: int,
seed: Optional[int],
save_warmup: bool,
progress_bar: bool,
init_mean: Optional[np.ndarray],
return_raw_trace: bool,
blocking: Literal[False],
**kwargs,
) -> _BackgroundSampler: ...


def sample(
compiled_model: CompiledModel,
*,
Expand All @@ -136,7 +335,8 @@ def sample(
save_warmup: bool = True,
progress_bar: bool = True,
init_mean: Optional[np.ndarray] = None,
return_raw_trace=False,
return_raw_trace: bool = False,
blocking: bool = True,
**kwargs,
) -> arviz.InferenceData:
"""Sample the posterior distribution for a compiled model.
Expand Down Expand Up @@ -211,49 +411,29 @@ def sample(
if init_mean is None:
init_mean = np.zeros(compiled_model.n_dim)

sampler = compiled_model._make_sampler(settings, init_mean, chains, cores, seed)
sampler = _BackgroundSampler(
compiled_model,
settings,
init_mean,
chains,
cores,
seed,
draws,
tune,
progress_bar=progress_bar,
save_warmup=save_warmup,
return_raw_trace=return_raw_trace,
)

if not blocking:
return sampler

try:
num_divs = 0
chains_tuning = chains
bar = fastprogress.progress_bar(
sampler,
total=chains * (draws + tune),
display=progress_bar,
)
try:
for info in bar:
if info.draw == tune - 1:
chains_tuning -= 1
if info.is_diverging and info.draw > tune:
num_divs += 1
bar.comment = (
f" Chains in warmup: {chains_tuning}, Divergences: {num_divs}"
)
except KeyboardInterrupt:
pass
finally:
results = sampler.finalize()

dims = {name: list(dim) for name, dim in compiled_model.dims.items()}
dims["mass_matrix_inv"] = ["unconstrained_parameter"]
dims["gradient"] = ["unconstrained_parameter"]
dims["unconstrained_draw"] = ["unconstrained_parameter"]
dims["divergence_start"] = ["unconstrained_parameter"]
dims["divergence_start_gradient"] = ["unconstrained_parameter"]
dims["divergence_end"] = ["unconstrained_parameter"]
dims["divergence_momentum"] = ["unconstrained_parameter"]

if return_raw_trace:
return results
else:
return _trace_to_arviz(
results,
tune,
compiled_model.shapes,
dims=dims,
coords={
name: pd.Index(vals) for name, vals in compiled_model.coords.items()
},
save_warmup=save_warmup,
)
result = sampler.wait()
except KeyboardInterrupt:
result = sampler.abort()
except:
sampler.cancel()
raise

return result
Loading