Skip to content

Commit

Permalink
fea: add test that checks that we don't call get_batch_initial_condit…
Browse files Browse the repository at this point in the history
…ions if doing L0 norm
  • Loading branch information
CompRhys committed Dec 13, 2024
1 parent 366c59a commit 6b9a6b6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ax/models/torch/botorch_modular/sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
optimize_acqf_homotopy,
)
from botorch.utils.datasets import SupervisedDataset
from pyre_extensions import none_throws
from pyre_extensions import assert_is_instance, none_throws
from torch import Tensor

CLAMP_TOL = 1e-2
Expand Down Expand Up @@ -296,14 +296,14 @@ def _optimize_with_homotopy(
],
)

if "batch_initial_conditions" not in optimizer_options:
optimizer_options["batch_initial_conditions"] = (
if "batch_initial_conditions" not in optimizer_options_with_defaults:
optimizer_options_with_defaults["batch_initial_conditions"] = (
get_batch_initial_conditions(
acq_function=self.acqf,
raw_samples=optimizer_options_with_defaults["raw_samples"],
inequality_constraints=inequality_constraints,
fixed_features=fixed_features,
X_pareto=self.acqf.X_baseline,
X_pareto=assert_is_instance(self.acqf.X_baseline, Tensor),
target_point=self.target_point,
bounds=bounds,
num_restarts=optimizer_options_with_defaults["num_restarts"],
Expand Down
40 changes: 40 additions & 0 deletions ax/models/torch/tests/test_sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,43 @@ def test_get_batch_initial_conditions(
self.assertEqual(batch_initial_conditions.shape, torch.Size([3, 1, 3]))
self.assertTrue(torch.all(batch_initial_conditions[:1] != 0.5))
self.assertTrue(torch.all(batch_initial_conditions[1:, :, 1] == 0.5))

@mock.patch(f"{SEBOACQUISITION_PATH}.optimize_acqf_homotopy")
@mock.patch(
f"{SEBOACQUISITION_PATH}.get_batch_initial_conditions",
wraps=get_batch_initial_conditions,
)
def test_optimize_with_provided_batch_initial_conditions(
self, mock_get_batch_initial_conditions: Mock, mock_optimize_acqf_homotopy: Mock
) -> None:
mock_optimize_acqf_homotopy.return_value = (
torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.double),
torch.tensor([1.0], dtype=torch.double),
)

# Create batch initial conditions
batch_ics = torch.rand(3, 1, 3, dtype=torch.double)

acquisition = self.get_acquisition_function(
options={
"target_point": self.target_point,
"penalty": "L0_norm",
},
)

acquisition.optimize(
n=1,
search_space_digest=self.search_space_digest,
optimizer_options={
"batch_initial_conditions": batch_ics,
Keys.NUM_RESTARTS: 3,
Keys.RAW_SAMPLES: 32,
},
)

# Verify get_batch_initial_conditions was not called
mock_get_batch_initial_conditions.assert_not_called()

# Verify the batch_initial_conditions were passed to optimize_acqf_homotopy
call_kwargs = mock_optimize_acqf_homotopy.call_args[1]
self.assertTrue(torch.equal(call_kwargs["batch_initial_conditions"], batch_ics))

0 comments on commit 6b9a6b6

Please sign in to comment.