Skip to content

Commit

Permalink
Update progressbar managers with existing fit results from ZarrTrace
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Feb 21, 2025
1 parent 0ab1596 commit 61613a7
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 6 deletions.
13 changes: 12 additions & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,13 +1157,24 @@ def _sample_many(

with progress_manager:
for i in range(chains):
trace = traces[i]
if isinstance(trace, ZarrChain):
progress_manager.set_initial_state(*trace.completed_draws_and_divergences())
progress_manager._progress.update(
progress_manager.tasks[i],
draws=progress_manager.completed_draws
if progress_manager.combined_progress
else progress_manager.draws,
divergences=progress_manager.divergences,
refresh=True,
)
step.sampling_state = initial_step_state
_sample(
draws=draws,
chain=i,
start=start[i],
step=step,
trace=traces[i],
trace=trace,
rng=rngs[i],
callback=callback,
progress_manager=progress_manager,
Expand Down
4 changes: 4 additions & 0 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,10 @@ def __init__(
progressbar=progressbar,
progressbar_theme=progressbar_theme,
)
if self.zarr_recording:
self._progress.set_initial_state(
*cast(ZarrChain, zarr_chains)[0].completed_draws_and_divergences()
)

def _make_active(self):
while self._inactive and len(self._active) < self._max_active:
Expand Down
4 changes: 4 additions & 0 deletions pymc/sampling/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def _sample_population(

with CustomProgress(disable=not progressbar) as progress:
task = progress.add_task("[red]Sampling...", total=draws)
if isinstance(traces[0], ZarrChain):
completed_draws, _ = traces[0].completed_draws_and_divergences()
progress.update(task, completed=completed_draws)
for _ in sampling:
progress.update(task)

Expand Down Expand Up @@ -197,6 +200,7 @@ def __init__(
# enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
# ):
task = self._progress.add_task(description=f"Chain {c}")
self._progress.update(task, completed=first_draw_idx)
secondary_end, primary_end = multiprocessing.Pipe()
stepper_dumps = cloudpickle.dumps(stepper, protocol=4)
process = multiprocessing.Process(
Expand Down
19 changes: 14 additions & 5 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ def __init__(

self._show_progress = show_progress
self.divergences = 0
self.draws = 0
self.completed_draws = 0
self.total_draws = draws + tune
self.desc = "Sampling chain"
Expand All @@ -827,27 +828,35 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
return self._progress.__exit__(exc_type, exc_val, exc_tb)

def set_initial_state(self, draws: int = 0, divergences: int = 0):
self.draws = draws
self.completed_draws += draws
self.divergences = divergences

def _initialize_tasks(self):
if self.combined_progress:
self.tasks = [
self._progress.add_task(
self.desc.format(self),
completed=0,
draws=0,
completed=self.completed_draws,
draws=self.completed_draws,
total=self.total_draws * self.chains - 1,
chain_idx=0,
sampling_speed=0,
speed_unit="draws/s",
**{stat: value[0] for stat, value in self.progress_stats.items()},
**{
stat: value[0] if stat != "diverging" else self.divergences
for stat, value in self.progress_stats.items()
},
)
]

else:
self.tasks = [
self._progress.add_task(
self.desc.format(self),
completed=0,
draws=0,
completed=self.completed_draws,
draws=self.draws,
total=self.total_draws - 1,
chain_idx=chain_idx,
sampling_speed=0,
Expand Down

0 comments on commit 61613a7

Please sign in to comment.