diff --git a/Cargo.lock b/Cargo.lock index f0bc08d..29e4883 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -88,7 +88,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.58", + "syn 2.0.59", "which", ] @@ -147,7 +147,7 @@ checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.59", ] [[package]] @@ -179,9 +179,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.37" +version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ "num-traits", ] @@ -388,7 +388,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.59", ] [[package]] @@ -408,7 +408,7 @@ checksum = "60d08acb9849f7fb4401564f251be5a526829183a3645a90197dea8e786cf3ae" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.59", ] [[package]] @@ -953,6 +953,8 @@ dependencies = [ "rayon", "smallvec", "thiserror", + "time-humanize", + "upon", ] [[package]] @@ -1044,7 +1046,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.59", ] [[package]] @@ -1100,19 +1102,19 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prettyplease" -version = "0.2.17" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7" +checksum = "5ac2cf0f2e4f42b49f5ffd07dae8d746508ef7526c13940e5f524012ae6c6550" dependencies = [ "proc-macro2", - "syn 2.0.58", + "syn 2.0.59", ] [[package]] name = "proc-macro2" -version = "1.0.79" +version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" +checksum = "a56dea16b0a29e94408b9aa5e2940a4eedbd128a1ba20e8f7ae60fd3d465af0e" dependencies = [ "unicode-ident", ] @@ -1144,9 +1146,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.21.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a8b1990bd018761768d5e608a13df8bd1ac5f678456e0f301bb93e5f3ea16b" +checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" dependencies = [ "anyhow", "cfg-if", @@ -1163,9 +1165,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.21.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "650dca34d463b6cdbdb02b1d71bfd6eb6b6816afc708faebb3bac1380ff4aef7" +checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" dependencies = [ "once_cell", "target-lexicon", @@ -1173,9 +1175,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.21.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09a7da8fc04a8a2084909b59f29e1b8474decac98b951d77b80b26dc45f046ad" +checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" dependencies = [ "libc", "pyo3-build-config", @@ -1183,27 +1185,27 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.21.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8a199fce11ebb28e3569387228836ea98110e43a804a530a9fd83ade36d513" +checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.58", + "syn 2.0.59", ] [[package]] name = "pyo3-macros-backend" -version = "0.21.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fbbfd7eb553d10036513cb122b888dcd362a945a00b06c165f2ab480d4cc3b" +checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" dependencies = [ "heck", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.58", + "syn 2.0.59", ] [[package]] @@ -1412,14 +1414,14 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.59", ] [[package]] name = "serde_json" -version = "1.0.115" +version = "1.0.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd" +checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" dependencies = [ "itoa", "ryu", @@ -1468,9 +1470,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.58" +version = "2.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687" +checksum = "4a6531ffc7b071655e4ce2e04bd464c4830bb585a61cabb96cf808f05172615a" dependencies = [ "proc-macro2", "quote", @@ -1520,9 +1522,15 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.59", ] +[[package]] +name = "time-humanize" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e32d019b4f7c100bcd5494e40a27119d45b71fba2b07a4684153129279a4647" + [[package]] name = "tinytemplate" version = "1.2.1" @@ -1557,6 +1565,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +[[package]] +name = "upon" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fe29601d1624f104fa9a35ea71a5f523dd8bd1cfc8c31f8124ad2b829f013c0" + [[package]] name = "version_check" version = "0.9.4" @@ -1600,7 +1614,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.59", "wasm-bindgen-shared", ] @@ -1622,7 +1636,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.59", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1833,5 +1847,5 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.59", ] diff --git a/Cargo.toml b/Cargo.toml index 7b44020..f616664 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,8 @@ itertools = "0.12.0" bridgestan = "2.1.2" rand_distr = "0.4.3" smallvec = "1.11.0" +upon = { version = "0.8.1", default-features = false, features = [] } +time-humanize = { version = "0.1.3", default-features = false } [dependencies.pyo3] version = "0.21.0" diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index aef52c9..e41058b 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -87,9 +87,11 @@ def with_data(self, **updates): user_data=user_data, ) - def _make_sampler(self, settings, init_mean, cores, callback=None): + def _make_sampler(self, settings, init_mean, cores, template, rate, callback=None): model = self._make_model(init_mean) - return _lib.PySampler.from_pymc(settings, cores, model, callback) + return _lib.PySampler.from_pymc( + settings, cores, model, template, rate, callback + ) def _make_model(self, init_mean): expand_fn = _lib.ExpandFunc( diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 5911b2c..03f5444 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -80,9 +80,11 @@ def _make_model(self, init_mean): return self.with_data().model return self.model - def _make_sampler(self, settings, init_mean, cores, callback=None): + def _make_sampler(self, settings, init_mean, cores, template, rate, callback=None): model = self._make_model(init_mean) - return _lib.PySampler.from_stan(settings, cores, model, callback) + return _lib.PySampler.from_stan( + settings, cores, model, template, rate, callback + ) @property def n_dim(self): diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 0ff2c54..d32a494 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -1,10 +1,8 @@ import os from dataclasses import dataclass -from threading import Condition, Event from typing import Any, Literal, Optional, overload import arviz -import fastprogress import numpy as np import pandas as pd import pyarrow @@ -128,85 +126,182 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs): ) -class _ChainProgress: - bar: Any - total: int - finished: bool - tuning: bool - draws: int - started: bool - chain_id: int - num_divs: int - step_size: float - num_steps: int - - def __init__(self, total, chain_id): - self.bar = fastprogress.progress_bar(range(total)) - self.total = total - self.finished = False - self.tuning = True - self.draws = 0 - self.started = False - self.chain_id = chain_id - self.num_divs = 0 - self.step_size = 0. - self.num_steps = 0 - self.bar.update(0) - - def callback(self, info): +_progress_style = """ + +""" + + +_progress_template = """ +
+

