From 290c0ba796f19e6526c05fa0865e6f6437344664 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Wed, 5 Mar 2025 14:07:29 -0800 Subject: [PATCH] incremental qLogNEI (#2760) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2760 This diff adds an incremental qLogNEI, that addresses many cases where the first candidate in the batch has positive EI (and satisfies the constraints) and subsequent arms violate the constraints (often severely). The issue appears to stem from optimizing the joint EI of the new candidate and the pending points w.r.t the current incumbent(s). My hypothesis is that this makes the initialization strategy perform worse and choose bad starting points. Using sequential batch optimization and optimizing the incremental EI of the new arm relative to the pending points (and the current incumbent) avoids the issue by only quanitifying the improvment of the current arm being optimized. TODO: add this for qNEI in a later diff, but that seems low pri since qLogNEI is widely used. Reviewed By: esantorella Differential Revision: D70288526 fbshipit-source-id: 9086bec5f077ac07d090d5fdf54b4284d784204d --- botorch/acquisition/input_constructors.py | 5 ++ botorch/acquisition/logei.py | 86 ++++++++++++++++--- botorch/acquisition/multi_objective/parego.py | 5 ++ .../multi_objective/test_parego.py | 5 ++ test/acquisition/test_logei.py | 35 ++++++++ 5 files changed, 126 insertions(+), 10 deletions(-) diff --git a/botorch/acquisition/input_constructors.py b/botorch/acquisition/input_constructors.py index e9262da92d..f89d6c4ae3 100644 --- a/botorch/acquisition/input_constructors.py +++ b/botorch/acquisition/input_constructors.py @@ -650,6 +650,7 @@ def construct_inputs_qLogNEI( fat: bool = True, tau_max: float = TAU_MAX, tau_relu: float = TAU_RELU, + incremental: bool = True, ): r"""Construct kwargs for the `qLogNoisyExpectedImprovement` constructor. @@ -684,6 +685,9 @@ def construct_inputs_qLogNEI( approximations to max. tau_relu: Temperature parameter controlling the sharpness of the smooth approximations to ReLU. + incremental: Whether to compute incremental EI over the pending points + or compute EI of the joint batch improvement (including pending + points). Returns: A dict mapping kwarg names of the constructor to values. @@ -705,6 +709,7 @@ def construct_inputs_qLogNEI( "fat": fat, "tau_max": tau_max, "tau_relu": tau_relu, + "incremental": incremental, } diff --git a/botorch/acquisition/logei.py b/botorch/acquisition/logei.py index 2d95b9f49f..0254bd90aa 100644 --- a/botorch/acquisition/logei.py +++ b/botorch/acquisition/logei.py @@ -250,6 +250,11 @@ class qLogNoisyExpectedImprovement( where `(Y, Y_baseline) ~ f((X, X_baseline)), X = (x_1,...,x_q)`. + For optimizing a batch of `q > 1` points using sequential greedy optimization, + the incremental improvement from the latest point is computed and returned by + default. I.e. the pending points are treated X_baseline. Often, the incremental + EI is easier to optimize. + Example: >>> model = SingleTaskGP(train_X, train_Y) >>> sampler = SobolQMCNormalSampler(1024) @@ -273,6 +278,7 @@ def __init__( tau_max: float = TAU_MAX, tau_relu: float = TAU_RELU, marginalize_dim: int | None = None, + incremental: bool = True, ) -> None: r"""q-Noisy Expected Improvement. @@ -312,6 +318,9 @@ def __init__( tau_relu: Temperature parameter controlling the sharpness of the smooth approximations to ReLU. marginalize_dim: The dimension to marginalize over. + incremental: Whether to compute incremental EI over the pending points + or compute EI of the joint batch improvement (including pending + points). TODO: similar to qNEHVI, when we are using sequential greedy candidate selection, we could incorporate pending points X_baseline and compute @@ -320,27 +329,34 @@ def __init__( """ # TODO: separate out baseline variables initialization and other functions # in qNEI to avoid duplication of both code and work at runtime. + self.incremental = incremental + super().__init__( model=model, sampler=sampler, objective=objective, posterior_transform=posterior_transform, - X_pending=X_pending, + # we set X_pending in init_baseline for incremental NEI + X_pending=X_pending if not incremental else None, constraints=constraints, eta=eta, fat=fat, tau_max=tau_max, ) self.tau_relu = tau_relu + self.prune_baseline = prune_baseline + self.marginalize_dim = marginalize_dim + if incremental: + self.X_pending = None # required to initialize attribute for optimize_acqf self._init_baseline( model=model, X_baseline=X_baseline, + # This is ignored in incremental=False + X_pending=X_pending, sampler=sampler, objective=objective, posterior_transform=posterior_transform, - prune_baseline=prune_baseline, cache_root=cache_root, - marginalize_dim=marginalize_dim, ) def _sample_forward(self, obj: Tensor) -> Tensor: @@ -364,26 +380,34 @@ def _init_baseline( self, model: Model, X_baseline: Tensor, + X_pending: Tensor | None = None, sampler: MCSampler | None = None, objective: MCAcquisitionObjective | None = None, posterior_transform: PosteriorTransform | None = None, - prune_baseline: bool = False, cache_root: bool = True, - marginalize_dim: int | None = None, ) -> None: CachedCholeskyMCSamplerMixin.__init__( self, model=model, cache_root=cache_root, sampler=sampler ) - if prune_baseline: + if self.prune_baseline: X_baseline = prune_inferior_points( model=model, X=X_baseline, objective=objective, posterior_transform=posterior_transform, - marginalize_dim=marginalize_dim, + marginalize_dim=self.marginalize_dim, constraints=self._constraints, ) - self.register_buffer("X_baseline", X_baseline) + self.register_buffer("_X_baseline", X_baseline) + # full_X_baseline is the set of points that should be considered as the + # incumbent. For incremental EI, this contains the previously evaluated + # points (X_baseline) and pending points (X_pending). For non-incremental + # EI, this contains the previously evaluated points (X_baseline). + if X_pending is not None and self.incremental: + full_X_baseline = torch.cat([X_baseline, X_pending], dim=-2) + else: + full_X_baseline = X_baseline + self.register_buffer("_full_X_baseline", full_X_baseline) # registering buffers for _get_samples_and_objectives in the next `if` block self.register_buffer("baseline_samples", None) self.register_buffer("baseline_obj", None) @@ -392,7 +416,7 @@ def _init_baseline( # set baseline samples with torch.no_grad(): # this is _get_samples_and_objectives(X_baseline) posterior = self.model.posterior( - X_baseline, posterior_transform=self.posterior_transform + self.X_baseline, posterior_transform=self.posterior_transform ) # Note: The root decomposition is cached in two different places. It # may be confusing to have two different caches, but this is not @@ -404,7 +428,9 @@ def _init_baseline( # - self._baseline_L allows a root decomposition to be persisted outside # this method. self.baseline_samples = self.get_posterior_samples(posterior) - self.baseline_obj = self.objective(self.baseline_samples, X=X_baseline) + self.baseline_obj = self.objective( + self.baseline_samples, X=self.X_baseline + ) # We make a copy here because we will write an attribute `base_samples` # to `self.base_sampler.base_samples`, and we don't want to mutate @@ -418,6 +444,46 @@ def _init_baseline( ) self._baseline_L = self._compute_root_decomposition(posterior=posterior) + @property + def X_baseline(self) -> Tensor: + """Returns the set of pointsthat should be considered as the incumbent. + + For incremental EI, this contains the previously evaluated points + (X_baseline) and pending points (X_pending). For non-incremental + EI, this contains the previously evaluated points (X_baseline). + """ + return self._full_X_baseline + + def set_X_pending(self, X_pending: Tensor | None = None) -> None: + r"""Informs the acquisition function about pending design points. + + Here pending points are concatenated with X_baseline and incremental + NEI is computed. + + Args: + X_pending: `n x d` Tensor with `n` `d`-dim design points that have + been submitted for evaluation but have not yet been evaluated. + """ + if not self.incremental: + return super().set_X_pending(X_pending=X_pending) + if X_pending is None: + if not hasattr(self, "_full_X_baseline") or ( + self._full_X_baseline.shape[-2] == self._X_baseline.shape[-2] + ): + return + else: + # reset pending points + X_pending = None + self._init_baseline( + model=self.model, + X_baseline=self._X_baseline, + X_pending=X_pending, + sampler=self.sampler, + objective=self.objective, + posterior_transform=self.posterior_transform, + cache_root=self._cache_root, + ) + def compute_best_f(self, obj: Tensor) -> Tensor: """Computes the best (feasible) noisy objective value. diff --git a/botorch/acquisition/multi_objective/parego.py b/botorch/acquisition/multi_objective/parego.py index da8ea0b066..a67f3878d9 100644 --- a/botorch/acquisition/multi_objective/parego.py +++ b/botorch/acquisition/multi_objective/parego.py @@ -35,6 +35,7 @@ def __init__( cache_root: bool = True, tau_relu: float = TAU_RELU, tau_max: float = TAU_MAX, + incremental: bool = True, ) -> None: r"""q-LogNParEGO supporting m >= 2 outcomes. This acquisition function utilizes qLogNEI to compute the expected improvement over Chebyshev @@ -88,6 +89,9 @@ def __init__( approximations to max. tau_relu: Temperature parameter controlling the sharpness of the smooth approximations to ReLU. + incremental: Whether to compute incremental EI over the pending points + or compute EI of the joint batch improvement (including pending + points). """ MultiObjectiveMCAcquisitionFunction.__init__( self, @@ -134,6 +138,7 @@ def __init__( cache_root=cache_root, tau_max=tau_max, tau_relu=tau_relu, + incremental=incremental, ) # Set these after __init__ calls so that they're not overwritten / deleted. # These are intended mainly for easier debugging & transparency. diff --git a/test/acquisition/multi_objective/test_parego.py b/test/acquisition/multi_objective/test_parego.py index 3a8e99f32d..a3749d1d13 100644 --- a/test/acquisition/multi_objective/test_parego.py +++ b/test/acquisition/multi_objective/test_parego.py @@ -26,6 +26,7 @@ def base_test_parego( with_scalarization_weights: bool = False, with_objective: bool = False, model: Model | None = None, + incremental: bool = True, ) -> None: if with_constraints: assert with_objective, "Objective must be specified if constraints are." @@ -57,6 +58,7 @@ def base_test_parego( objective=objective, constraints=constraints, prune_baseline=True, + incremental=incremental, ) self.assertEqual(acqf.Y_baseline.shape, torch.Size([3, 2])) # Scalarization weights should be set if given and sampled otherwise. @@ -102,6 +104,9 @@ def test_parego_with_constraints_objective_weights(self) -> None: with_constraints=True, with_objective=True, with_scalarization_weights=True ) + def test_parego_with_non_incremental_ei(self) -> None: + self.base_test_parego(incremental=False) + def test_parego_with_ensemble_model(self) -> None: tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double} models = [] diff --git a/test/acquisition/test_logei.py b/test/acquisition/test_logei.py index 6d5ebb7695..667ee3e0c0 100644 --- a/test/acquisition/test_logei.py +++ b/test/acquisition/test_logei.py @@ -404,6 +404,7 @@ def test_q_log_noisy_expected_improvement(self): "sampler": sampler, "prune_baseline": False, "cache_root": False, + "incremental": False, } # copy for log version log_acqf = qLogNoisyExpectedImprovement(**kwargs) @@ -422,6 +423,40 @@ def test_q_log_noisy_expected_improvement(self): self.assertEqual(log_acqf.X_pending, X2) self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1) + # test incremental + # Check that adding a pending point is equivalent to adding a point to + # X_baseline + for cache_root in (True, False): + kwargs = { + "model": mm_noisy_pending, + "X_baseline": X_baseline, + "sampler": sampler, + "prune_baseline": False, + "cache_root": cache_root, + "incremental": True, + } + log_acqf = qLogNoisyExpectedImprovement(**kwargs) + log_acqf.set_X_pending(X) + self.assertIsNone(log_acqf.X_pending) + self.assertTrue( + torch.equal(log_acqf.X_baseline, torch.cat([X_baseline, X], dim=0)) + ) + af_val1 = log_acqf(X2) + kwargs = { + "model": mm_noisy_pending, + "X_baseline": torch.cat([X_baseline, X], dim=-2), + "sampler": sampler, + "prune_baseline": False, + "cache_root": cache_root, + "incremental": False, + } + log_acqf2 = qLogNoisyExpectedImprovement(**kwargs) + af_val2 = log_acqf2(X2) + self.assertAllClose(af_val1.item(), af_val2.item()) + # test reseting X_pending + log_acqf.set_X_pending(None) + self.assertTrue(torch.equal(log_acqf.X_baseline, X_baseline)) + def test_q_noisy_expected_improvement_batch(self): for dtype in (torch.float, torch.double): # the event shape is `b x q x t` = 2 x 3 x 1