From 048fc7e0a9614f2e03cbbb31d5be3cf4f85c7c38 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 16 Oct 2024 17:25:16 -0400 Subject: [PATCH 1/5] fea: expose all kwargs to optimize_acqf in optimize_acqf_homotopy --- botorch/optim/optimize_homotopy.py | 97 +++++++++++++++++------------- 1 file changed, 55 insertions(+), 42 deletions(-) diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py index cfad4a0b6e..1d1cdcad86 100644 --- a/botorch/optim/optimize_homotopy.py +++ b/botorch/optim/optimize_homotopy.py @@ -3,7 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections.abc import Callable +from typing import Any +import warnings import torch from botorch.acquisition import AcquisitionFunction @@ -51,14 +52,11 @@ def optimize_acqf_homotopy( bounds: Tensor, q: int, homotopy: Homotopy, - num_restarts: int, - raw_samples: int | None = None, - fixed_features: dict[int, float] | None = None, - options: dict[str, bool | float | int | str] | None = None, - final_options: dict[str, bool | float | int | str] | None = None, - batch_initial_conditions: Tensor | None = None, - post_processing_func: Callable[[Tensor], Tensor] | None = None, + *, prune_tolerance: float = 1e-4, + batch_initial_conditions: Tensor | None = None, + optimize_acqf_loop_kwargs: dict[str, Any] | None = None, + optimize_acqf_final_kwargs: dict[str, Any] | None = None, ) -> tuple[Tensor, Tensor]: r"""Generate a set of candidates via multi-start optimization. @@ -68,19 +66,56 @@ def optimize_acqf_homotopy( q: The number of candidates. homotopy: Homotopy object that will make the necessary modifications to the problem when calling `step()`. - num_restarts: The number of starting points for multistart acquisition - function optimization. - raw_samples: The number of samples for initialization. This is required - if `batch_initial_conditions` is not specified. - fixed_features: A map `{feature_index: value}` for features that - should be fixed to a particular value during generation. - options: Options for candidate generation. - final_options: Options for candidate generation in the last homotopy step. + prune_tolerance: The minimum distance to prune candidates. batch_initial_conditions: A tensor to specify the initial conditions. Set this if you do not want to use default initialization strategy. - post_processing_func: Post processing function (such as rounding or clamping) - that is applied before choosing the final candidate. + optimize_acqf_loop_kwargs: A dictionary of keyword arguments for + `optimize_acqf`. These settings are used in the homotopy loop. + optimize_acqf_final_kwargs: A dictionary of keyword arguments for + `optimize_acqf`. These settings are used for the final optimization + after the homotopy loop. """ + if optimize_acqf_loop_kwargs is None: + optimize_acqf_loop_kwargs = {} + + if optimize_acqf_final_kwargs is None: + optimize_acqf_final_kwargs = {} + + for kwarg_dict_name, kwarg_dict in [("optimize_acqf_loop_kwargs", optimize_acqf_loop_kwargs), ("optimize_acqf_final_kwargs", optimize_acqf_final_kwargs)]: + + if "return_best_only" in kwarg_dict: + warnings.warn( + f"`return_best_only` is set to True in `{kwarg_dict_name}`, setting to False." + ) + kwarg_dict["return_best_only"] = False + + if "q" in kwarg_dict: + warnings.warn( + f"`q` is set in `{kwarg_dict_name}`, setting to 1." + ) + kwarg_dict["q"] = 1 + + if "batch_initial_conditions" in kwarg_dict: + warnings.warn( + f"`batch_initial_conditions` is set in `{kwarg_dict_name}`, setting to None." + ) + kwarg_dict.pop("batch_initial_conditions") + + for arg_name, arg_value in [("acq_function", acq_function), ("bounds", bounds)]: + if arg_name in kwarg_dict: + warnings.warn( + f"`{arg_name}` is set in `{kwarg_dict_name}` and will be " + "overridden in favor of the value in `optimize_acqf_homotopy`. " + f"({arg_name} = {arg_value} c.f. {kwarg_dict_name}[{arg_name}] = {kwarg_dict[arg_name]})" + ) + kwarg_dict[arg_name] = arg_value + + if "post_processing_func" in optimize_acqf_loop_kwargs: + warnings.warn( + "`post_processing_func` is set in `optimize_acqf_loop_kwargs`, setting to None." + ) + optimize_acqf_loop_kwargs["post_processing_func"] = None + candidate_list, acq_value_list = [], [] if q > 1: base_X_pending = acq_function.X_pending @@ -90,17 +125,7 @@ def optimize_acqf_homotopy( homotopy.restart() while not homotopy.should_stop: - candidates, acq_values = optimize_acqf( - q=1, - acq_function=acq_function, - bounds=bounds, - num_restarts=num_restarts, - batch_initial_conditions=candidates, - raw_samples=raw_samples, - fixed_features=fixed_features, - return_best_only=False, - options=options, - ) + candidates, acq_values = optimize_acqf(batch_initial_conditions=candidates, **optimize_acqf_loop_kwargs) homotopy.step() # Prune candidates @@ -111,20 +136,8 @@ def optimize_acqf_homotopy( ).unsqueeze(1) # Optimize one more time with the final options - candidates, acq_values = optimize_acqf( - q=1, - acq_function=acq_function, - bounds=bounds, - num_restarts=num_restarts, - batch_initial_conditions=candidates, - return_best_only=False, - options=final_options, - ) + candidates, acq_values = optimize_acqf(batch_initial_conditions=candidates, **optimize_acqf_final_kwargs) - # Post-process the candidates and grab the best candidate - if post_processing_func is not None: - candidates = post_processing_func(candidates) - acq_values = acq_function(candidates) best = torch.argmax(acq_values.view(-1), dim=0) candidate, acq_value = candidates[best], acq_values[best] From b7607b71c8dd2beeb8cc107c51a97473c6ae8d17 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 17 Oct 2024 10:32:57 -0400 Subject: [PATCH 2/5] fix: more warnings on miss-set args --- botorch/optim/optimize_homotopy.py | 48 ++++++++++++++++++++++++------ test/optim/test_homotopy.py | 32 ++++++++++++-------- 2 files changed, 59 insertions(+), 21 deletions(-) diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py index 1d1cdcad86..b527f9a910 100644 --- a/botorch/optim/optimize_homotopy.py +++ b/botorch/optim/optimize_homotopy.py @@ -81,23 +81,25 @@ def optimize_acqf_homotopy( if optimize_acqf_final_kwargs is None: optimize_acqf_final_kwargs = {} - for kwarg_dict_name, kwarg_dict in [("optimize_acqf_loop_kwargs", optimize_acqf_loop_kwargs), ("optimize_acqf_final_kwargs", optimize_acqf_final_kwargs)]: - + for kwarg_dict_name, kwarg_dict in [ + ("optimize_acqf_loop_kwargs", optimize_acqf_loop_kwargs), + ("optimize_acqf_final_kwargs", optimize_acqf_final_kwargs), + ]: if "return_best_only" in kwarg_dict: warnings.warn( f"`return_best_only` is set to True in `{kwarg_dict_name}`, setting to False." ) kwarg_dict["return_best_only"] = False - if "q" in kwarg_dict: - warnings.warn( - f"`q` is set in `{kwarg_dict_name}`, setting to 1." - ) + if kwarg_dict.get("q", None) != 1: + warnings.warn(f"`q` is set in `{kwarg_dict_name}`, setting to 1.") kwarg_dict["q"] = 1 if "batch_initial_conditions" in kwarg_dict: warnings.warn( - f"`batch_initial_conditions` is set in `{kwarg_dict_name}`, setting to None." + f"`batch_initial_conditions` is set in `{kwarg_dict_name}`, " + "removing it in favour of `batch_initial_conditions` given to " + "`optimize_acqf_homotopy`." ) kwarg_dict.pop("batch_initial_conditions") @@ -110,12 +112,29 @@ def optimize_acqf_homotopy( ) kwarg_dict[arg_name] = arg_value + if ( + batch_initial_conditions is None + and optimize_acqf_loop_kwargs.get("raw_samples", None) is None + ): + raise ValueError( + "Must specify `raw_samples` in `optimize_acqf_loop_kwargs` when " + "`batch_initial_conditions` is None`." + ) + if "post_processing_func" in optimize_acqf_loop_kwargs: warnings.warn( "`post_processing_func` is set in `optimize_acqf_loop_kwargs`, setting to None." ) optimize_acqf_loop_kwargs["post_processing_func"] = None + if "raw_samples" in optimize_acqf_final_kwargs: + warnings.warn( + "`raw_samples` is set in `optimize_acqf_final_kwargs`, " + "removing it as we set `batch_initial_conditions` to the candidates " + "returned by homotopy loop for the final optimization." + ) + optimize_acqf_final_kwargs.pop("raw_samples") + candidate_list, acq_value_list = [], [] if q > 1: base_X_pending = acq_function.X_pending @@ -125,7 +144,12 @@ def optimize_acqf_homotopy( homotopy.restart() while not homotopy.should_stop: - candidates, acq_values = optimize_acqf(batch_initial_conditions=candidates, **optimize_acqf_loop_kwargs) + candidates, acq_values = optimize_acqf( + acq_function=acq_function, + bounds=bounds, + batch_initial_conditions=candidates, + **optimize_acqf_loop_kwargs + ) homotopy.step() # Prune candidates @@ -136,7 +160,11 @@ def optimize_acqf_homotopy( ).unsqueeze(1) # Optimize one more time with the final options - candidates, acq_values = optimize_acqf(batch_initial_conditions=candidates, **optimize_acqf_final_kwargs) + candidates, acq_values = optimize_acqf( + acq_function=acq_function, + bounds=bounds, + batch_initial_conditions=candidates, **optimize_acqf_final_kwargs + ) best = torch.argmax(acq_values.view(-1), dim=0) candidate, acq_value = candidates[best], acq_values[best] @@ -145,6 +173,7 @@ def optimize_acqf_homotopy( candidate_list.append(candidate) acq_value_list.append(acq_value) selected_candidates = torch.cat(candidate_list, dim=-2) + if q > 1: acq_function.set_X_pending( torch.cat([base_X_pending, selected_candidates], dim=-2) @@ -154,6 +183,7 @@ def optimize_acqf_homotopy( if q > 1: # Reset acq_function to previous X_pending state acq_function.set_X_pending(base_X_pending) + homotopy.reset() # Reset the homotopy parameters return selected_candidates, torch.stack(acq_value_list) diff --git a/test/optim/test_homotopy.py b/test/optim/test_homotopy.py index ac3cd6142a..e2f2d1c611 100644 --- a/test/optim/test_homotopy.py +++ b/test/optim/test_homotopy.py @@ -114,14 +114,18 @@ def test_optimize_acqf_homotopy(self): ) model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2) acqf = PosteriorMean(model=model) + + optimize_acqf_core_kwargs = { + "num_restarts": 2, + "raw_samples": 16, + } candidate, acqf_val = optimize_acqf_homotopy( q=1, acq_function=acqf, bounds=torch.tensor([[-10], [5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - num_restarts=2, - raw_samples=16, - post_processing_func=lambda x: x.round(), + optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs, + optimize_acqf_final_kwargs=optimize_acqf_core_kwargs.update({"post_processing_func":lambda x: x.round()}), ) self.assertEqual(candidate, torch.zeros(1, **tkwargs)) self.assertEqual(acqf_val, 5 * torch.ones(1, **tkwargs)) @@ -137,9 +141,8 @@ def test_optimize_acqf_homotopy(self): acq_function=acqf, bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - num_restarts=2, - raw_samples=16, - fixed_features=fixed_features, + optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs.update({"fixed_features":fixed_features}), + optimize_acqf_final_kwargs=optimize_acqf_core_kwargs ) self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs)) @@ -150,13 +153,14 @@ def test_optimize_acqf_homotopy(self): acq_function=acqf, bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - num_restarts=2, - raw_samples=16, - fixed_features=fixed_features, + optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs.update({"fixed_features":fixed_features}), + optimize_acqf_final_kwargs=optimize_acqf_core_kwargs ) self.assertEqual(candidate.shape, torch.Size([3, 2])) self.assertEqual(acqf_val.shape, torch.Size([3])) + # with linear constraints + def test_prune_candidates(self): tkwargs = {"device": self.device, "dtype": torch.double} # no pruning @@ -202,14 +206,18 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock): ) model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2) acqf = PosteriorMean(model=model) + optimize_acqf_core_kwargs = { + "num_restarts":2, + "raw_samples":16, + } + candidate, acqf_val = optimize_acqf_homotopy( q=1, acq_function=acqf, bounds=torch.tensor([[-10], [5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - num_restarts=4, - raw_samples=16, - post_processing_func=lambda x: x.round(), + optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs, + optimize_acqf_final_kwargs=optimize_acqf_core_kwargs.update({"post_processing_func":lambda x: x.round()}), ) # First time we expect to call `prune_candidates` with 4 candidates self.assertEqual( From 038a735f3313ed796dfa134574e7c91c4fb68c4b Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 17 Oct 2024 10:43:05 -0400 Subject: [PATCH 3/5] fix: tweaks to set q --- botorch/optim/optimize_homotopy.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py index b527f9a910..87ba7b41cc 100644 --- a/botorch/optim/optimize_homotopy.py +++ b/botorch/optim/optimize_homotopy.py @@ -87,12 +87,12 @@ def optimize_acqf_homotopy( ]: if "return_best_only" in kwarg_dict: warnings.warn( - f"`return_best_only` is set to True in `{kwarg_dict_name}`, setting to False." + f"`return_best_only` is set to True in `{kwarg_dict_name}`, override to False." ) kwarg_dict["return_best_only"] = False if kwarg_dict.get("q", None) != 1: - warnings.warn(f"`q` is set in `{kwarg_dict_name}`, setting to 1.") + warnings.warn(f"`q` is not set to 1 in `{kwarg_dict_name}`, override to 1.") kwarg_dict["q"] = 1 if "batch_initial_conditions" in kwarg_dict: @@ -160,10 +160,12 @@ def optimize_acqf_homotopy( ).unsqueeze(1) # Optimize one more time with the final options + # NOTE is there any reason we don't want to pass fixed features to final? candidates, acq_values = optimize_acqf( acq_function=acq_function, bounds=bounds, - batch_initial_conditions=candidates, **optimize_acqf_final_kwargs + batch_initial_conditions=candidates, + **optimize_acqf_final_kwargs ) best = torch.argmax(acq_values.view(-1), dim=0) From fc544a6d5329c625c738fbec3b3d4a6d944cb234 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 17 Oct 2024 11:00:05 -0400 Subject: [PATCH 4/5] fix: restore the pruning test --- botorch/optim/optimize_homotopy.py | 2 +- test/optim/test_homotopy.py | 28 ++++++++++++++++++++-------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py index 87ba7b41cc..5f218eeca6 100644 --- a/botorch/optim/optimize_homotopy.py +++ b/botorch/optim/optimize_homotopy.py @@ -85,7 +85,7 @@ def optimize_acqf_homotopy( ("optimize_acqf_loop_kwargs", optimize_acqf_loop_kwargs), ("optimize_acqf_final_kwargs", optimize_acqf_final_kwargs), ]: - if "return_best_only" in kwarg_dict: + if kwarg_dict.get("return_best_only", None) is not False: warnings.warn( f"`return_best_only` is set to True in `{kwarg_dict_name}`, override to False." ) diff --git a/test/optim/test_homotopy.py b/test/optim/test_homotopy.py index e2f2d1c611..a531e525fb 100644 --- a/test/optim/test_homotopy.py +++ b/test/optim/test_homotopy.py @@ -125,7 +125,10 @@ def test_optimize_acqf_homotopy(self): bounds=torch.tensor([[-10], [5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs, - optimize_acqf_final_kwargs=optimize_acqf_core_kwargs.update({"post_processing_func":lambda x: x.round()}), + optimize_acqf_final_kwargs={ + **optimize_acqf_core_kwargs, + "post_processing_func": lambda x: x.round(), + }, ) self.assertEqual(candidate, torch.zeros(1, **tkwargs)) self.assertEqual(acqf_val, 5 * torch.ones(1, **tkwargs)) @@ -141,8 +144,11 @@ def test_optimize_acqf_homotopy(self): acq_function=acqf, bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs.update({"fixed_features":fixed_features}), - optimize_acqf_final_kwargs=optimize_acqf_core_kwargs + optimize_acqf_loop_kwargs={ + **optimize_acqf_core_kwargs, + "fixed_features": fixed_features, + }, + optimize_acqf_final_kwargs=optimize_acqf_core_kwargs, ) self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs)) @@ -153,8 +159,11 @@ def test_optimize_acqf_homotopy(self): acq_function=acqf, bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs.update({"fixed_features":fixed_features}), - optimize_acqf_final_kwargs=optimize_acqf_core_kwargs + optimize_acqf_loop_kwargs={ + **optimize_acqf_core_kwargs, + "fixed_features": fixed_features, + }, + optimize_acqf_final_kwargs=optimize_acqf_core_kwargs, ) self.assertEqual(candidate.shape, torch.Size([3, 2])) self.assertEqual(acqf_val.shape, torch.Size([3])) @@ -207,8 +216,8 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock): model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2) acqf = PosteriorMean(model=model) optimize_acqf_core_kwargs = { - "num_restarts":2, - "raw_samples":16, + "num_restarts": 4, + "raw_samples": 16, } candidate, acqf_val = optimize_acqf_homotopy( @@ -217,7 +226,10 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock): bounds=torch.tensor([[-10], [5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs, - optimize_acqf_final_kwargs=optimize_acqf_core_kwargs.update({"post_processing_func":lambda x: x.round()}), + optimize_acqf_final_kwargs={ + **optimize_acqf_core_kwargs, + "post_processing_func": lambda x: x.round(), + }, ) # First time we expect to call `prune_candidates` with 4 candidates self.assertEqual( From d088be672d6ea970145eed237faa5ad23c9da71a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 18 Oct 2024 10:49:48 -0400 Subject: [PATCH 5/5] test: add a test using linear constraint --- botorch/optim/optimize_homotopy.py | 5 +++-- test/optim/test_homotopy.py | 33 +++++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py index 5f218eeca6..980762d871 100644 --- a/botorch/optim/optimize_homotopy.py +++ b/botorch/optim/optimize_homotopy.py @@ -87,7 +87,7 @@ def optimize_acqf_homotopy( ]: if kwarg_dict.get("return_best_only", None) is not False: warnings.warn( - f"`return_best_only` is set to True in `{kwarg_dict_name}`, override to False." + f"`return_best_only` is not False in `{kwarg_dict_name}`, override to False." ) kwarg_dict["return_best_only"] = False @@ -101,6 +101,7 @@ def optimize_acqf_homotopy( "removing it in favour of `batch_initial_conditions` given to " "`optimize_acqf_homotopy`." ) + # are pops dangerious here given no copy? if repeatedly reusing kwarg_dict it could create issues kwarg_dict.pop("batch_initial_conditions") for arg_name, arg_value in [("acq_function", acq_function), ("bounds", bounds)]: @@ -133,7 +134,7 @@ def optimize_acqf_homotopy( "removing it as we set `batch_initial_conditions` to the candidates " "returned by homotopy loop for the final optimization." ) - optimize_acqf_final_kwargs.pop("raw_samples") + optimize_acqf_final_kwargs.pop("raw_samples") # are pops dangerious here given no copy? see above candidate_list, acq_value_list = [], [] if q > 1: diff --git a/test/optim/test_homotopy.py b/test/optim/test_homotopy.py index a531e525fb..8a71b7d30e 100644 --- a/test/optim/test_homotopy.py +++ b/test/optim/test_homotopy.py @@ -124,7 +124,7 @@ def test_optimize_acqf_homotopy(self): acq_function=acqf, bounds=torch.tensor([[-10], [5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs, + optimize_acqf_loop_kwargs={**optimize_acqf_core_kwargs}, optimize_acqf_final_kwargs={ **optimize_acqf_core_kwargs, "post_processing_func": lambda x: x.round(), @@ -146,9 +146,9 @@ def test_optimize_acqf_homotopy(self): homotopy=Homotopy(homotopy_parameters=[hp]), optimize_acqf_loop_kwargs={ **optimize_acqf_core_kwargs, - "fixed_features": fixed_features, + "fixed_features": fixed_features, # this is done to mimic old behaviour which was perhaps a bug? }, - optimize_acqf_final_kwargs=optimize_acqf_core_kwargs, + optimize_acqf_final_kwargs={**optimize_acqf_core_kwargs}, ) self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs)) @@ -163,12 +163,35 @@ def test_optimize_acqf_homotopy(self): **optimize_acqf_core_kwargs, "fixed_features": fixed_features, }, - optimize_acqf_final_kwargs=optimize_acqf_core_kwargs, + optimize_acqf_final_kwargs={**optimize_acqf_core_kwargs}, ) self.assertEqual(candidate.shape, torch.Size([3, 2])) self.assertEqual(acqf_val.shape, torch.Size([3])) # with linear constraints + constraints = [( # X[..., 0] + X[..., 1] >= 2. + torch.tensor([0, 1], device=self.device), + torch.ones(2, device=self.device, dtype=torch.double), + 2.0, + )] + + acqf = PosteriorMean(model=model) + candidate, acqf_val = optimize_acqf_homotopy( + q=1, + acq_function=acqf, + bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + homotopy=Homotopy(homotopy_parameters=[hp]), + optimize_acqf_loop_kwargs={ + **optimize_acqf_core_kwargs, + "inequality_constraints": constraints, + }, + optimize_acqf_final_kwargs={ + **optimize_acqf_core_kwargs, + "inequality_constraints": constraints, + }, + ) + self.assertEqual(candidate.shape, torch.Size([1, 2])) + self.assertGreaterEqual(candidate.sum(), 2 * torch.ones(1, **tkwargs)) def test_prune_candidates(self): tkwargs = {"device": self.device, "dtype": torch.double} @@ -225,7 +248,7 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock): acq_function=acqf, bounds=torch.tensor([[-10], [5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs, + optimize_acqf_loop_kwargs={**optimize_acqf_core_kwargs}, optimize_acqf_final_kwargs={ **optimize_acqf_core_kwargs, "post_processing_func": lambda x: x.round(),