diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 64d6829fc8..ae1787686f 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -338,6 +338,7 @@ def _sample_external_nuts( UserWarning, ) compile_kwargs = {} + nuts_sampler_kwargs = nuts_sampler_kwargs.copy() for kwarg in ("backend", "gradient_backend"): if kwarg in nuts_sampler_kwargs: compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)