From e0048bf3a2d4074e0db60d2fc61ef62f44e36f2a Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Mon, 13 Jan 2025 20:11:04 +1100 Subject: [PATCH] ENH: Add jitter_scale parameter for initial point generation (#7555) --- pymc/initial_point.py | 8 +++++++- pymc/sampling/mcmc.py | 4 +++- tests/test_initial_point.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 241409f6834..a59d7355d16 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -66,6 +66,7 @@ def make_initial_point_fns_per_chain( model, overrides: StartDict | Sequence[StartDict | None] | None, jitter_rvs: set[TensorVariable] | None = None, + jitter_scale: float = 1.0, chains: int, ) -> list[Callable]: """Create an initial point function for each chain, as defined by initvals. @@ -96,6 +97,7 @@ def make_initial_point_fns_per_chain( model=model, overrides=overrides, jitter_rvs=jitter_rvs, + jitter_scale=jitter_scale, return_transformed=True, ) ] * chains @@ -104,6 +106,7 @@ def make_initial_point_fns_per_chain( make_initial_point_fn( model=model, jitter_rvs=jitter_rvs, + jitter_scale=jitter_scale, overrides=chain_overrides, return_transformed=True, ) @@ -122,6 +125,7 @@ def make_initial_point_fn( model, overrides: StartDict | None = None, jitter_rvs: set[TensorVariable] | None = None, + jitter_scale: float = 1.0, default_strategy: str = "support_point", return_transformed: bool = True, ) -> Callable: @@ -150,6 +154,7 @@ def make_initial_point_fn( rvs_to_transforms=model.rvs_to_transforms, initval_strategies=initval_strats, jitter_rvs=jitter_rvs, + jitter_scale=jitter_scale, default_strategy=default_strategy, return_transformed=return_transformed, ) @@ -188,6 +193,7 @@ def make_initial_point_expression( rvs_to_transforms: dict[TensorVariable, Transform], initval_strategies: dict[TensorVariable, np.ndarray | Variable | str | None], jitter_rvs: set[TensorVariable] | None = None, + jitter_scale: float = 1.0, default_strategy: str = "support_point", return_transformed: bool = False, ) -> list[TensorVariable]: @@ -265,7 +271,7 @@ def make_initial_point_expression( value = transform.forward(value, *variable.owner.inputs) if variable in jitter_rvs: - jitter = pt.random.uniform(-1, 1, size=value.shape) + jitter = pt.random.uniform(-jitter_scale, jitter_scale, size=value.shape) jitter.name = f"{variable.name}_jitter" value = value + jitter diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index ca91325ff16..397fc3e3172 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1423,10 +1423,11 @@ def _init_jitter( initvals: StartDict | Sequence[StartDict | None] | None, seeds: Sequence[int] | np.ndarray, jitter: bool, + jitter_scale: float, jitter_max_retries: int, logp_dlogp_func=None, ) -> list[PointType]: - """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. + """Apply a uniform jitter in [-jitter_scale, jitter_scale] to the test value as starting point in each chain. ``model.check_start_vals`` is used to test whether the jittered starting values produce a finite log probability. Invalid values are resampled @@ -1449,6 +1450,7 @@ def _init_jitter( model=model, overrides=initvals, jitter_rvs=set(model.free_RVs) if jitter else set(), + jitter_scale=jitter_scale if jitter else 1.0, chains=len(seeds), ) diff --git a/tests/test_initial_point.py b/tests/test_initial_point.py index 9138f37b3e7..8f6bc56d29d 100644 --- a/tests/test_initial_point.py +++ b/tests/test_initial_point.py @@ -152,6 +152,34 @@ def test_adds_jitter(self): assert fn(0) == fn(0) assert fn(0) != fn(1) + def test_jitter_scale(self): + with pm.Model() as pmodel: + A = pm.HalfFlat("A", initval="support_point") + + jitter_scale_tests = np.array([1.0, 2.0, 5.0]) + fns = [] + for jitter_scale in jitter_scale_tests: + fns.append( + make_initial_point_fn( + model=pmodel, + jitter_rvs=set(pmodel.free_RVs), + jitter_scale=jitter_scale, + return_transformed=True, + ) + ) + + n_draws = 1000 + jitter_samples = np.empty((n_draws, len(fns))) + for j, fn in enumerate(fns): + # start and end to ensure random samples, otherwise jitter_samples across different jitter_scale will be an exact scale of each other + start = j * n_draws + end = start + n_draws + jitter_samples[:, j] = np.asarray([fn(i)["A_log__"] for i in range(start, end)]) + + init_standardised = np.mean((jitter_samples / jitter_scale_tests), axis=0) + + assert np.all((-0.05 < init_standardised) & (init_standardised < 0.05)) + def test_respects_overrides(self): with pm.Model() as pmodel: A = pm.Flat("A", initval="support_point")