Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose all kwargs to optimize_acqf in optimize_acqf_homotopy #2580

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 78 additions & 32 deletions botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from collections.abc import Callable
from typing import Any
import warnings

import torch
from botorch.acquisition import AcquisitionFunction
Expand Down Expand Up @@ -51,14 +52,11 @@
bounds: Tensor,
q: int,
homotopy: Homotopy,
num_restarts: int,
raw_samples: int | None = None,
fixed_features: dict[int, float] | None = None,
options: dict[str, bool | float | int | str] | None = None,
final_options: dict[str, bool | float | int | str] | None = None,
batch_initial_conditions: Tensor | None = None,
post_processing_func: Callable[[Tensor], Tensor] | None = None,
*,
prune_tolerance: float = 1e-4,
batch_initial_conditions: Tensor | None = None,
optimize_acqf_loop_kwargs: dict[str, Any] | None = None,
optimize_acqf_final_kwargs: dict[str, Any] | None = None,
) -> tuple[Tensor, Tensor]:
r"""Generate a set of candidates via multi-start optimization.

Expand All @@ -68,19 +66,76 @@
q: The number of candidates.
homotopy: Homotopy object that will make the necessary modifications to the
problem when calling `step()`.
num_restarts: The number of starting points for multistart acquisition
function optimization.
raw_samples: The number of samples for initialization. This is required
if `batch_initial_conditions` is not specified.
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
options: Options for candidate generation.
final_options: Options for candidate generation in the last homotopy step.
prune_tolerance: The minimum distance to prune candidates.
batch_initial_conditions: A tensor to specify the initial conditions. Set
this if you do not want to use default initialization strategy.
post_processing_func: Post processing function (such as rounding or clamping)
that is applied before choosing the final candidate.
optimize_acqf_loop_kwargs: A dictionary of keyword arguments for
`optimize_acqf`. These settings are used in the homotopy loop.
optimize_acqf_final_kwargs: A dictionary of keyword arguments for
`optimize_acqf`. These settings are used for the final optimization
after the homotopy loop.
"""
if optimize_acqf_loop_kwargs is None:
optimize_acqf_loop_kwargs = {}

Check warning on line 79 in botorch/optim/optimize_homotopy.py

View check run for this annotation

Codecov / codecov/patch

botorch/optim/optimize_homotopy.py#L79

Added line #L79 was not covered by tests

if optimize_acqf_final_kwargs is None:
optimize_acqf_final_kwargs = {}

Check warning on line 82 in botorch/optim/optimize_homotopy.py

View check run for this annotation

Codecov / codecov/patch

botorch/optim/optimize_homotopy.py#L82

Added line #L82 was not covered by tests

for kwarg_dict_name, kwarg_dict in [
("optimize_acqf_loop_kwargs", optimize_acqf_loop_kwargs),
("optimize_acqf_final_kwargs", optimize_acqf_final_kwargs),
]:
if kwarg_dict.get("return_best_only", None) is not False:
warnings.warn(
f"`return_best_only` is not False in `{kwarg_dict_name}`, override to False."
)
kwarg_dict["return_best_only"] = False

if kwarg_dict.get("q", None) != 1:
warnings.warn(f"`q` is not set to 1 in `{kwarg_dict_name}`, override to 1.")
kwarg_dict["q"] = 1

if "batch_initial_conditions" in kwarg_dict:
warnings.warn(

Check warning on line 99 in botorch/optim/optimize_homotopy.py

View check run for this annotation

Codecov / codecov/patch

botorch/optim/optimize_homotopy.py#L99

Added line #L99 was not covered by tests
f"`batch_initial_conditions` is set in `{kwarg_dict_name}`, "
"removing it in favour of `batch_initial_conditions` given to "
"`optimize_acqf_homotopy`."
)
# are pops dangerious here given no copy? if repeatedly reusing kwarg_dict it could create issues
kwarg_dict.pop("batch_initial_conditions")

Check warning on line 105 in botorch/optim/optimize_homotopy.py

View check run for this annotation

Codecov / codecov/patch

botorch/optim/optimize_homotopy.py#L105

Added line #L105 was not covered by tests

for arg_name, arg_value in [("acq_function", acq_function), ("bounds", bounds)]:
if arg_name in kwarg_dict:
warnings.warn(

Check warning on line 109 in botorch/optim/optimize_homotopy.py

View check run for this annotation

Codecov / codecov/patch

botorch/optim/optimize_homotopy.py#L109

Added line #L109 was not covered by tests
f"`{arg_name}` is set in `{kwarg_dict_name}` and will be "
"overridden in favor of the value in `optimize_acqf_homotopy`. "
f"({arg_name} = {arg_value} c.f. {kwarg_dict_name}[{arg_name}] = {kwarg_dict[arg_name]})"
)
kwarg_dict[arg_name] = arg_value

Check warning on line 114 in botorch/optim/optimize_homotopy.py

View check run for this annotation

Codecov / codecov/patch

botorch/optim/optimize_homotopy.py#L114

Added line #L114 was not covered by tests

if (
batch_initial_conditions is None
and optimize_acqf_loop_kwargs.get("raw_samples", None) is None
):
raise ValueError(

Check warning on line 120 in botorch/optim/optimize_homotopy.py

View check run for this annotation

Codecov / codecov/patch

botorch/optim/optimize_homotopy.py#L120

Added line #L120 was not covered by tests
"Must specify `raw_samples` in `optimize_acqf_loop_kwargs` when "
"`batch_initial_conditions` is None`."
)

