From 61613a7dd6f4ec89aa3331745e9a4f9f54c7de00 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Fri, 21 Feb 2025 11:12:24 +0100 Subject: [PATCH] Update progressbar managers with existing fit results from ZarrTrace --- pymc/sampling/mcmc.py | 13 ++++++++++++- pymc/sampling/parallel.py | 4 ++++ pymc/sampling/population.py | 4 ++++ pymc/util.py | 19 ++++++++++++++----- 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 416c907f61..9457273944 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1157,13 +1157,24 @@ def _sample_many( with progress_manager: for i in range(chains): + trace = traces[i] + if isinstance(trace, ZarrChain): + progress_manager.set_initial_state(*trace.completed_draws_and_divergences()) + progress_manager._progress.update( + progress_manager.tasks[i], + draws=progress_manager.completed_draws + if progress_manager.combined_progress + else progress_manager.draws, + divergences=progress_manager.divergences, + refresh=True, + ) step.sampling_state = initial_step_state _sample( draws=draws, chain=i, start=start[i], step=step, - trace=traces[i], + trace=trace, rng=rngs[i], callback=callback, progress_manager=progress_manager, diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 173b5aaac2..d0a1f31287 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -509,6 +509,10 @@ def __init__( progressbar=progressbar, progressbar_theme=progressbar_theme, ) + if self.zarr_recording: + self._progress.set_initial_state( + *cast(ZarrChain, zarr_chains)[0].completed_draws_and_divergences() + ) def _make_active(self): while self._inactive and len(self._active) < self._max_active: diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index b88608e08b..f74cbadb78 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -110,6 +110,9 @@ def _sample_population( with CustomProgress(disable=not progressbar) as progress: task = progress.add_task("[red]Sampling...", total=draws) + if isinstance(traces[0], ZarrChain): + completed_draws, _ = traces[0].completed_draws_and_divergences() + progress.update(task, completed=completed_draws) for _ in sampling: progress.update(task) @@ -197,6 +200,7 @@ def __init__( # enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) # ): task = self._progress.add_task(description=f"Chain {c}") + self._progress.update(task, completed=first_draw_idx) secondary_end, primary_end = multiprocessing.Pipe() stepper_dumps = cloudpickle.dumps(stepper, protocol=4) process = multiprocessing.Process( diff --git a/pymc/util.py b/pymc/util.py index 979b3beebf..b8362e776e 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -812,6 +812,7 @@ def __init__( self._show_progress = show_progress self.divergences = 0 + self.draws = 0 self.completed_draws = 0 self.total_draws = draws + tune self.desc = "Sampling chain" @@ -827,18 +828,26 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): return self._progress.__exit__(exc_type, exc_val, exc_tb) + def set_initial_state(self, draws: int = 0, divergences: int = 0): + self.draws = draws + self.completed_draws += draws + self.divergences = divergences + def _initialize_tasks(self): if self.combined_progress: self.tasks = [ self._progress.add_task( self.desc.format(self), - completed=0, - draws=0, + completed=self.completed_draws, + draws=self.completed_draws, total=self.total_draws * self.chains - 1, chain_idx=0, sampling_speed=0, speed_unit="draws/s", - **{stat: value[0] for stat, value in self.progress_stats.items()}, + **{ + stat: value[0] if stat != "diverging" else self.divergences + for stat, value in self.progress_stats.items() + }, ) ] @@ -846,8 +855,8 @@ def _initialize_tasks(self): self.tasks = [ self._progress.add_task( self.desc.format(self), - completed=0, - draws=0, + completed=self.completed_draws, + draws=self.draws, total=self.total_draws - 1, chain_idx=chain_idx, sampling_speed=0,