From 741cf36b6c0e0de51b9ea8d14dbfb21c077aa62c Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 23 Jan 2025 23:49:08 +0800 Subject: [PATCH] Explicit case handling for `progressbar` argument --- pymc/util.py | 50 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/pymc/util.py b/pymc/util.py index 3d727d78c4..8f41e47f38 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -704,28 +704,38 @@ def callbacks(self, task: "Task"): class ProgressManager: def __init__(self, step_method, chains, draws, tune, progressbar, progressbar_theme): - mode = "chain" - stats = "full" - - if isinstance(progressbar, bool): - show_progress = progressbar - else: - show_progress = True - - if "+" in progressbar: - mode, stats = progressbar.split("+") - else: - mode = progressbar - stats = "full" - - if mode not in ["chain", "combined"]: - raise ValueError('Invalid mode. Valid values are "chain" and "combined"') - if stats not in ["full", "simple"]: - raise ValueError('Invalid stats. Valid values are "full" and "simple"') + self.combined_progress = False + self.full_stats = True + show_progress = True + + match progressbar: + case True: + show_progress = True + case False: + show_progress = False + case "combined": + self.combined_progress = True + case "chain": + self.combined_progress = False + case "combined+full": + self.combined_progress = True + self.full_stats = True + case "combined+simple": + self.combined_progress = True + self.full_stats = False + case "chain+full": + self.combined_progress = False + self.full_stats = True + case "chain+simple": + self.combined_progress = False + self.full_stats = False + case _: + raise ValueError( + "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " + "or one of 'combined', 'chain', 'combined+full', 'combined+simple', 'chain+full', 'chain+simple'." + ) progress_columns, progress_stats = step_method._progressbar_config(chains) - self.combined_progress = mode == "combined" - self.full_stats = stats == "full" self._progress = self.create_progress_bar( progress_columns,