Skip to content

Commit

Permalink
ENH: Add jitter_scale parameter for initial point generation (#7555)
Browse files Browse the repository at this point in the history
  • Loading branch information
aphc14 committed Jan 13, 2025
1 parent 671d704 commit e0048bf
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
8 changes: 7 additions & 1 deletion pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def make_initial_point_fns_per_chain(
model,
overrides: StartDict | Sequence[StartDict | None] | None,
jitter_rvs: set[TensorVariable] | None = None,
jitter_scale: float = 1.0,
chains: int,
) -> list[Callable]:
"""Create an initial point function for each chain, as defined by initvals.
Expand Down Expand Up @@ -96,6 +97,7 @@ def make_initial_point_fns_per_chain(
model=model,
overrides=overrides,
jitter_rvs=jitter_rvs,
jitter_scale=jitter_scale,
return_transformed=True,
)
] * chains
Expand All @@ -104,6 +106,7 @@ def make_initial_point_fns_per_chain(
make_initial_point_fn(
model=model,
jitter_rvs=jitter_rvs,
jitter_scale=jitter_scale,
overrides=chain_overrides,
return_transformed=True,
)
Expand All @@ -122,6 +125,7 @@ def make_initial_point_fn(
model,
overrides: StartDict | None = None,
jitter_rvs: set[TensorVariable] | None = None,
jitter_scale: float = 1.0,
default_strategy: str = "support_point",
return_transformed: bool = True,
) -> Callable:
Expand Down Expand Up @@ -150,6 +154,7 @@ def make_initial_point_fn(
rvs_to_transforms=model.rvs_to_transforms,
initval_strategies=initval_strats,
jitter_rvs=jitter_rvs,
jitter_scale=jitter_scale,
default_strategy=default_strategy,
return_transformed=return_transformed,
)
Expand Down Expand Up @@ -188,6 +193,7 @@ def make_initial_point_expression(
rvs_to_transforms: dict[TensorVariable, Transform],
initval_strategies: dict[TensorVariable, np.ndarray | Variable | str | None],
jitter_rvs: set[TensorVariable] | None = None,
jitter_scale: float = 1.0,
default_strategy: str = "support_point",
return_transformed: bool = False,
) -> list[TensorVariable]:
Expand Down Expand Up @@ -265,7 +271,7 @@ def make_initial_point_expression(
value = transform.forward(value, *variable.owner.inputs)

if variable in jitter_rvs:
jitter = pt.random.uniform(-1, 1, size=value.shape)
jitter = pt.random.uniform(-jitter_scale, jitter_scale, size=value.shape)
jitter.name = f"{variable.name}_jitter"
value = value + jitter

Expand Down
4 changes: 3 additions & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,10 +1423,11 @@ def _init_jitter(
initvals: StartDict | Sequence[StartDict | None] | None,
seeds: Sequence[int] | np.ndarray,
jitter: bool,
jitter_scale: float,
jitter_max_retries: int,
logp_dlogp_func=None,
) -> list[PointType]:
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
"""Apply a uniform jitter in [-jitter_scale, jitter_scale] to the test value as starting point in each chain.
``model.check_start_vals`` is used to test whether the jittered starting
values produce a finite log probability. Invalid values are resampled
Expand All @@ -1449,6 +1450,7 @@ def _init_jitter(
model=model,
overrides=initvals,
jitter_rvs=set(model.free_RVs) if jitter else set(),
jitter_scale=jitter_scale if jitter else 1.0,
chains=len(seeds),
)

Expand Down
28 changes: 28 additions & 0 deletions tests/test_initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,34 @@ def test_adds_jitter(self):
assert fn(0) == fn(0)
assert fn(0) != fn(1)

def test_jitter_scale(self):
with pm.Model() as pmodel:
A = pm.HalfFlat("A", initval="support_point")

jitter_scale_tests = np.array([1.0, 2.0, 5.0])
fns = []
for jitter_scale in jitter_scale_tests:
fns.append(
make_initial_point_fn(
model=pmodel,
jitter_rvs=set(pmodel.free_RVs),
jitter_scale=jitter_scale,
return_transformed=True,
)
)

n_draws = 1000
jitter_samples = np.empty((n_draws, len(fns)))
for j, fn in enumerate(fns):
# start and end to ensure random samples, otherwise jitter_samples across different jitter_scale will be an exact scale of each other
start = j * n_draws
end = start + n_draws
jitter_samples[:, j] = np.asarray([fn(i)["A_log__"] for i in range(start, end)])

init_standardised = np.mean((jitter_samples / jitter_scale_tests), axis=0)

assert np.all((-0.05 < init_standardised) & (init_standardised < 0.05))

def test_respects_overrides(self):
with pm.Model() as pmodel:
A = pm.Flat("A", initval="support_point")
Expand Down

0 comments on commit e0048bf

Please sign in to comment.