Skip to content

Commit

Permalink
incremental qLogNEI (#2760)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2760

This diff adds an incremental qLogNEI, that addresses many cases where the first candidate in the batch has positive EI (and satisfies the constraints) and subsequent arms violate the constraints (often severely).

The issue appears to stem from optimizing the joint EI of the new candidate and the pending points w.r.t the current incumbent(s). My hypothesis is that this makes the initialization strategy perform worse and choose bad starting points. Using sequential batch optimization and optimizing the incremental EI of the new arm relative to the pending points (and the current incumbent) avoids the issue by only quanitifying the improvment of the current arm being optimized.

TODO: add this for qNEI in a later diff, but that seems low pri since qLogNEI is widely used.

Reviewed By: esantorella

Differential Revision: D70288526

fbshipit-source-id: 9086bec5f077ac07d090d5fdf54b4284d784204d
  • Loading branch information
sdaulton authored and facebook-github-bot committed Mar 5, 2025
1 parent 78c04e2 commit 290c0ba
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 10 deletions.
5 changes: 5 additions & 0 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ def construct_inputs_qLogNEI(
fat: bool = True,
tau_max: float = TAU_MAX,
tau_relu: float = TAU_RELU,
incremental: bool = True,
):
r"""Construct kwargs for the `qLogNoisyExpectedImprovement` constructor.
Expand Down Expand Up @@ -684,6 +685,9 @@ def construct_inputs_qLogNEI(
approximations to max.
tau_relu: Temperature parameter controlling the sharpness of the smooth
approximations to ReLU.
incremental: Whether to compute incremental EI over the pending points
or compute EI of the joint batch improvement (including pending
points).
Returns:
A dict mapping kwarg names of the constructor to values.
Expand All @@ -705,6 +709,7 @@ def construct_inputs_qLogNEI(
"fat": fat,
"tau_max": tau_max,
"tau_relu": tau_relu,
"incremental": incremental,
}


Expand Down
86 changes: 76 additions & 10 deletions botorch/acquisition/logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ class qLogNoisyExpectedImprovement(
where `(Y, Y_baseline) ~ f((X, X_baseline)), X = (x_1,...,x_q)`.
For optimizing a batch of `q > 1` points using sequential greedy optimization,
the incremental improvement from the latest point is computed and returned by
default. I.e. the pending points are treated X_baseline. Often, the incremental
EI is easier to optimize.
Example:
>>> model = SingleTaskGP(train_X, train_Y)
>>> sampler = SobolQMCNormalSampler(1024)
Expand All @@ -273,6 +278,7 @@ def __init__(
tau_max: float = TAU_MAX,
tau_relu: float = TAU_RELU,
marginalize_dim: int | None = None,
incremental: bool = True,
) -> None:
r"""q-Noisy Expected Improvement.
Expand Down Expand Up @@ -312,6 +318,9 @@ def __init__(
tau_relu: Temperature parameter controlling the sharpness of the smooth
approximations to ReLU.
marginalize_dim: The dimension to marginalize over.
incremental: Whether to compute incremental EI over the pending points
or compute EI of the joint batch improvement (including pending
points).
TODO: similar to qNEHVI, when we are using sequential greedy candidate
selection, we could incorporate pending points X_baseline and compute
Expand All @@ -320,27 +329,34 @@ def __init__(
"""
# TODO: separate out baseline variables initialization and other functions
# in qNEI to avoid duplication of both code and work at runtime.
self.incremental = incremental

super().__init__(
model=model,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
# we set X_pending in init_baseline for incremental NEI
X_pending=X_pending if not incremental else None,
constraints=constraints,
eta=eta,
fat=fat,
tau_max=tau_max,
)
self.tau_relu = tau_relu
self.prune_baseline = prune_baseline
self.marginalize_dim = marginalize_dim
if incremental:
self.X_pending = None # required to initialize attribute for optimize_acqf
self._init_baseline(
model=model,
X_baseline=X_baseline,
# This is ignored in incremental=False
X_pending=X_pending,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
prune_baseline=prune_baseline,
cache_root=cache_root,
marginalize_dim=marginalize_dim,
)

def _sample_forward(self, obj: Tensor) -> Tensor:
Expand All @@ -364,26 +380,34 @@ def _init_baseline(
self,
model: Model,
X_baseline: Tensor,
X_pending: Tensor | None = None,
sampler: MCSampler | None = None,
objective: MCAcquisitionObjective | None = None,
posterior_transform: PosteriorTransform | None = None,
prune_baseline: bool = False,
cache_root: bool = True,
marginalize_dim: int | None = None,
) -> None:
CachedCholeskyMCSamplerMixin.__init__(
self, model=model, cache_root=cache_root, sampler=sampler
)
if prune_baseline:
if self.prune_baseline:
X_baseline = prune_inferior_points(
model=model,
X=X_baseline,
objective=objective,
posterior_transform=posterior_transform,
marginalize_dim=marginalize_dim,
marginalize_dim=self.marginalize_dim,
constraints=self._constraints,
)
self.register_buffer("X_baseline", X_baseline)
self.register_buffer("_X_baseline", X_baseline)
# full_X_baseline is the set of points that should be considered as the
# incumbent. For incremental EI, this contains the previously evaluated
# points (X_baseline) and pending points (X_pending). For non-incremental
# EI, this contains the previously evaluated points (X_baseline).
if X_pending is not None and self.incremental:
full_X_baseline = torch.cat([X_baseline, X_pending], dim=-2)
else:
full_X_baseline = X_baseline
self.register_buffer("_full_X_baseline", full_X_baseline)
# registering buffers for _get_samples_and_objectives in the next `if` block
self.register_buffer("baseline_samples", None)
self.register_buffer("baseline_obj", None)
Expand All @@ -392,7 +416,7 @@ def _init_baseline(
# set baseline samples
with torch.no_grad(): # this is _get_samples_and_objectives(X_baseline)
posterior = self.model.posterior(
X_baseline, posterior_transform=self.posterior_transform
self.X_baseline, posterior_transform=self.posterior_transform
)
# Note: The root decomposition is cached in two different places. It
# may be confusing to have two different caches, but this is not
Expand All @@ -404,7 +428,9 @@ def _init_baseline(
# - self._baseline_L allows a root decomposition to be persisted outside
# this method.
self.baseline_samples = self.get_posterior_samples(posterior)
self.baseline_obj = self.objective(self.baseline_samples, X=X_baseline)
self.baseline_obj = self.objective(
self.baseline_samples, X=self.X_baseline
)

# We make a copy here because we will write an attribute `base_samples`
# to `self.base_sampler.base_samples`, and we don't want to mutate
Expand All @@ -418,6 +444,46 @@ def _init_baseline(
)
self._baseline_L = self._compute_root_decomposition(posterior=posterior)

@property
def X_baseline(self) -> Tensor:
"""Returns the set of pointsthat should be considered as the incumbent.
For incremental EI, this contains the previously evaluated points
(X_baseline) and pending points (X_pending). For non-incremental
EI, this contains the previously evaluated points (X_baseline).
"""
return self._full_X_baseline

def set_X_pending(self, X_pending: Tensor | None = None) -> None:
r"""Informs the acquisition function about pending design points.
Here pending points are concatenated with X_baseline and incremental
NEI is computed.
Args:
X_pending: `n x d` Tensor with `n` `d`-dim design points that have
been submitted for evaluation but have not yet been evaluated.
"""
if not self.incremental:
return super().set_X_pending(X_pending=X_pending)
if X_pending is None:
if not hasattr(self, "_full_X_baseline") or (
self._full_X_baseline.shape[-2] == self._X_baseline.shape[-2]
):
return
else:
# reset pending points
X_pending = None
self._init_baseline(
model=self.model,
X_baseline=self._X_baseline,
X_pending=X_pending,
sampler=self.sampler,
objective=self.objective,
posterior_transform=self.posterior_transform,
cache_root=self._cache_root,
)

def compute_best_f(self, obj: Tensor) -> Tensor:
"""Computes the best (feasible) noisy objective value.
Expand Down
5 changes: 5 additions & 0 deletions botorch/acquisition/multi_objective/parego.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
cache_root: bool = True,
tau_relu: float = TAU_RELU,
tau_max: float = TAU_MAX,
incremental: bool = True,
) -> None:
r"""q-LogNParEGO supporting m >= 2 outcomes. This acquisition function
utilizes qLogNEI to compute the expected improvement over Chebyshev
Expand Down Expand Up @@ -88,6 +89,9 @@ def __init__(
approximations to max.
tau_relu: Temperature parameter controlling the sharpness of the smooth
approximations to ReLU.
incremental: Whether to compute incremental EI over the pending points
or compute EI of the joint batch improvement (including pending
points).
"""
MultiObjectiveMCAcquisitionFunction.__init__(
self,
Expand Down Expand Up @@ -134,6 +138,7 @@ def __init__(
cache_root=cache_root,
tau_max=tau_max,
tau_relu=tau_relu,
incremental=incremental,
)
# Set these after __init__ calls so that they're not overwritten / deleted.
# These are intended mainly for easier debugging & transparency.
Expand Down
5 changes: 5 additions & 0 deletions test/acquisition/multi_objective/test_parego.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def base_test_parego(
with_scalarization_weights: bool = False,
with_objective: bool = False,
model: Model | None = None,
incremental: bool = True,
) -> None:
if with_constraints:
assert with_objective, "Objective must be specified if constraints are."
Expand Down Expand Up @@ -57,6 +58,7 @@ def base_test_parego(
objective=objective,
constraints=constraints,
prune_baseline=True,
incremental=incremental,
)
self.assertEqual(acqf.Y_baseline.shape, torch.Size([3, 2]))
# Scalarization weights should be set if given and sampled otherwise.
Expand Down Expand Up @@ -102,6 +104,9 @@ def test_parego_with_constraints_objective_weights(self) -> None:
with_constraints=True, with_objective=True, with_scalarization_weights=True
)

def test_parego_with_non_incremental_ei(self) -> None:
self.base_test_parego(incremental=False)

def test_parego_with_ensemble_model(self) -> None:
tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double}
models = []
Expand Down
35 changes: 35 additions & 0 deletions test/acquisition/test_logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def test_q_log_noisy_expected_improvement(self):
"sampler": sampler,
"prune_baseline": False,
"cache_root": False,
"incremental": False,
}
# copy for log version
log_acqf = qLogNoisyExpectedImprovement(**kwargs)
Expand All @@ -422,6 +423,40 @@ def test_q_log_noisy_expected_improvement(self):
self.assertEqual(log_acqf.X_pending, X2)
self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1)

