From d48d4b7f5fa95124a28a9f0c37513c42e32765e1 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Fri, 3 Jan 2025 21:35:01 -0800 Subject: [PATCH] Make fitting compatible with scipy 1.15 optimization changes (#2667) Summary: Resolves https://github.com/pytorch/botorch/issues/2666 by updating the regexp that is used to filter out the optimization warning emitted when the maximium number of iterations is reached. NOTE: This kind of regexp filtering is kind of brittle, but it's not necessarily obvious how to do this differently if scipy doesn't return these in a more structured form. Pull Request resolved: https://github.com/pytorch/botorch/pull/2667 Reviewed By: saitcakmak Differential Revision: D67816653 Pulled By: Balandat fbshipit-source-id: aa6f83298e50e4e1ebedde7893888ecd82c4bae1 --- botorch/optim/core.py | 4 ++- requirements.txt | 2 +- test/generation/test_gen.py | 8 ++++-- test/optim/test_core.py | 17 ++++++++--- test/optim/test_fit.py | 13 +++++++-- test/optim/test_optimize.py | 55 ++++++++++++++++++++---------------- test/test_fit.py | 17 +++++++++-- test/test_utils/test_mock.py | 11 ++++++-- 8 files changed, 86 insertions(+), 41 deletions(-) diff --git a/botorch/optim/core.py b/botorch/optim/core.py index 73d18e8a9e..e2062a3b73 100644 --- a/botorch/optim/core.py +++ b/botorch/optim/core.py @@ -34,7 +34,9 @@ _LBFGSB_MAXITER_MAXFUN_REGEX = re.compile( # regex for maxiter and maxfun messages - "TOTAL NO. of (ITERATIONS REACHED LIMIT|f AND g EVALUATIONS EXCEEDS LIMIT)" + # Note that the messages changed with scipy 1.15, hence the different matching here. + "TOTAL NO. (of|OF) " + + "(ITERATIONS REACHED LIMIT|(f AND g|F,G) EVALUATIONS EXCEEDS LIMIT)" ) diff --git a/requirements.txt b/requirements.txt index ffe12962c0..61559fe624 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,5 @@ gpytorch==1.13 linear_operator==0.5.3 torch>=2.0.1 pyro-ppl>=1.8.4 -scipy<1.15 +scipy multipledispatch diff --git a/test/generation/test_gen.py b/test/generation/test_gen.py index eb12bbd32e..255b7c0dac 100644 --- a/test/generation/test_gen.py +++ b/test/generation/test_gen.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import math +import re import warnings from unittest import mock @@ -225,13 +226,14 @@ def test_gen_candidates_scipy_with_fixed_features_inequality_constraints(self): def test_gen_candidates_scipy_warns_opt_failure(self): with warnings.catch_warnings(record=True) as ws: self.test_gen_candidates(options={"maxls": 1}) - expected_msg = ( + expected_msg = re.compile( + # The message changed with scipy 1.15, hence the different matching here. "Optimization failed within `scipy.optimize.minimize` with status 2" - " and message ABNORMAL_TERMINATION_IN_LNSRCH." + " and message ABNORMAL(|_TERMINATION_IN_LNSRCH)." ) expected_warning_raised = any( issubclass(w.category, OptimizationWarning) - and expected_msg in str(w.message) + and expected_msg.search(str(w.message)) for w in ws ) self.assertTrue(expected_warning_raised) diff --git a/test/optim/test_core.py b/test/optim/test_core.py index 4288980959..b3a7225a97 100644 --- a/test/optim/test_core.py +++ b/test/optim/test_core.py @@ -135,11 +135,20 @@ def _callback(parameters, result, out) -> None: def test_post_processing(self): closure = next(iter(self.closures.values())) wrapper = NdarrayOptimizationClosure(closure, closure.parameters) + + # Scipy changed return values and messages in v1.15, so we check both + # old and new versions here. + status_msgs = [ + # scipy >=1.15 + (OptimizationStatus.FAILURE, "ABNORMAL_TERMINATION_IN_LNSRCH"), + (OptimizationStatus.STOPPED, "TOTAL NO. of ITERATIONS REACHED LIMIT"), + # scipy <1.15 + (OptimizationStatus.FAILURE, "ABNORMAL "), + (OptimizationStatus.STOPPED, "TOTAL NO. OF ITERATIONS REACHED LIMIT"), + ] + with patch.object(core, "minimize_with_timeout") as mock_minimize_with_timeout: - for status, msg in ( - (OptimizationStatus.FAILURE, b"ABNORMAL_TERMINATION_IN_LNSRCH"), - (OptimizationStatus.STOPPED, "TOTAL NO. of ITERATIONS REACHED LIMIT"), - ): + for status, msg in status_msgs: mock_minimize_with_timeout.return_value = OptimizeResult( x=wrapper.state, fun=1.0, diff --git a/test/optim/test_fit.py b/test/optim/test_fit.py index a4e0c6f6dc..fa2c2d3540 100644 --- a/test/optim/test_fit.py +++ b/test/optim/test_fit.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import math +import re from unittest.mock import MagicMock, patch from warnings import catch_warnings @@ -20,6 +21,11 @@ from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from scipy.optimize import OptimizeResult +MAX_ITER_MSG_REGEX = re.compile( + # Note that the message changed with scipy 1.15, hence the different matching here. + "TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT" +) + class TestFitGPyTorchMLLScipy(BotorchTestCase): def setUp(self, suppress_input_warnings: bool = True) -> None: @@ -63,7 +69,8 @@ def _test_fit_gpytorch_mll_scipy(self, mll): ) # Test maxiter warning message - self.assertTrue(any("TOTAL NO. of" in str(w.message) for w in ws)) + + self.assertTrue(any(MAX_ITER_MSG_REGEX.search(str(w.message)) for w in ws)) self.assertTrue( any(issubclass(w.category, OptimizationWarning) for w in ws) ) @@ -71,7 +78,9 @@ def _test_fit_gpytorch_mll_scipy(self, mll): # Test iteration tracking self.assertIsInstance(result, OptimizationResult) self.assertLessEqual(result.step, options["maxiter"]) - self.assertEqual(sum(1 for w in ws if "TOTAL NO. of" in str(w.message)), 1) + self.assertEqual( + sum(1 for w in ws if MAX_ITER_MSG_REGEX.search(str(w.message))), 1 + ) # Test that user provided bounds are respected with self.subTest("bounds"), module_rollback_ctx(mll, checkpoint=ckpt): diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 4cb541722b..2971b32137 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import itertools +import re import warnings from functools import partial from itertools import product @@ -724,19 +725,20 @@ def test_optimize_acqf_warns_on_opt_failure(self): raw_samples=raw_samples, batch_initial_conditions=initial_conditions, ) - message = ( - "Optimization failed in `gen_candidates_scipy` with the following " - "warning(s):\n[OptimizationWarning('Optimization failed within " - "`scipy.optimize.minimize` with status 2 and message " - "ABNORMAL_TERMINATION_IN_LNSRCH.')]\nBecause you specified " - "`batch_initial_conditions` larger than required `num_restarts`, " - "optimization will not be retried with new initial conditions and " - "will proceed with the current solution. Suggested remediation: " - "Try again with different `batch_initial_conditions`, don't provide " - "`batch_initial_conditions`, or increase `num_restarts`." + message_regex = re.compile( + r"Optimization failed in `gen_candidates_scipy` with the following " + r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within " + r"`scipy.optimize.minimize` with status 2 and message " + r"ABNORMAL(: |_TERMINATION_IN_LNSRCH).'\)]\nBecause you specified " + r"`batch_initial_conditions` larger than required `num_restarts`, " + r"optimization will not be retried with new initial conditions and " + r"will proceed with the current solution. Suggested remediation: " + r"Try again with different `batch_initial_conditions`, don't provide " + r"`batch_initial_conditions`, or increase `num_restarts`." ) expected_warning_raised = any( - issubclass(w.category, RuntimeWarning) and message in str(w.message) + issubclass(w.category, RuntimeWarning) + and message_regex.search(str(w.message)) for w in ws ) self.assertTrue(expected_warning_raised) @@ -774,14 +776,16 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self): # more likely options={"maxls": 2}, ) - message = ( - "Optimization failed in `gen_candidates_scipy` with the following " - "warning(s):\n[OptimizationWarning('Optimization failed within " - "`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION" - "_IN_LNSRCH.')]\nTrying again with a new set of initial conditions." + message_regex = re.compile( + r"Optimization failed in `gen_candidates_scipy` with the following " + r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within " + r"`scipy.optimize.minimize` with status 2 and message ABNORMAL(: |" + r"_TERMINATION_IN_LNSRCH).'\)\]\nTrying again with a new set of " + r"initial conditions." ) expected_warning_raised = any( - issubclass(w.category, RuntimeWarning) and message in str(w.message) + issubclass(w.category, RuntimeWarning) + and message_regex.search(str(w.message)) for w in ws ) self.assertTrue(expected_warning_raised) @@ -803,7 +807,8 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self): retry_on_optimization_warning=False, ) expected_warning_raised = any( - issubclass(w.category, RuntimeWarning) and message in str(w.message) + issubclass(w.category, RuntimeWarning) + and message_regex.search(str(w.message)) for w in ws ) self.assertFalse(expected_warning_raised) @@ -840,11 +845,12 @@ def test_optimize_acqf_warns_on_second_opt_failure(self): options={"maxls": 2}, ) - message_1 = ( - "Optimization failed in `gen_candidates_scipy` with the following " - "warning(s):\n[OptimizationWarning('Optimization failed within " - "`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION" - "_IN_LNSRCH.')]\nTrying again with a new set of initial conditions." + message_1_regex = re.compile( + r"Optimization failed in `gen_candidates_scipy` with the following " + r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within " + r"`scipy.optimize.minimize` with status 2 and message ABNORMAL(: |" + r"_TERMINATION_IN_LNSRCH).'\)\]\nTrying again with a new set of " + r"initial conditions." ) message_2 = ( @@ -852,7 +858,8 @@ def test_optimize_acqf_warns_on_second_opt_failure(self): "of initial conditions." ) first_expected_warning_raised = any( - issubclass(w.category, RuntimeWarning) and message_1 in str(w.message) + issubclass(w.category, RuntimeWarning) + and message_1_regex.search(str(w.message)) for w in ws ) second_expected_warning_raised = any( diff --git a/test/test_fit.py b/test/test_fit.py index b161bff537..c5bc0389cf 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import math +import re from collections.abc import Callable, Iterable from contextlib import ExitStack, nullcontext from copy import deepcopy @@ -30,7 +31,10 @@ from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO from linear_operator.utils.errors import NotPSDError -MAX_ITER_MSG = "TOTAL NO. of ITERATIONS REACHED LIMIT" +MAX_ITER_MSG_REGEX = re.compile( + # Note that the message changed with scipy 1.15, hence the different matching here. + "TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT" +) class MockOptimizer: @@ -215,7 +219,12 @@ def _test_warnings(self, mll, ckpt): optimizer = MockOptimizer(randomize_requires_grad=False) optimizer.warnings = [ WarningMessage("test_runtime_warning", RuntimeWarning, __file__, 0), - WarningMessage(MAX_ITER_MSG, OptimizationWarning, __file__, 0), + WarningMessage( + "STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT", + OptimizationWarning, + __file__, + 0, + ), WarningMessage( "Optimization timed out after X", OptimizationWarning, __file__, 0 ), @@ -260,7 +269,9 @@ def _test_warnings(self, mll, ckpt): {str(w.message) for w in rethrown + unresolved}, ) if logs: # test that default filter logs certain warnings - self.assertTrue(any(MAX_ITER_MSG in log for log in logs.output)) + self.assertTrue( + any(MAX_ITER_MSG_REGEX.search(log) for log in logs.output) + ) # Test default of retrying upon encountering an uncaught OptimizationWarning optimizer.warnings.append( diff --git a/test/test_utils/test_mock.py b/test/test_utils/test_mock.py index 43867bbeea..e72b5d9912 100644 --- a/test/test_utils/test_mock.py +++ b/test/test_utils/test_mock.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. +import re import warnings from unittest.mock import patch @@ -26,6 +27,12 @@ from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction +MAX_ITER_MSG = re.compile( + # Note that the message changed with scipy 1.15, hence the different matching here. + "TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT" +) + + class SinAcqusitionFunction(MockAcquisitionFunction): """Simple acquisition function with known numerical properties.""" @@ -56,9 +63,7 @@ def closure(): with mock_optimize_context_manager(): result = scipy_minimize(closure=closure, parameters={"x": x}) - self.assertEqual( - result.message, "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT" - ) + self.assertTrue(MAX_ITER_MSG.search(result.message)) with self.subTest("optimize_acqf"): with mock_optimize_context_manager():