if "post_processing_func" in optimize_acqf_loop_kwargs:
warnings.warn(

Check warning on line 126 in botorch/optim/optimize_homotopy.py

View check run for this annotation

Codecov / codecov/patch

botorch/optim/optimize_homotopy.py#L126

Added line #L126 was not covered by tests
"`post_processing_func` is set in `optimize_acqf_loop_kwargs`, setting to None."
)
optimize_acqf_loop_kwargs["post_processing_func"] = None

Check warning on line 129 in botorch/optim/optimize_homotopy.py

View check run for this annotation

Codecov / codecov/patch

botorch/optim/optimize_homotopy.py#L129

Added line #L129 was not covered by tests

if "raw_samples" in optimize_acqf_final_kwargs:
warnings.warn(
"`raw_samples` is set in `optimize_acqf_final_kwargs`, "
"removing it as we set `batch_initial_conditions` to the candidates "
"returned by homotopy loop for the final optimization."
)
optimize_acqf_final_kwargs.pop("raw_samples") # are pops dangerious here given no copy? see above

candidate_list, acq_value_list = [], []
if q > 1:
base_X_pending = acq_function.X_pending
Expand All @@ -91,15 +146,10 @@

while not homotopy.should_stop:
candidates, acq_values = optimize_acqf(
q=1,
acq_function=acq_function,
bounds=bounds,
num_restarts=num_restarts,
batch_initial_conditions=candidates,
raw_samples=raw_samples,
fixed_features=fixed_features,
return_best_only=False,
options=options,
**optimize_acqf_loop_kwargs
)
homotopy.step()

Expand All @@ -111,27 +161,22 @@
).unsqueeze(1)

# Optimize one more time with the final options
# NOTE is there any reason we don't want to pass fixed features to final?
candidates, acq_values = optimize_acqf(
q=1,
acq_function=acq_function,
bounds=bounds,
num_restarts=num_restarts,
batch_initial_conditions=candidates,
return_best_only=False,
options=final_options,
**optimize_acqf_final_kwargs
)

# Post-process the candidates and grab the best candidate
if post_processing_func is not None:
candidates = post_processing_func(candidates)
acq_values = acq_function(candidates)
best = torch.argmax(acq_values.view(-1), dim=0)
candidate, acq_value = candidates[best], acq_values[best]

# Keep the new candidate and update the pending points
candidate_list.append(candidate)
acq_value_list.append(acq_value)
selected_candidates = torch.cat(candidate_list, dim=-2)

if q > 1:
acq_function.set_X_pending(
torch.cat([base_X_pending, selected_candidates], dim=-2)
Expand All @@ -141,6 +186,7 @@

if q > 1: # Reset acq_function to previous X_pending state
acq_function.set_X_pending(base_X_pending)

homotopy.reset() # Reset the homotopy parameters

return selected_candidates, torch.stack(acq_value_list)
67 changes: 55 additions & 12 deletions test/optim/test_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,21 @@ def test_optimize_acqf_homotopy(self):
)
model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2)
acqf = PosteriorMean(model=model)

optimize_acqf_core_kwargs = {
"num_restarts": 2,
"raw_samples": 16,
}
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10], [5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
post_processing_func=lambda x: x.round(),
optimize_acqf_loop_kwargs={**optimize_acqf_core_kwargs},
optimize_acqf_final_kwargs={
**optimize_acqf_core_kwargs,
"post_processing_func": lambda x: x.round(),
},
)
self.assertEqual(candidate, torch.zeros(1, **tkwargs))
self.assertEqual(acqf_val, 5 * torch.ones(1, **tkwargs))
Expand All @@ -137,9 +144,11 @@ def test_optimize_acqf_homotopy(self):
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
fixed_features=fixed_features,
optimize_acqf_loop_kwargs={
**optimize_acqf_core_kwargs,
"fixed_features": fixed_features, # this is done to mimic old behaviour which was perhaps a bug?
},
optimize_acqf_final_kwargs={**optimize_acqf_core_kwargs},
)
self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs))

Expand All @@ -150,13 +159,40 @@ def test_optimize_acqf_homotopy(self):
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
fixed_features=fixed_features,
optimize_acqf_loop_kwargs={
**optimize_acqf_core_kwargs,
"fixed_features": fixed_features,
},
optimize_acqf_final_kwargs={**optimize_acqf_core_kwargs},
)
self.assertEqual(candidate.shape, torch.Size([3, 2]))
self.assertEqual(acqf_val.shape, torch.Size([3]))

# with linear constraints
constraints = [( # X[..., 0] + X[..., 1] >= 2.
torch.tensor([0, 1], device=self.device),
torch.ones(2, device=self.device, dtype=torch.double),
2.0,
)]

acqf = PosteriorMean(model=model)
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
optimize_acqf_loop_kwargs={
**optimize_acqf_core_kwargs,
"inequality_constraints": constraints,
},
optimize_acqf_final_kwargs={
**optimize_acqf_core_kwargs,
"inequality_constraints": constraints,
},
)
self.assertEqual(candidate.shape, torch.Size([1, 2]))
self.assertGreaterEqual(candidate.sum(), 2 * torch.ones(1, **tkwargs))

def test_prune_candidates(self):
tkwargs = {"device": self.device, "dtype": torch.double}
# no pruning
Expand Down Expand Up @@ -202,14 +238,21 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock):
)
model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2)
acqf = PosteriorMean(model=model)
optimize_acqf_core_kwargs = {
"num_restarts": 4,
"raw_samples": 16,
}

candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10], [5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=4,
raw_samples=16,
post_processing_func=lambda x: x.round(),
optimize_acqf_loop_kwargs={**optimize_acqf_core_kwargs},
optimize_acqf_final_kwargs={
**optimize_acqf_core_kwargs,
"post_processing_func": lambda x: x.round(),
},
)
# First time we expect to call `prune_candidates` with 4 candidates
self.assertEqual(
Expand Down
Loading