Skip to content

Commit

Permalink
mypy + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Jan 25, 2025
1 parent 4e535d4 commit b9b0583
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
10 changes: 6 additions & 4 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _sample_external_nuts(
initvals: StartDict | Sequence[StartDict | None] | None,
model: Model,
var_names: Sequence[str] | None,
progressbar: bool,
progressbar: bool | ProgressType,
idata_kwargs: dict | None,
compute_convergence_checks: bool,
nuts_sampler_kwargs: dict | None,
Expand Down Expand Up @@ -401,7 +401,7 @@ def _sample_external_nuts(
initvals=initvals,
model=model,
var_names=var_names,
progressbar=progressbar,
progressbar=True if progressbar else False,
nuts_sampler=sampler,
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
Expand Down Expand Up @@ -488,7 +488,7 @@ def sample(
cores: int | None = None,
random_seed: RandomState = None,
progressbar: bool | ProgressType = True,
progressbar_theme: Theme | None = default_progress_theme,
progressbar_theme: Theme | None = None,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
Expand Down Expand Up @@ -831,7 +831,9 @@ def joined_blas_limiter():
n_init=n_init,
model=model,
random_seed=random_seed_list,
progressbar=progressbar,
progressbar=True
if progressbar
else False, # ADVI doesn't use the ProgressManager; pass a bool only
jitter_max_retries=jitter_max_retries,
tune=tune,
initvals=initvals,
Expand Down
9 changes: 6 additions & 3 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from pymc.exceptions import BlockModelAccessError

if TYPE_CHECKING:
from pymc import BlockedStep
from pymc.step_methods.compound import BlockedStep, CompoundStep


ProgressType = Literal[
Expand Down Expand Up @@ -727,12 +727,12 @@ class ProgressManager:

def __init__(
self,
step_method: BlockedStep,
step_method: "BlockedStep" | "CompoundStep",
chains: int,
draws: int,
tune: int,
progressbar: bool | ProgressType = True,
progressbar_theme: Theme = default_progress_theme,
progressbar_theme: Theme | None = None,
):
"""
Manage progress bars displayed during sampling.
Expand Down Expand Up @@ -770,6 +770,9 @@ def __init__(
progressbar_theme: Theme, optional
The theme to use for the progress bar. Defaults to the default theme.
"""
if progressbar_theme is None:
progressbar_theme = default_progress_theme

self.combined_progress = False
self.full_stats = True
show_progress = True
Expand Down

0 comments on commit b9b0583

Please sign in to comment.