# test incremental
# Check that adding a pending point is equivalent to adding a point to
# X_baseline
for cache_root in (True, False):
kwargs = {
"model": mm_noisy_pending,
"X_baseline": X_baseline,
"sampler": sampler,
"prune_baseline": False,
"cache_root": cache_root,
"incremental": True,
}
log_acqf = qLogNoisyExpectedImprovement(**kwargs)
log_acqf.set_X_pending(X)
self.assertIsNone(log_acqf.X_pending)
self.assertTrue(
torch.equal(log_acqf.X_baseline, torch.cat([X_baseline, X], dim=0))
)
af_val1 = log_acqf(X2)
kwargs = {
"model": mm_noisy_pending,
"X_baseline": torch.cat([X_baseline, X], dim=-2),
"sampler": sampler,
"prune_baseline": False,
"cache_root": cache_root,
"incremental": False,
}
log_acqf2 = qLogNoisyExpectedImprovement(**kwargs)
af_val2 = log_acqf2(X2)
self.assertAllClose(af_val1.item(), af_val2.item())
# test reseting X_pending
log_acqf.set_X_pending(None)
self.assertTrue(torch.equal(log_acqf.X_baseline, X_baseline))

def test_q_noisy_expected_improvement_batch(self):
for dtype in (torch.float, torch.double):
# the event shape is `b x q x t` = 2 x 3 x 1
Expand Down

0 comments on commit 290c0ba

Please sign in to comment.