diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 02e5ea8550..3c006729c8 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -474,7 +474,7 @@ def __iter__(self): self._desc.format(self), completed=0, draws=0, - total=self._total_draws, + total=self._total_draws - 1, chain_idx=chain_idx, sampling_speed=0, speed_unit="draws/s", diff --git a/pymc/util.py b/pymc/util.py index 32b29e8298..3db30819bf 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -31,7 +31,14 @@ from pytensor.graph.utils import ValidatingScratchpad from rich.box import SIMPLE_HEAD from rich.console import Console -from rich.progress import BarColumn, Progress, Task, TextColumn +from rich.progress import ( + BarColumn, + Progress, + Task, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) from rich.style import Style from rich.table import Column, Table from rich.theme import Theme @@ -58,6 +65,8 @@ def __getattr__(name): { "bar.complete": "#1764f4", "bar.finished": "green", + "progress.remaining": "none", + "progress.elapsed": "none", } ) @@ -694,7 +703,9 @@ def create_progress_bar(step_columns, init_stat_dict, progressbar, progressbar_t TextColumn( "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", table_column=Column("Sampling Speed", ratio=1), - ) + ), + TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), + TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), ] return CustomProgress(