diff --git a/Cargo.lock b/Cargo.lock
index f0bc08d..29e4883 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -88,7 +88,7 @@ dependencies = [
"regex",
"rustc-hash",
"shlex",
- "syn 2.0.58",
+ "syn 2.0.59",
"which",
]
@@ -147,7 +147,7 @@ checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.58",
+ "syn 2.0.59",
]
[[package]]
@@ -179,9 +179,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
-version = "0.4.37"
+version = "0.4.38"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e"
+checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401"
dependencies = [
"num-traits",
]
@@ -388,7 +388,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
- "syn 2.0.58",
+ "syn 2.0.59",
]
[[package]]
@@ -408,7 +408,7 @@ checksum = "60d08acb9849f7fb4401564f251be5a526829183a3645a90197dea8e786cf3ae"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.58",
+ "syn 2.0.59",
]
[[package]]
@@ -953,6 +953,8 @@ dependencies = [
"rayon",
"smallvec",
"thiserror",
+ "time-humanize",
+ "upon",
]
[[package]]
@@ -1044,7 +1046,7 @@ dependencies = [
"pest_meta",
"proc-macro2",
"quote",
- "syn 2.0.58",
+ "syn 2.0.59",
]
[[package]]
@@ -1100,19 +1102,19 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "prettyplease"
-version = "0.2.17"
+version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7"
+checksum = "5ac2cf0f2e4f42b49f5ffd07dae8d746508ef7526c13940e5f524012ae6c6550"
dependencies = [
"proc-macro2",
- "syn 2.0.58",
+ "syn 2.0.59",
]
[[package]]
name = "proc-macro2"
-version = "1.0.79"
+version = "1.0.80"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e"
+checksum = "a56dea16b0a29e94408b9aa5e2940a4eedbd128a1ba20e8f7ae60fd3d465af0e"
dependencies = [
"unicode-ident",
]
@@ -1144,9 +1146,9 @@ dependencies = [
[[package]]
name = "pyo3"
-version = "0.21.1"
+version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a7a8b1990bd018761768d5e608a13df8bd1ac5f678456e0f301bb93e5f3ea16b"
+checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8"
dependencies = [
"anyhow",
"cfg-if",
@@ -1163,9 +1165,9 @@ dependencies = [
[[package]]
name = "pyo3-build-config"
-version = "0.21.1"
+version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "650dca34d463b6cdbdb02b1d71bfd6eb6b6816afc708faebb3bac1380ff4aef7"
+checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50"
dependencies = [
"once_cell",
"target-lexicon",
@@ -1173,9 +1175,9 @@ dependencies = [
[[package]]
name = "pyo3-ffi"
-version = "0.21.1"
+version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "09a7da8fc04a8a2084909b59f29e1b8474decac98b951d77b80b26dc45f046ad"
+checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403"
dependencies = [
"libc",
"pyo3-build-config",
@@ -1183,27 +1185,27 @@ dependencies = [
[[package]]
name = "pyo3-macros"
-version = "0.21.1"
+version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4b8a199fce11ebb28e3569387228836ea98110e43a804a530a9fd83ade36d513"
+checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
- "syn 2.0.58",
+ "syn 2.0.59",
]
[[package]]
name = "pyo3-macros-backend"
-version = "0.21.1"
+version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "93fbbfd7eb553d10036513cb122b888dcd362a945a00b06c165f2ab480d4cc3b"
+checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c"
dependencies = [
"heck",
"proc-macro2",
"pyo3-build-config",
"quote",
- "syn 2.0.58",
+ "syn 2.0.59",
]
[[package]]
@@ -1412,14 +1414,14 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.58",
+ "syn 2.0.59",
]
[[package]]
name = "serde_json"
-version = "1.0.115"
+version = "1.0.116"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd"
+checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813"
dependencies = [
"itoa",
"ryu",
@@ -1468,9 +1470,9 @@ dependencies = [
[[package]]
name = "syn"
-version = "2.0.58"
+version = "2.0.59"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687"
+checksum = "4a6531ffc7b071655e4ce2e04bd464c4830bb585a61cabb96cf808f05172615a"
dependencies = [
"proc-macro2",
"quote",
@@ -1520,9 +1522,15 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.58",
+ "syn 2.0.59",
]
+[[package]]
+name = "time-humanize"
+version = "0.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3e32d019b4f7c100bcd5494e40a27119d45b71fba2b07a4684153129279a4647"
+
[[package]]
name = "tinytemplate"
version = "1.2.1"
@@ -1557,6 +1565,12 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
+[[package]]
+name = "upon"
+version = "0.8.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9fe29601d1624f104fa9a35ea71a5f523dd8bd1cfc8c31f8124ad2b829f013c0"
+
[[package]]
name = "version_check"
version = "0.9.4"
@@ -1600,7 +1614,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
- "syn 2.0.58",
+ "syn 2.0.59",
"wasm-bindgen-shared",
]
@@ -1622,7 +1636,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.58",
+ "syn 2.0.59",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@@ -1833,5 +1847,5 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.58",
+ "syn 2.0.59",
]
diff --git a/Cargo.toml b/Cargo.toml
index 7b44020..f616664 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -34,6 +34,8 @@ itertools = "0.12.0"
bridgestan = "2.1.2"
rand_distr = "0.4.3"
smallvec = "1.11.0"
+upon = { version = "0.8.1", default-features = false, features = [] }
+time-humanize = { version = "0.1.3", default-features = false }
[dependencies.pyo3]
version = "0.21.0"
diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py
index aef52c9..e41058b 100644
--- a/python/nutpie/compile_pymc.py
+++ b/python/nutpie/compile_pymc.py
@@ -87,9 +87,11 @@ def with_data(self, **updates):
user_data=user_data,
)
- def _make_sampler(self, settings, init_mean, cores, callback=None):
+ def _make_sampler(self, settings, init_mean, cores, template, rate, callback=None):
model = self._make_model(init_mean)
- return _lib.PySampler.from_pymc(settings, cores, model, callback)
+ return _lib.PySampler.from_pymc(
+ settings, cores, model, template, rate, callback
+ )
def _make_model(self, init_mean):
expand_fn = _lib.ExpandFunc(
diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py
index 5911b2c..03f5444 100644
--- a/python/nutpie/compile_stan.py
+++ b/python/nutpie/compile_stan.py
@@ -80,9 +80,11 @@ def _make_model(self, init_mean):
return self.with_data().model
return self.model
- def _make_sampler(self, settings, init_mean, cores, callback=None):
+ def _make_sampler(self, settings, init_mean, cores, template, rate, callback=None):
model = self._make_model(init_mean)
- return _lib.PySampler.from_stan(settings, cores, model, callback)
+ return _lib.PySampler.from_stan(
+ settings, cores, model, template, rate, callback
+ )
@property
def n_dim(self):
diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py
index 0ff2c54..d32a494 100644
--- a/python/nutpie/sample.py
+++ b/python/nutpie/sample.py
@@ -1,10 +1,8 @@
import os
from dataclasses import dataclass
-from threading import Condition, Event
from typing import Any, Literal, Optional, overload
import arviz
-import fastprogress
import numpy as np
import pandas as pd
import pyarrow
@@ -128,85 +126,182 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs):
)
-class _ChainProgress:
- bar: Any
- total: int
- finished: bool
- tuning: bool
- draws: int
- started: bool
- chain_id: int
- num_divs: int
- step_size: float
- num_steps: int
-
- def __init__(self, total, chain_id):
- self.bar = fastprogress.progress_bar(range(total))
- self.total = total
- self.finished = False
- self.tuning = True
- self.draws = 0
- self.started = False
- self.chain_id = chain_id
- self.num_divs = 0
- self.step_size = 0.
- self.num_steps = 0
- self.bar.update(0)
-
- def callback(self, info):
+_progress_style = """
+
+"""
+
+
+_progress_template = """
+
+
Sampler Progress
+
Total Chains: {{ num_chains }}
+
Active Chains: {{ running_chains }}
+
+ Finished Chains:
+ {{ finished_chains }}
+
+
Sampling for {{ time_sampling }}
+
+ Estimated Time to Completion:
+ {{ time_remaining_estimate }}
+
+
+
+
+
+
+ Progress |
+ Chain |
+ Divergences |
+ Step Size |
+ Gradients/Draw |
+
+
+
+ {% for chain in chains %}
+
+
+
+ |
+ {{ chain.chain_index }} |
+ {{ chain.divergences }} |
+ {{ chain.step_size }} |
+ {{ chain.latest_num_steps }} |
+
+ {% endfor %}
+
+
+
+
+"""
+
+
+# Adapted from fastprogress
+def in_notebook():
+ def in_colab():
+ "Check if the code is running in Google Colaboratory"
try:
- if self.finished:
- return
-
- if info.finished_draws == info.total_draws:
- self.finished = True
-
- if not info.tuning:
- self.tuning = False
-
- self.draws = info.finished_draws
- if info.started:
- self.started = True
-
- self.num_divs = info.divergences
- self.step_size = info.step_size
- self.num_steps = info.num_steps
-
- if self.tuning:
- state = "warmup"
- else:
- state = "sampling"
-
- self.bar.comment = (
- f"Chain {self.chain_id:2} {state}: "
- f"trajectory {self.num_steps:3} / "
- f"diverging {self.num_divs: 2} / "
- f"step {self.step_size:.2g}"
- )
- self.bar.update(self.draws)
- except Exception as e:
- print(e)
-
-
-class _DetailedProgress:
- chains: list[_ChainProgress]
-
- def __init__(self, total_draws, num_chains):
- self.chains = [_ChainProgress(total_draws, i) for i in range(num_chains)]
-
- def callback(self, info):
- for chain, chain_info in zip(self.chains, info):
- chain.callback(chain_info)
-
+ from google import colab # noqa: F401
+ return True
+ except ImportError:
+ return False
-class _SummaryProgress:
- bar: Any
-
- def __init__(self, total_draws, num_chains):
- pass
-
- def callback(self, info):
- return None
+ if in_colab():
+ return True
+ try:
+ shell = get_ipython().__class__.__name__
+ if shell == "ZMQInteractiveShell": # Jupyter notebook, Spyder or qtconsole
+ try:
+ from IPython.display import HTML, clear_output, display # noqa: F401
+
+ return True
+ except ImportError:
+ import warnings
+
+ warnings.warn(
+ "Couldn't import ipywidgets properly, "
+ "progress bar will be disabled",
+ stacklevel=2,
+ )
+ return False
+ elif shell == "TerminalInteractiveShell":
+ return False # Terminal running IPython
+ else:
+ return False # Other type (?)
+ except NameError:
+ return False # Probably standard Python interpreter
class _BackgroundSampler:
@@ -230,25 +325,46 @@ def __init__(
progress_bar=True,
save_warmup=True,
return_raw_trace=False,
+ progress_template=None,
+ progress_style=None,
+ progress_rate=100,
):
self._settings = settings
self._compiled_model = compiled_model
self._save_warmup = save_warmup
self._return_raw_trace = return_raw_trace
- total_draws = settings.num_draws + settings.num_tune
+ self._html = None
- if progress_bar:
- self._progress = _DetailedProgress(total_draws, settings.num_chains)
- callback = self._progress.callback
- else:
- self._progress = None
+ if progress_template is None:
+ progress_template = _progress_template
+
+ if progress_style is None:
+ progress_style = _progress_style
+
+ if not progress_bar or not in_notebook():
+ progress_template = ""
callback = None
+ else:
+ import IPython
+
+ self._html = ""
+
+ if progress_style is not None:
+ IPython.display.display(IPython.display.HTML(progress_style))
+
+ self.display_id = IPython.display.display(self, display_id=True)
+
+ def callback(formatted):
+ self._html = formatted
+ self.display_id.update(self)
self._sampler = compiled_model._make_sampler(
settings,
init_mean,
cores,
+ progress_template,
+ progress_rate,
callback=callback,
)
@@ -322,6 +438,9 @@ def __del__(self):
if not self._sampler.is_empty():
self.cancel()
+ def _repr_html_(self):
+ return self._html
+
@overload
def sample(
@@ -372,6 +491,9 @@ def sample(
init_mean: Optional[np.ndarray] = None,
return_raw_trace: bool = False,
blocking: bool = True,
+ progress_template: Optional[str] = None,
+ progress_style: Optional[str] = None,
+ progress_rate: int = 100,
**kwargs,
) -> arviz.InferenceData:
"""Sample the posterior distribution for a compiled model.
@@ -432,6 +554,14 @@ def sample(
Use a mass matrix estimate that is based on draw and gradient
variance. Set to `False` to get mass matrix adaptation more
similar to PyMC and Stan.
+ progress_template: str
+ This is only exposed for experimentation. upon template
+ for the html progress representation.
+ progress_style: str
+ This is only exposed for experimentation. Common HTML
+ for the progress bar (eg CSS).
+ progress_rate: int, default=500
+ Rate in ms at which the progress should be updated.
**kwargs
Pass additional arguments to nutpie._lib.PySamplerArgs
@@ -467,6 +597,9 @@ def sample(
progress_bar=progress_bar,
save_warmup=save_warmup,
return_raw_trace=return_raw_trace,
+ progress_template=progress_template,
+ progress_style=progress_style,
+ progress_rate=progress_rate,
)
if not blocking:
diff --git a/src/lib.rs b/src/lib.rs
index 2250b01..2a94874 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,3 +1,4 @@
+mod progress;
mod pymc;
mod stan;
mod wrapper;
diff --git a/src/progress.rs b/src/progress.rs
new file mode 100644
index 0000000..8c130e5
--- /dev/null
+++ b/src/progress.rs
@@ -0,0 +1,223 @@
+use std::{collections::BTreeMap, time::Duration};
+
+use anyhow::{Context, Result};
+use nuts_rs::{ChainProgress, ProgressCallback};
+use pyo3::{Py, PyAny, Python};
+use time_humanize::{Accuracy, Tense};
+use upon::{Engine, Value};
+
+pub struct ProgressHandler {
+ engine: Engine<'static>,
+ template: String,
+ callback: Py,
+ rate: Duration,
+ n_cores: usize,
+}
+
+impl ProgressHandler {
+ pub fn new(callback: Py, rate: Duration, template: String, n_cores: usize) -> Self {
+ let engine = Engine::new();
+ Self {
+ engine,
+ callback,
+ rate,
+ template,
+ n_cores,
+ }
+ }
+
+ pub fn into_callback(self) -> Result {
+ let template = self
+ .engine
+ .compile(self.template)
+ .context("Could not compile progress template")?;
+
+ let mut finished = false;
+ let mut progress_update_count = 0;
+
+ let callback = move |time_sampling, progress: Box<[ChainProgress]>| {
+ if finished {
+ return;
+ }
+ if progress
+ .iter()
+ .all(|chain| chain.finished_draws == chain.total_draws)
+ {
+ finished = true;
+ }
+ let progress =
+ progress_to_value(progress_update_count, self.n_cores, time_sampling, progress);
+ let rendered = template.render_from(&self.engine, &progress).to_string();
+ let rendered = rendered.unwrap_or_else(|err| format!("{}", err));
+ let _ = Python::with_gil(|py| self.callback.call1(py, (rendered,)));
+ progress_update_count += 1;
+ };
+
+ Ok(ProgressCallback {
+ callback: Box::new(callback),
+ rate: self.rate,
+ })
+ }
+}
+
+fn progress_to_value(
+ progress_update_count: usize,
+ n_cores: usize,
+ time_sampling: Duration,
+ progress: Box<[ChainProgress]>,
+) -> Value {
+ let chains: Vec<_> = progress
+ .iter()
+ .enumerate()
+ .map(|(i, chain)| {
+ let mut values = BTreeMap::new();
+ values.insert("chain_index".into(), Value::Integer(i as i64));
+ values.insert(
+ "finished_draws".into(),
+ Value::Integer(chain.finished_draws as i64),
+ );
+ values.insert(
+ "total_draws".into(),
+ Value::Integer(chain.total_draws as i64),
+ );
+ values.insert(
+ "divergences".into(),
+ Value::Integer(chain.divergences as i64),
+ );
+ values.insert("tuning".into(), Value::Bool(chain.tuning));
+ values.insert("started".into(), Value::Bool(chain.started));
+ values.insert(
+ "finished".into(),
+ Value::Bool(chain.total_draws == chain.finished_draws),
+ );
+ values.insert(
+ "latest_num_steps".into(),
+ Value::Integer(chain.latest_num_steps as i64),
+ );
+ values.insert(
+ "total_num_steps".into(),
+ Value::Integer(chain.total_num_steps as i64),
+ );
+ values.insert(
+ "step_size".into(),
+ Value::String(format!("{:.2}", chain.step_size)),
+ );
+ values.insert(
+ "divergent_draws".into(),
+ Value::List(
+ chain
+ .divergent_draws
+ .iter()
+ .map(|&idx| Value::Integer(idx as _))
+ .collect(),
+ ),
+ );
+ upon::Value::Map(values)
+ })
+ .collect();
+
+ let mut map = BTreeMap::new();
+ map.insert("chains".into(), Value::List(chains));
+ map.insert(
+ "total_draws".into(),
+ Value::Integer(
+ progress
+ .iter()
+ .map(|chain| chain.total_draws)
+ .sum::() as i64,
+ ),
+ );
+ map.insert(
+ "total_finished_draws".into(),
+ Value::Integer(
+ progress
+ .iter()
+ .map(|chain| chain.finished_draws)
+ .sum::() as i64,
+ ),
+ );
+ map.insert(
+ "time_sampling".into(),
+ Value::String(
+ time_humanize::HumanTime::from(time_sampling)
+ .to_text_en(Accuracy::Rough, Tense::Present),
+ ),
+ );
+
+ let remaining = estimate_remaining_time(n_cores, time_sampling, &progress);
+ map.insert(
+ "time_remaining_estimate".into(),
+ match remaining {
+ Some(remaining) => Value::String(
+ time_humanize::HumanTime::from(remaining)
+ .to_text_en(Accuracy::Rough, Tense::Present),
+ ),
+ None => Value::None,
+ },
+ );
+
+ map.insert("num_cores".into(), Value::Integer(n_cores as _));
+
+ let finished_chains = progress
+ .iter()
+ .map(|chain| (chain.finished_draws == chain.total_draws) as u64)
+ .sum::();
+ map.insert(
+ "finished_chains".into(),
+ Value::Integer(finished_chains as _),
+ );
+ map.insert(
+ "running_chains".into(),
+ Value::Integer(
+ progress
+ .iter()
+ .map(|chain| (chain.started & (chain.finished_draws < chain.total_draws)) as u64)
+ .sum::() as i64,
+ ),
+ );
+ map.insert("num_chains".into(), Value::Integer(progress.len() as _));
+ map.insert(
+ "finished".into(),
+ Value::Bool(progress.len() == finished_chains as usize),
+ );
+ map.insert(
+ "progress_update_count".into(),
+ Value::Integer(progress_update_count as i64),
+ );
+
+ Value::Map(map)
+}
+
+fn estimate_remaining_time(
+ n_cores: usize,
+ time_sampling: Duration,
+ progress: &[ChainProgress],
+) -> Option {
+ let finished_draws: f64 = progress
+ .iter()
+ .map(|chain| chain.finished_draws as f64)
+ .sum();
+ if !(finished_draws > 0.) {
+ return None;
+ }
+
+ // TODO this assumes that so far all cores were used all the time
+ let time_per_draw = time_sampling.mul_f64((n_cores as f64) / finished_draws);
+
+ let mut core_times = vec![Duration::ZERO; n_cores];
+
+ progress
+ .iter()
+ .map(|chain| time_per_draw.mul_f64((chain.total_draws - chain.finished_draws) as f64))
+ .for_each(|time| {
+ let min_index = core_times
+ .iter()
+ .enumerate()
+ .min_by_key(|&(_, v)| v)
+ .unwrap()
+ .0;
+ core_times[min_index] += time;
+ });
+
+ Some(core_times.into_iter().max().unwrap_or(Duration::ZERO))
+}
diff --git a/src/wrapper.rs b/src/wrapper.rs
index 3da5b84..e631f55 100644
--- a/src/wrapper.rs
+++ b/src/wrapper.rs
@@ -1,6 +1,7 @@
use std::time::{Duration, Instant};
use crate::{
+ progress::ProgressHandler,
pymc::{ExpandFunc, LogpFunc, PyMcModel},
stan::{StanLibrary, StanModel},
};
@@ -239,27 +240,20 @@ pub(crate) enum SamplerState {
#[pyclass]
struct PySampler(SamplerState);
-fn make_callback(callback: Option>) -> Option {
+fn make_callback(
+ template: String,
+ n_cores: usize,
+ rate: Duration,
+ callback: Option>,
+) -> Result