diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py index cfad4a0b6e..980762d871 100644 --- a/botorch/optim/optimize_homotopy.py +++ b/botorch/optim/optimize_homotopy.py @@ -3,7 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections.abc import Callable +from typing import Any +import warnings import torch from botorch.acquisition import AcquisitionFunction @@ -51,14 +52,11 @@ def optimize_acqf_homotopy( bounds: Tensor, q: int, homotopy: Homotopy, - num_restarts: int, - raw_samples: int | None = None, - fixed_features: dict[int, float] | None = None, - options: dict[str, bool | float | int | str] | None = None, - final_options: dict[str, bool | float | int | str] | None = None, - batch_initial_conditions: Tensor | None = None, - post_processing_func: Callable[[Tensor], Tensor] | None = None, + *, prune_tolerance: float = 1e-4, + batch_initial_conditions: Tensor | None = None, + optimize_acqf_loop_kwargs: dict[str, Any] | None = None, + optimize_acqf_final_kwargs: dict[str, Any] | None = None, ) -> tuple[Tensor, Tensor]: r"""Generate a set of candidates via multi-start optimization. @@ -68,19 +66,76 @@ def optimize_acqf_homotopy( q: The number of candidates. homotopy: Homotopy object that will make the necessary modifications to the problem when calling `step()`. - num_restarts: The number of starting points for multistart acquisition - function optimization. - raw_samples: The number of samples for initialization. This is required - if `batch_initial_conditions` is not specified. - fixed_features: A map `{feature_index: value}` for features that - should be fixed to a particular value during generation. - options: Options for candidate generation. - final_options: Options for candidate generation in the last homotopy step. + prune_tolerance: The minimum distance to prune candidates. batch_initial_conditions: A tensor to specify the initial conditions. Set this if you do not want to use default initialization strategy. - post_processing_func: Post processing function (such as rounding or clamping) - that is applied before choosing the final candidate. + optimize_acqf_loop_kwargs: A dictionary of keyword arguments for + `optimize_acqf`. These settings are used in the homotopy loop. + optimize_acqf_final_kwargs: A dictionary of keyword arguments for + `optimize_acqf`. These settings are used for the final optimization + after the homotopy loop. """ + if optimize_acqf_loop_kwargs is None: + optimize_acqf_loop_kwargs = {} + + if optimize_acqf_final_kwargs is None: + optimize_acqf_final_kwargs = {} + + for kwarg_dict_name, kwarg_dict in [ + ("optimize_acqf_loop_kwargs", optimize_acqf_loop_kwargs), + ("optimize_acqf_final_kwargs", optimize_acqf_final_kwargs), + ]: + if kwarg_dict.get("return_best_only", None) is not False: + warnings.warn( + f"`return_best_only` is not False in `{kwarg_dict_name}`, override to False." + ) + kwarg_dict["return_best_only"] = False + + if kwarg_dict.get("q", None) != 1: + warnings.warn(f"`q` is not set to 1 in `{kwarg_dict_name}`, override to 1.") + kwarg_dict["q"] = 1 + + if "batch_initial_conditions" in kwarg_dict: + warnings.warn( + f"`batch_initial_conditions` is set in `{kwarg_dict_name}`, " + "removing it in favour of `batch_initial_conditions` given to " + "`optimize_acqf_homotopy`." + ) + # are pops dangerious here given no copy? if repeatedly reusing kwarg_dict it could create issues + kwarg_dict.pop("batch_initial_conditions") + + for arg_name, arg_value in [("acq_function", acq_function), ("bounds", bounds)]: + if arg_name in kwarg_dict: + warnings.warn( + f"`{arg_name}` is set in `{kwarg_dict_name}` and will be " + "overridden in favor of the value in `optimize_acqf_homotopy`. " + f"({arg_name} = {arg_value} c.f. {kwarg_dict_name}[{arg_name}] = {kwarg_dict[arg_name]})" + ) + kwarg_dict[arg_name] = arg_value + + if ( + batch_initial_conditions is None + and optimize_acqf_loop_kwargs.get("raw_samples", None) is None + ): + raise ValueError( + "Must specify `raw_samples` in `optimize_acqf_loop_kwargs` when " + "`batch_initial_conditions` is None`." + ) + + if "post_processing_func" in optimize_acqf_loop_kwargs: + warnings.warn( + "`post_processing_func` is set in `optimize_acqf_loop_kwargs`, setting to None." + ) + optimize_acqf_loop_kwargs["post_processing_func"] = None + + if "raw_samples" in optimize_acqf_final_kwargs: + warnings.warn( + "`raw_samples` is set in `optimize_acqf_final_kwargs`, " + "removing it as we set `batch_initial_conditions` to the candidates " + "returned by homotopy loop for the final optimization." + ) + optimize_acqf_final_kwargs.pop("raw_samples") # are pops dangerious here given no copy? see above + candidate_list, acq_value_list = [], [] if q > 1: base_X_pending = acq_function.X_pending @@ -91,15 +146,10 @@ def optimize_acqf_homotopy( while not homotopy.should_stop: candidates, acq_values = optimize_acqf( - q=1, acq_function=acq_function, bounds=bounds, - num_restarts=num_restarts, batch_initial_conditions=candidates, - raw_samples=raw_samples, - fixed_features=fixed_features, - return_best_only=False, - options=options, + **optimize_acqf_loop_kwargs ) homotopy.step() @@ -111,20 +161,14 @@ def optimize_acqf_homotopy( ).unsqueeze(1) # Optimize one more time with the final options + # NOTE is there any reason we don't want to pass fixed features to final? candidates, acq_values = optimize_acqf( - q=1, acq_function=acq_function, bounds=bounds, - num_restarts=num_restarts, batch_initial_conditions=candidates, - return_best_only=False, - options=final_options, + **optimize_acqf_final_kwargs ) - # Post-process the candidates and grab the best candidate - if post_processing_func is not None: - candidates = post_processing_func(candidates) - acq_values = acq_function(candidates) best = torch.argmax(acq_values.view(-1), dim=0) candidate, acq_value = candidates[best], acq_values[best] @@ -132,6 +176,7 @@ def optimize_acqf_homotopy( candidate_list.append(candidate) acq_value_list.append(acq_value) selected_candidates = torch.cat(candidate_list, dim=-2) + if q > 1: acq_function.set_X_pending( torch.cat([base_X_pending, selected_candidates], dim=-2) @@ -141,6 +186,7 @@ def optimize_acqf_homotopy( if q > 1: # Reset acq_function to previous X_pending state acq_function.set_X_pending(base_X_pending) + homotopy.reset() # Reset the homotopy parameters return selected_candidates, torch.stack(acq_value_list) diff --git a/test/optim/test_homotopy.py b/test/optim/test_homotopy.py index ac3cd6142a..8a71b7d30e 100644 --- a/test/optim/test_homotopy.py +++ b/test/optim/test_homotopy.py @@ -114,14 +114,21 @@ def test_optimize_acqf_homotopy(self): ) model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2) acqf = PosteriorMean(model=model) + + optimize_acqf_core_kwargs = { + "num_restarts": 2, + "raw_samples": 16, + } candidate, acqf_val = optimize_acqf_homotopy( q=1, acq_function=acqf, bounds=torch.tensor([[-10], [5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - num_restarts=2, - raw_samples=16, - post_processing_func=lambda x: x.round(), + optimize_acqf_loop_kwargs={**optimize_acqf_core_kwargs}, + optimize_acqf_final_kwargs={ + **optimize_acqf_core_kwargs, + "post_processing_func": lambda x: x.round(), + }, ) self.assertEqual(candidate, torch.zeros(1, **tkwargs)) self.assertEqual(acqf_val, 5 * torch.ones(1, **tkwargs)) @@ -137,9 +144,11 @@ def test_optimize_acqf_homotopy(self): acq_function=acqf, bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - num_restarts=2, - raw_samples=16, - fixed_features=fixed_features, + optimize_acqf_loop_kwargs={ + **optimize_acqf_core_kwargs, + "fixed_features": fixed_features, # this is done to mimic old behaviour which was perhaps a bug? + }, + optimize_acqf_final_kwargs={**optimize_acqf_core_kwargs}, ) self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs)) @@ -150,13 +159,40 @@ def test_optimize_acqf_homotopy(self): acq_function=acqf, bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - num_restarts=2, - raw_samples=16, - fixed_features=fixed_features, + optimize_acqf_loop_kwargs={ + **optimize_acqf_core_kwargs, + "fixed_features": fixed_features, + }, + optimize_acqf_final_kwargs={**optimize_acqf_core_kwargs}, ) self.assertEqual(candidate.shape, torch.Size([3, 2])) self.assertEqual(acqf_val.shape, torch.Size([3])) + # with linear constraints + constraints = [( # X[..., 0] + X[..., 1] >= 2. + torch.tensor([0, 1], device=self.device), + torch.ones(2, device=self.device, dtype=torch.double), + 2.0, + )] + + acqf = PosteriorMean(model=model) + candidate, acqf_val = optimize_acqf_homotopy( + q=1, + acq_function=acqf, + bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + homotopy=Homotopy(homotopy_parameters=[hp]), + optimize_acqf_loop_kwargs={ + **optimize_acqf_core_kwargs, + "inequality_constraints": constraints, + }, + optimize_acqf_final_kwargs={ + **optimize_acqf_core_kwargs, + "inequality_constraints": constraints, + }, + ) + self.assertEqual(candidate.shape, torch.Size([1, 2])) + self.assertGreaterEqual(candidate.sum(), 2 * torch.ones(1, **tkwargs)) + def test_prune_candidates(self): tkwargs = {"device": self.device, "dtype": torch.double} # no pruning @@ -202,14 +238,21 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock): ) model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2) acqf = PosteriorMean(model=model) + optimize_acqf_core_kwargs = { + "num_restarts": 4, + "raw_samples": 16, + } + candidate, acqf_val = optimize_acqf_homotopy( q=1, acq_function=acqf, bounds=torch.tensor([[-10], [5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - num_restarts=4, - raw_samples=16, - post_processing_func=lambda x: x.round(), + optimize_acqf_loop_kwargs={**optimize_acqf_core_kwargs}, + optimize_acqf_final_kwargs={ + **optimize_acqf_core_kwargs, + "post_processing_func": lambda x: x.round(), + }, ) # First time we expect to call `prune_candidates` with 4 candidates self.assertEqual(