Skip to content

Commit

Permalink
Check for feasibility in gen_candidates_scipy and error out for infea…
Browse files Browse the repository at this point in the history
…sible candidates (#2737)

Summary:

As titled. Previously, it was possible to return infeasible candidates to the user, with or without warnings alerting the user to the issue. This diff makes it so that the optimizer will error out when infeasible candidates are generated, so that the user can adjust the setup as needed.

Resolves #2708

Also includes a couple lint fixes in optimizer tests.

Reviewed By: esantorella

Differential Revision: D69314159
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 7, 2025
1 parent 8770fa4 commit 8e2a8c1
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 67 deletions.
46 changes: 23 additions & 23 deletions botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,19 @@
import numpy.typing as npt
import torch
from botorch.acquisition import AcquisitionFunction
from botorch.exceptions.errors import OptimizationGradientError
from botorch.exceptions.errors import (
CandidateGenerationError,
OptimizationGradientError,
)
from botorch.exceptions.warnings import OptimizationWarning
from botorch.generation.utils import _remove_fixed_features_from_optimization
from botorch.logging import logger
from botorch.optim.parameter_constraints import (
_arrayify,
evaluate_feasibility,
make_scipy_bounds,
make_scipy_linear_constraints,
make_scipy_nonlinear_inequality_constraints,
nonlinear_constraint_is_feasible,
)
from botorch.optim.stopping import ExpMAStoppingCriterion
from botorch.optim.utils import columnwise_clamp, fix_features
Expand Down Expand Up @@ -237,11 +240,12 @@ def f_np_wrapper(x: npt.NDArray, f: Callable):
def f(x):
return -acquisition_function(x)

method = options.get("method", "SLSQP" if constraints else "L-BFGS-B")
res = minimize_with_timeout(
fun=f_np_wrapper,
args=(f,),
x0=x0,
method=options.get("method", "SLSQP" if constraints else "L-BFGS-B"),
method=method,
jac=with_grad,
bounds=bounds,
constraints=constraints,
Expand All @@ -260,26 +264,22 @@ def f(x):
fixed_features=fixed_features,
)

# SLSQP sometimes fails in the line search or may just fail to find a feasible
# candidate in which case we just return the starting point. This happens rarely,
# so it shouldn't be an issue given enough restarts.
if nonlinear_inequality_constraints:
for con, is_intrapoint in nonlinear_inequality_constraints:
if not (
feasible := nonlinear_constraint_is_feasible(
con, is_intrapoint=is_intrapoint, x=candidates
)
).all():
# Replace the infeasible batches with feasible ICs.
candidates[~feasible] = (
torch.from_numpy(x0).to(candidates).reshape(shapeX)[~feasible]
)
warnings.warn(
"SLSQP failed to converge to a solution the satisfies the "
"non-linear constraints. Returning the feasible starting point.",
OptimizationWarning,
stacklevel=2,
)
# SLSQP can sometimes fail to produce a feasible candidate. Check for
# feasibility and error out if necessary.
if not (
is_feasible := evaluate_feasibility(
X=candidates,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
)
).all():
raise CandidateGenerationError(
f"The {method} optimizer produced infeasible candidates. "
f"{(~is_feasible).sum().item()} out of {is_feasible.numel()} batches "
"of candidates were infeasible. Please make sure the constraints are "
"satisfiable and relax them if needed. "
)

clamped_candidates = columnwise_clamp(
X=candidates, lower=lower_bounds, upper=upper_bounds, raise_on_violation=True
Expand Down
2 changes: 1 addition & 1 deletion botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ def optimize_acqf_mixed(

if isinstance(acq_function, OneShotAcquisitionFunction):
if not hasattr(acq_function, "evaluate") and q > 1:
raise ValueError(
raise UnsupportedError(
"`OneShotAcquisitionFunction`s that do not implement `evaluate` "
"are currently not supported when `q > 1`. This is needed to "
"compute the joint acquisition value."
Expand Down
18 changes: 11 additions & 7 deletions botorch/optim/parameter_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,11 +647,11 @@ def evaluate_feasibility(
intra-point or inter-point constraint (`True` for intra-point. `False` for
inter-point). For more information on intra-point vs inter-point
constraints, see the docstring of the `inequality_constraints` argument.
tolerance: The tolerance used to check the feasibility of equality constraints
and non-linear inequality constraints. For equality constraints, we check
if `abs(const(X) - rhs) < tolerance`. For non-linear inequality constraints,
we check if `const(X) >= -tolerance`. This avoids marking the candidates as
infeasible due to tiny violations.
tolerance: The tolerance used to check the feasibility of constraints.
For inequality constraints, we check if `const(X) >= rhs - tolerance`.
For equality constraints, we check if `abs(const(X) - rhs) < tolerance`.
For non-linear inequality constraints, we check if `const(X) >= -tolerance`.
This avoids marking the candidates as infeasible due to tiny violations.
Returns:
A boolean tensor of shape `batch` indicating if the corresponding candidate of
Expand All @@ -662,10 +662,14 @@ def evaluate_feasibility(
for idx, coef, rhs in inequality_constraints:
if idx.ndim == 1:
# Intra-point constraints.
is_feasible &= ((X[..., idx] * coef).sum(dim=-1) >= rhs).all(dim=-1)
is_feasible &= (
(X[..., idx] * coef).sum(dim=-1) >= rhs - tolerance
).all(dim=-1)
else:
# Inter-point constraints.
is_feasible &= (X[..., idx[:, 0], idx[:, 1]] * coef).sum(dim=-1) >= rhs
is_feasible &= (X[..., idx[:, 0], idx[:, 1]] * coef).sum(
dim=-1
) >= rhs - tolerance
if equality_constraints is not None:
for idx, coef, rhs in equality_constraints:
if idx.ndim == 1:
Expand Down
10 changes: 5 additions & 5 deletions botorch/test_utils/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def mock_optimize_context_manager(
USE RESPONSIBLY.
"""

def one_iteration_minimize(*args: Any, **kwargs: Any) -> OptimizeResult:
def two_iteration_minimize(*args: Any, **kwargs: Any) -> OptimizeResult:
if kwargs["options"] is None:
kwargs["options"] = {}

kwargs["options"]["maxiter"] = 1
# Using two iterations here to allow SLSQP to adapt to constraints.
kwargs["options"]["maxiter"] = 2
return minimize_with_timeout(*args, **kwargs)

def minimal_gen_ics(*args: Any, **kwargs: Any) -> Tensor:
Expand All @@ -64,7 +64,7 @@ def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None:
mock_generation = es.enter_context(
mock.patch(
"botorch.generation.gen.minimize_with_timeout",
wraps=one_iteration_minimize,
wraps=two_iteration_minimize,
)
)

Expand All @@ -73,7 +73,7 @@ def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None:
mock_fit = es.enter_context(
mock.patch(
"botorch.optim.core.minimize_with_timeout",
wraps=one_iteration_minimize,
wraps=two_iteration_minimize,
)
)

Expand Down
38 changes: 35 additions & 3 deletions test/generation/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

import torch
from botorch.acquisition import qExpectedImprovement, qKnowledgeGradient
from botorch.exceptions.errors import OptimizationGradientError
from botorch.exceptions.errors import (
CandidateGenerationError,
OptimizationGradientError,
)
from botorch.exceptions.warnings import OptimizationWarning
from botorch.fit import fit_gpytorch_mll
from botorch.generation.gen import (
Expand Down Expand Up @@ -211,8 +214,16 @@ def test_gen_candidates_scipy_with_fixed_features_inequality_constraints(self):
initial_conditions=self.initial_conditions.reshape(1, 1, -1),
acquisition_function=qEI,
inequality_constraints=[
(torch.tensor([0]), torch.tensor([1]), 0),
(torch.tensor([1]), torch.tensor([-1]), -1),
(
torch.tensor([0], device=self.device),
torch.tensor([1], device=self.device),
0,
),
(
torch.tensor([1], device=self.device),
torch.tensor([-1], device=self.device),
-1,
),
],
fixed_features={1: 0.25},
options=options,
Expand Down Expand Up @@ -378,6 +389,27 @@ def test_gen_candidates_scipy_invalid_method(self) -> None:
upper_bounds=1,
)

def test_gen_candidates_scipy_infeasible_candidates(self) -> None:
# Check for error when infeasible candidates are generated.
ics = torch.rand(2, 3, 1, device=self.device)
with mock.patch(
"botorch.generation.gen.minimize_with_timeout",
return_value=OptimizeResult(x=ics.view(-1).cpu().numpy()),
), self.assertRaisesRegex(
CandidateGenerationError, "infeasible candidates. 2 out of 2"
):
gen_candidates_scipy(
initial_conditions=ics,
acquisition_function=MockAcquisitionFunction(),
inequality_constraints=[
( # X[..., 0] >= 2.0, which is infeasible.
torch.tensor([0], device=self.device),
torch.tensor([1.0], device=self.device),
2.0,
)
],
)


class TestRandomRestartOptimization(TestBaseCandidateGeneration):
def test_random_restart_optimization(self):
Expand Down
34 changes: 6 additions & 28 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,10 +623,10 @@ def test_optimize_acqf_batch_limit(self) -> None:

for ic_shape, expected_shape in [((2, 1, dim), 2), ((2, dim), 1)]:
with self.subTest(gen_candidates=gen_candidates):
ics = torch.ones((ic_shape))
with self.assertWarnsRegex(
RuntimeWarning, "botorch will default to old behavior"
):
ics = torch.ones((ic_shape))
_candidates, acq_value_list = optimize_acqf(
acq_function=SinOneOverXAcqusitionFunction(),
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
Expand All @@ -638,8 +638,7 @@ def test_optimize_acqf_batch_limit(self) -> None:
gen_candidates=gen_candidates,
batch_initial_conditions=ics,
)

self.assertEqual(acq_value_list.shape, (expected_shape,))
self.assertEqual(acq_value_list.shape, (expected_shape,))

def test_optimize_acqf_runs_given_batch_initial_conditions(self):
num_restarts, raw_samples, dim = 1, 2, 3
Expand Down Expand Up @@ -915,27 +914,6 @@ def nlc1(x):
torch.allclose(acq_value, torch.tensor([4], **tkwargs), atol=1e-3)
)

# Make sure we return the initial solution if SLSQP fails to return
# a feasible point.
with mock.patch(
"botorch.generation.gen.minimize_with_timeout"
) as mock_minimize:
# By setting "success" to True and "status" to 0, we prevent a
# warning that `minimize` failed, which isn't the behavior
# we're looking to test here.
mock_minimize.return_value = OptimizeResult(
x=np.array([4, 4, 4]), success=True, status=0
)
candidates, acq_value = optimize_acqf(
acq_function=mock_acq_function,
bounds=bounds,
q=1,
nonlinear_inequality_constraints=[(nlc1, True)],
batch_initial_conditions=batch_initial_conditions,
num_restarts=1,
)
self.assertAllClose(candidates, batch_initial_conditions[0, ...])

# Constrain all variables to be >= 1. The global optimum is 2.45 and
# is attained by some permutation of [1, 1, 2]
def nlc2(x):
Expand Down Expand Up @@ -1685,10 +1663,10 @@ def test_optimize_acqf_mixed_q2(self, mock_optimize_acqf):
self.assertTrue(torch.equal(acq_value, expected_acq_value))

def test_optimize_acqf_mixed_empty_ff(self):
mock_acq_function = MockAcquisitionFunction()
with self.assertRaisesRegex(
ValueError, expected_regex="fixed_features_list must be non-empty."
):
mock_acq_function = MockAcquisitionFunction()
optimize_acqf_mixed(
acq_function=mock_acq_function,
q=1,
Expand All @@ -1715,9 +1693,9 @@ def test_optimize_acqf_mixed_return_best_only_q2(self):
)

def test_optimize_acqf_one_shot_large_q(self):
with self.assertRaises(ValueError):
mock_acq_function = MockOneShotAcquisitionFunction()
fixed_features_list = [{i: i * 0.1} for i in range(2)]
mock_acq_function = MockOneShotAcquisitionFunction()
fixed_features_list = [{i: i * 0.1} for i in range(2)]
with self.assertRaisesRegex(UnsupportedError, "OneShotAcquisitionFunction"):
optimize_acqf_mixed(
acq_function=mock_acq_function,
q=2,
Expand Down

0 comments on commit 8e2a8c1

Please sign in to comment.