Sampler Progress

+

Total Chains: {{ num_chains }}

+

Active Chains: {{ running_chains }}

+

+ Finished Chains: + {{ finished_chains }} +

+

Sampling for {{ time_sampling }}

+

+ Estimated Time to Completion: + {{ time_remaining_estimate }} +

+ + + + + + + + + + + + + + + {% for chain in chains %} + + + + + + + + {% endfor %} + + +
ProgressChainDivergencesStep SizeGradients/Draw
+ + + {{ chain.chain_index }}{{ chain.divergences }}{{ chain.step_size }}{{ chain.latest_num_steps }}
+
+""" + + +# Adapted from fastprogress +def in_notebook(): + def in_colab(): + "Check if the code is running in Google Colaboratory" try: - if self.finished: - return - - if info.finished_draws == info.total_draws: - self.finished = True - - if not info.tuning: - self.tuning = False - - self.draws = info.finished_draws - if info.started: - self.started = True - - self.num_divs = info.divergences - self.step_size = info.step_size - self.num_steps = info.num_steps - - if self.tuning: - state = "warmup" - else: - state = "sampling" - - self.bar.comment = ( - f"Chain {self.chain_id:2} {state}: " - f"trajectory {self.num_steps:3} / " - f"diverging {self.num_divs: 2} / " - f"step {self.step_size:.2g}" - ) - self.bar.update(self.draws) - except Exception as e: - print(e) - - -class _DetailedProgress: - chains: list[_ChainProgress] - - def __init__(self, total_draws, num_chains): - self.chains = [_ChainProgress(total_draws, i) for i in range(num_chains)] - - def callback(self, info): - for chain, chain_info in zip(self.chains, info): - chain.callback(chain_info) - + from google import colab # noqa: F401 + return True + except ImportError: + return False -class _SummaryProgress: - bar: Any - - def __init__(self, total_draws, num_chains): - pass - - def callback(self, info): - return None + if in_colab(): + return True + try: + shell = get_ipython().__class__.__name__ + if shell == "ZMQInteractiveShell": # Jupyter notebook, Spyder or qtconsole + try: + from IPython.display import HTML, clear_output, display # noqa: F401 + + return True + except ImportError: + import warnings + + warnings.warn( + "Couldn't import ipywidgets properly, " + "progress bar will be disabled", + stacklevel=2, + ) + return False + elif shell == "TerminalInteractiveShell": + return False # Terminal running IPython + else: + return False # Other type (?) + except NameError: + return False # Probably standard Python interpreter class _BackgroundSampler: @@ -230,25 +325,46 @@ def __init__( progress_bar=True, save_warmup=True, return_raw_trace=False, + progress_template=None, + progress_style=None, + progress_rate=100, ): self._settings = settings self._compiled_model = compiled_model self._save_warmup = save_warmup self._return_raw_trace = return_raw_trace - total_draws = settings.num_draws + settings.num_tune + self._html = None - if progress_bar: - self._progress = _DetailedProgress(total_draws, settings.num_chains) - callback = self._progress.callback - else: - self._progress = None + if progress_template is None: + progress_template = _progress_template + + if progress_style is None: + progress_style = _progress_style + + if not progress_bar or not in_notebook(): + progress_template = "" callback = None + else: + import IPython + + self._html = "" + + if progress_style is not None: + IPython.display.display(IPython.display.HTML(progress_style)) + + self.display_id = IPython.display.display(self, display_id=True) + + def callback(formatted): + self._html = formatted + self.display_id.update(self) self._sampler = compiled_model._make_sampler( settings, init_mean, cores, + progress_template, + progress_rate, callback=callback, ) @@ -322,6 +438,9 @@ def __del__(self): if not self._sampler.is_empty(): self.cancel() + def _repr_html_(self): + return self._html + @overload def sample( @@ -372,6 +491,9 @@ def sample( init_mean: Optional[np.ndarray] = None, return_raw_trace: bool = False, blocking: bool = True, + progress_template: Optional[str] = None, + progress_style: Optional[str] = None, + progress_rate: int = 100, **kwargs, ) -> arviz.InferenceData: """Sample the posterior distribution for a compiled model. @@ -432,6 +554,14 @@ def sample( Use a mass matrix estimate that is based on draw and gradient variance. Set to `False` to get mass matrix adaptation more similar to PyMC and Stan. + progress_template: str + This is only exposed for experimentation. upon template + for the html progress representation. + progress_style: str + This is only exposed for experimentation. Common HTML + for the progress bar (eg CSS). + progress_rate: int, default=500 + Rate in ms at which the progress should be updated. **kwargs Pass additional arguments to nutpie._lib.PySamplerArgs @@ -467,6 +597,9 @@ def sample( progress_bar=progress_bar, save_warmup=save_warmup, return_raw_trace=return_raw_trace, + progress_template=progress_template, + progress_style=progress_style, + progress_rate=progress_rate, ) if not blocking: diff --git a/src/lib.rs b/src/lib.rs index 2250b01..2a94874 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod progress; mod pymc; mod stan; mod wrapper; diff --git a/src/progress.rs b/src/progress.rs new file mode 100644 index 0000000..8c130e5 --- /dev/null +++ b/src/progress.rs @@ -0,0 +1,223 @@ +use std::{collections::BTreeMap, time::Duration}; + +use anyhow::{Context, Result}; +use nuts_rs::{ChainProgress, ProgressCallback}; +use pyo3::{Py, PyAny, Python}; +use time_humanize::{Accuracy, Tense}; +use upon::{Engine, Value}; + +pub struct ProgressHandler { + engine: Engine<'static>, + template: String, + callback: Py, + rate: Duration, + n_cores: usize, +} + +impl ProgressHandler { + pub fn new(callback: Py, rate: Duration, template: String, n_cores: usize) -> Self { + let engine = Engine::new(); + Self { + engine, + callback, + rate, + template, + n_cores, + } + } + + pub fn into_callback(self) -> Result { + let template = self + .engine + .compile(self.template) + .context("Could not compile progress template")?; + + let mut finished = false; + let mut progress_update_count = 0; + + let callback = move |time_sampling, progress: Box<[ChainProgress]>| { + if finished { + return; + } + if progress + .iter() + .all(|chain| chain.finished_draws == chain.total_draws) + { + finished = true; + } + let progress = + progress_to_value(progress_update_count, self.n_cores, time_sampling, progress); + let rendered = template.render_from(&self.engine, &progress).to_string(); + let rendered = rendered.unwrap_or_else(|err| format!("{}", err)); + let _ = Python::with_gil(|py| self.callback.call1(py, (rendered,))); + progress_update_count += 1; + }; + + Ok(ProgressCallback { + callback: Box::new(callback), + rate: self.rate, + }) + } +} + +fn progress_to_value( + progress_update_count: usize, + n_cores: usize, + time_sampling: Duration, + progress: Box<[ChainProgress]>, +) -> Value { + let chains: Vec<_> = progress + .iter() + .enumerate() + .map(|(i, chain)| { + let mut values = BTreeMap::new(); + values.insert("chain_index".into(), Value::Integer(i as i64)); + values.insert( + "finished_draws".into(), + Value::Integer(chain.finished_draws as i64), + ); + values.insert( + "total_draws".into(), + Value::Integer(chain.total_draws as i64), + ); + values.insert( + "divergences".into(), + Value::Integer(chain.divergences as i64), + ); + values.insert("tuning".into(), Value::Bool(chain.tuning)); + values.insert("started".into(), Value::Bool(chain.started)); + values.insert( + "finished".into(), + Value::Bool(chain.total_draws == chain.finished_draws), + ); + values.insert( + "latest_num_steps".into(), + Value::Integer(chain.latest_num_steps as i64), + ); + values.insert( + "total_num_steps".into(), + Value::Integer(chain.total_num_steps as i64), + ); + values.insert( + "step_size".into(), + Value::String(format!("{:.2}", chain.step_size)), + ); + values.insert( + "divergent_draws".into(), + Value::List( + chain + .divergent_draws + .iter() + .map(|&idx| Value::Integer(idx as _)) + .collect(), + ), + ); + upon::Value::Map(values) + }) + .collect(); + + let mut map = BTreeMap::new(); + map.insert("chains".into(), Value::List(chains)); + map.insert( + "total_draws".into(), + Value::Integer( + progress + .iter() + .map(|chain| chain.total_draws) + .sum::() as i64, + ), + ); + map.insert( + "total_finished_draws".into(), + Value::Integer( + progress + .iter() + .map(|chain| chain.finished_draws) + .sum::() as i64, + ), + ); + map.insert( + "time_sampling".into(), + Value::String( + time_humanize::HumanTime::from(time_sampling) + .to_text_en(Accuracy::Rough, Tense::Present), + ), + ); + + let remaining = estimate_remaining_time(n_cores, time_sampling, &progress); + map.insert( + "time_remaining_estimate".into(), + match remaining { + Some(remaining) => Value::String( + time_humanize::HumanTime::from(remaining) + .to_text_en(Accuracy::Rough, Tense::Present), + ), + None => Value::None, + }, + ); + + map.insert("num_cores".into(), Value::Integer(n_cores as _)); + + let finished_chains = progress + .iter() + .map(|chain| (chain.finished_draws == chain.total_draws) as u64) + .sum::(); + map.insert( + "finished_chains".into(), + Value::Integer(finished_chains as _), + ); + map.insert( + "running_chains".into(), + Value::Integer( + progress + .iter() + .map(|chain| (chain.started & (chain.finished_draws < chain.total_draws)) as u64) + .sum::() as i64, + ), + ); + map.insert("num_chains".into(), Value::Integer(progress.len() as _)); + map.insert( + "finished".into(), + Value::Bool(progress.len() == finished_chains as usize), + ); + map.insert( + "progress_update_count".into(), + Value::Integer(progress_update_count as i64), + ); + + Value::Map(map) +} + +fn estimate_remaining_time( + n_cores: usize, + time_sampling: Duration, + progress: &[ChainProgress], +) -> Option { + let finished_draws: f64 = progress + .iter() + .map(|chain| chain.finished_draws as f64) + .sum(); + if !(finished_draws > 0.) { + return None; + } + + // TODO this assumes that so far all cores were used all the time + let time_per_draw = time_sampling.mul_f64((n_cores as f64) / finished_draws); + + let mut core_times = vec![Duration::ZERO; n_cores]; + + progress + .iter() + .map(|chain| time_per_draw.mul_f64((chain.total_draws - chain.finished_draws) as f64)) + .for_each(|time| { + let min_index = core_times + .iter() + .enumerate() + .min_by_key(|&(_, v)| v) + .unwrap() + .0; + core_times[min_index] += time; + }); + + Some(core_times.into_iter().max().unwrap_or(Duration::ZERO)) +} diff --git a/src/wrapper.rs b/src/wrapper.rs index 3da5b84..e631f55 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -1,6 +1,7 @@ use std::time::{Duration, Instant}; use crate::{ + progress::ProgressHandler, pymc::{ExpandFunc, LogpFunc, PyMcModel}, stan::{StanLibrary, StanModel}, }; @@ -239,27 +240,20 @@ pub(crate) enum SamplerState { #[pyclass] struct PySampler(SamplerState); -fn make_callback(callback: Option>) -> Option { +fn make_callback( + template: String, + n_cores: usize, + rate: Duration, + callback: Option>, +) -> Result> { match callback { Some(callback) => { - let callback = Box::new(move |stats: Box<[ChainProgress]>| { - let _ = Python::with_gil(|py| { - let args = PyList::new_bound( - py, - stats - .into_vec() - .into_iter() - .map(|prog| PyChainProgress(prog).into_py(py)), - ); - callback.call1(py, (args,)) - }); - }); - Some(ProgressCallback { - callback, - rate: Duration::from_millis(500), - }) + let handler = ProgressHandler::new(callback, rate, template, n_cores); + let callback = handler.into_callback()?; + + Ok(Some(callback)) } - None => None, + None => Ok(None), } } @@ -270,9 +264,13 @@ impl PySampler { settings: PyDiagGradNutsSettings, cores: usize, model: PyMcModel, + template: String, + rate: u64, callback: Option>, ) -> PyResult { - let sampler = Sampler::new(model, settings.0, cores, make_callback(callback))?; + let rate = Duration::from_millis(rate); + let callback = make_callback(template, cores, rate, callback)?; + let sampler = Sampler::new(model, settings.0, cores, callback)?; Ok(PySampler(SamplerState::Running(sampler))) } @@ -281,9 +279,13 @@ impl PySampler { settings: PyDiagGradNutsSettings, cores: usize, model: StanModel, + template: String, + rate: u64, callback: Option>, ) -> PyResult { - let sampler = Sampler::new(model, settings.0, cores, make_callback(callback))?; + let rate = Duration::from_millis(rate); + let callback = make_callback(template, cores, rate, callback)?; + let sampler = Sampler::new(model, settings.0, cores, callback)?; Ok(PySampler(SamplerState::Running(sampler))) }