From caa9501b45d6cefa9ef888bdf4f7ae3126a14d5a Mon Sep 17 00:00:00 2001 From: Adarsh Dubey <84132532+inclinedadarsh@users.noreply.github.com> Date: Mon, 20 Jan 2025 10:59:36 +0530 Subject: [PATCH 1/3] fix: deep copy nuts_sampler_kwarg to prevent pop side effects --- pymc/sampling/mcmc.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 64d6829fc8..ffa6310f8e 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -309,8 +309,11 @@ def _sample_external_nuts( nuts_sampler_kwargs: dict | None, **kwargs, ): - if nuts_sampler_kwargs is None: - nuts_sampler_kwargs = {} + import copy + + nuts_sampler_kwargs_copy = copy.deepcopy(nuts_sampler_kwargs) + if nuts_sampler_kwargs_copy is None: + nuts_sampler_kwargs_copy = {} if sampler == "nutpie": try: @@ -339,8 +342,8 @@ def _sample_external_nuts( ) compile_kwargs = {} for kwarg in ("backend", "gradient_backend"): - if kwarg in nuts_sampler_kwargs: - compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg) + if kwarg in nuts_sampler_kwargs_copy: + compile_kwargs[kwarg] = nuts_sampler_kwargs_copy.pop(kwarg) compiled_model = nutpie.compile_pymc_model( model, **compile_kwargs, @@ -354,7 +357,7 @@ def _sample_external_nuts( target_accept=target_accept, seed=_get_seeds_per_chain(random_seed, 1)[0], progress_bar=progressbar, - **nuts_sampler_kwargs, + **nuts_sampler_kwargs_copy, ) t_sample = time.time() - t_start # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed @@ -406,7 +409,7 @@ def _sample_external_nuts( nuts_sampler=sampler, idata_kwargs=idata_kwargs, compute_convergence_checks=compute_convergence_checks, - **nuts_sampler_kwargs, + **nuts_sampler_kwargs_copy, ) return idata @@ -686,6 +689,9 @@ def sample( mean sd hdi_3% hdi_97% p 0.609 0.047 0.528 0.699 """ + import copy + + nuts_sampler_kwargs_copy = copy.deepcopy(nuts_sampler_kwargs) if "start" in kwargs: if initvals is not None: raise ValueError("Passing both `start` and `initvals` is not supported.") @@ -695,8 +701,8 @@ def sample( stacklevel=2, ) initvals = kwargs.pop("start") - if nuts_sampler_kwargs is None: - nuts_sampler_kwargs = {} + if nuts_sampler_kwargs_copy is None: + nuts_sampler_kwargs_copy = {} if "target_accept" in kwargs: if "nuts" in kwargs and "target_accept" in kwargs["nuts"]: raise ValueError( @@ -808,7 +814,7 @@ def joined_blas_limiter(): progressbar=progressbar, idata_kwargs=idata_kwargs, compute_convergence_checks=compute_convergence_checks, - nuts_sampler_kwargs=nuts_sampler_kwargs, + nuts_sampler_kwargs=nuts_sampler_kwargs_copy, **kwargs, ) From e6c5b1a20996e4ba5ac5cd54305dea8b5a8ee6ea Mon Sep 17 00:00:00 2001 From: Adarsh Dubey <84132532+inclinedadarsh@users.noreply.github.com> Date: Mon, 20 Jan 2025 19:43:58 +0530 Subject: [PATCH 2/3] fix: replace deep copy with shalow copy --- pymc/sampling/mcmc.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index ffa6310f8e..6e8f142898 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -309,11 +309,9 @@ def _sample_external_nuts( nuts_sampler_kwargs: dict | None, **kwargs, ): - import copy - - nuts_sampler_kwargs_copy = copy.deepcopy(nuts_sampler_kwargs) - if nuts_sampler_kwargs_copy is None: - nuts_sampler_kwargs_copy = {} + nuts_sampler_kwargs = nuts_sampler_kwargs.copy() + if nuts_sampler_kwargs is None: + nuts_sampler_kwargs = {} if sampler == "nutpie": try: @@ -342,8 +340,8 @@ def _sample_external_nuts( ) compile_kwargs = {} for kwarg in ("backend", "gradient_backend"): - if kwarg in nuts_sampler_kwargs_copy: - compile_kwargs[kwarg] = nuts_sampler_kwargs_copy.pop(kwarg) + if kwarg in nuts_sampler_kwargs: + compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg) compiled_model = nutpie.compile_pymc_model( model, **compile_kwargs, @@ -357,7 +355,7 @@ def _sample_external_nuts( target_accept=target_accept, seed=_get_seeds_per_chain(random_seed, 1)[0], progress_bar=progressbar, - **nuts_sampler_kwargs_copy, + **nuts_sampler_kwargs, ) t_sample = time.time() - t_start # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed @@ -409,7 +407,7 @@ def _sample_external_nuts( nuts_sampler=sampler, idata_kwargs=idata_kwargs, compute_convergence_checks=compute_convergence_checks, - **nuts_sampler_kwargs_copy, + **nuts_sampler_kwargs, ) return idata @@ -689,9 +687,7 @@ def sample( mean sd hdi_3% hdi_97% p 0.609 0.047 0.528 0.699 """ - import copy - - nuts_sampler_kwargs_copy = copy.deepcopy(nuts_sampler_kwargs) + nuts_sampler_kwargs = nuts_sampler_kwargs.copy() if "start" in kwargs: if initvals is not None: raise ValueError("Passing both `start` and `initvals` is not supported.") @@ -701,8 +697,8 @@ def sample( stacklevel=2, ) initvals = kwargs.pop("start") - if nuts_sampler_kwargs_copy is None: - nuts_sampler_kwargs_copy = {} + if nuts_sampler_kwargs is None: + nuts_sampler_kwargs = {} if "target_accept" in kwargs: if "nuts" in kwargs and "target_accept" in kwargs["nuts"]: raise ValueError( @@ -814,7 +810,7 @@ def joined_blas_limiter(): progressbar=progressbar, idata_kwargs=idata_kwargs, compute_convergence_checks=compute_convergence_checks, - nuts_sampler_kwargs=nuts_sampler_kwargs_copy, + nuts_sampler_kwargs=nuts_sampler_kwargs, **kwargs, ) From 4db8ed642deb169ee20b8b9312acabd44f3441b3 Mon Sep 17 00:00:00 2001 From: Adarsh Dubey <84132532+inclinedadarsh@users.noreply.github.com> Date: Tue, 21 Jan 2025 00:04:09 +0530 Subject: [PATCH 3/3] fix: remove unnecessary shallow copy --- pymc/sampling/mcmc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 6e8f142898..ae1787686f 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -309,7 +309,6 @@ def _sample_external_nuts( nuts_sampler_kwargs: dict | None, **kwargs, ): - nuts_sampler_kwargs = nuts_sampler_kwargs.copy() if nuts_sampler_kwargs is None: nuts_sampler_kwargs = {} @@ -339,6 +338,7 @@ def _sample_external_nuts( UserWarning, ) compile_kwargs = {} + nuts_sampler_kwargs = nuts_sampler_kwargs.copy() for kwarg in ("backend", "gradient_backend"): if kwarg in nuts_sampler_kwargs: compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg) @@ -687,7 +687,6 @@ def sample( mean sd hdi_3% hdi_97% p 0.609 0.047 0.528 0.699 """ - nuts_sampler_kwargs = nuts_sampler_kwargs.copy() if "start" in kwargs: if initvals is not None: raise ValueError("Passing both `start` and `initvals` is not supported.")