From 90fc8726cc9227da9e3a78ba27942170ccad9ce9 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Mon, 27 Jan 2025 10:39:52 -0800 Subject: [PATCH] Add in some type annotations for FullyBayesianSingleTaskGP models (#2703) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2703 This made it easier for me to check if there were any other issues like #2702. Brings BoTorch down from 10275 to 10254 type errors Reviewed By: saitcakmak Differential Revision: D68711505 fbshipit-source-id: fe4861d9f7c2ff3ea879dc8b182a1947b5c120fc --- botorch/fit.py | 4 +- botorch/models/fully_bayesian.py | 20 +++++---- test/models/test_fully_bayesian.py | 69 ++++++++++++++++-------------- 3 files changed, 51 insertions(+), 42 deletions(-) diff --git a/botorch/fit.py b/botorch/fit.py index ff006c695d..0c045beeba 100644 --- a/botorch/fit.py +++ b/botorch/fit.py @@ -22,7 +22,7 @@ from botorch.logging import logger from botorch.models import SingleTaskGP from botorch.models.approximate_gp import ApproximateGPyTorchModel -from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP +from botorch.models.fully_bayesian import FullyBayesianSingleTaskGP from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP from botorch.models.map_saas import get_map_saas_model from botorch.models.model_list_gp_regression import ModelListGP @@ -334,7 +334,7 @@ def _fit_fallback_approximate( def fit_fully_bayesian_model_nuts( - model: SaasFullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP, + model: FullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP, max_tree_depth: int = 6, warmup_steps: int = 512, num_samples: int = 256, diff --git a/botorch/models/fully_bayesian.py b/botorch/models/fully_bayesian.py index 91a09e2875..bc6f32f506 100644 --- a/botorch/models/fully_bayesian.py +++ b/botorch/models/fully_bayesian.py @@ -577,8 +577,8 @@ def __init__( validate_input_scaling( train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar ) - self._num_outputs = train_Y.shape[-1] - self._input_batch_shape = train_X.shape[:-2] + self._num_outputs: int = train_Y.shape[-1] + self._input_batch_shape: torch.Size = train_X.shape[:-2] if train_Yvar is not None: # Clamp after transforming train_Yvar = train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL) @@ -596,11 +596,11 @@ def __init__( ) self.pyro_model: PyroModel = pyro_model if outcome_transform is not None: - self.outcome_transform = outcome_transform + self.outcome_transform: OutcomeTransform = outcome_transform if input_transform is not None: - self.input_transform = input_transform + self.input_transform: InputTransform = input_transform - def _check_if_fitted(self): + def _check_if_fitted(self) -> None: r"""Raise an exception if the model hasn't been fitted.""" if self.covar_module is None: raise RuntimeError( @@ -769,7 +769,9 @@ def median_lengthscale(self) -> Tensor: lengthscale = self.covar_module.base_kernel.lengthscale.clone() return lengthscale.median(0).values.squeeze(0) - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True + ) -> None: r"""Custom logic for loading the state dict. The standard approach of calling `load_state_dict` currently doesn't play well @@ -894,7 +896,9 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None: else: self.input_transform = input_transform - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True + ) -> None: r"""Custom logic for loading the state dict. The standard approach of calling `load_state_dict` currently doesn't play well @@ -932,7 +936,7 @@ def construct_inputs( *, use_input_warping: bool = True, indices_to_warp: list[int] | None = None, - ) -> dict[str, BotorchContainer | Tensor]: + ) -> dict[str, BotorchContainer | Tensor | None]: r"""Construct `SingleTaskGP` keyword arguments from a `SupervisedDataset`. Args: diff --git a/test/models/test_fully_bayesian.py b/test/models/test_fully_bayesian.py index caffbd5a3f..1c2fcc0f85 100644 --- a/test/models/test_fully_bayesian.py +++ b/test/models/test_fully_bayesian.py @@ -44,6 +44,7 @@ from botorch.models.deterministic import GenericDeterministicModel from botorch.models.fully_bayesian import ( FullyBayesianLinearSingleTaskGP, + FullyBayesianSingleTaskGP, LinearPyroModel, MCMC_DIM, MIN_INFERRED_NOISE_LEVEL, @@ -82,20 +83,20 @@ class CustomPyroModel(PyroModel): def sample(self) -> None: pass - def postprocess_mcmc_samples(self, mcmc_samples, **kwargs): + def postprocess_mcmc_samples(self, mcmc_samples, **kwargs) -> None: pass - def load_mcmc_samples(self, mcmc_samples): + def load_mcmc_samples(self, mcmc_samples) -> None: pass class TestSaasFullyBayesianSingleTaskGP(BotorchTestCase): - model_cls = SaasFullyBayesianSingleTaskGP - pyro_model_cls = SaasPyroModel + model_cls: type[FullyBayesianSingleTaskGP] = SaasFullyBayesianSingleTaskGP + pyro_model_cls: type[PyroModel] = SaasPyroModel model_kwargs = {} @property - def expected_keys(self): + def expected_keys(self) -> list[str]: return [ "mean_module.raw_constant", "covar_module.raw_outputscale", @@ -107,17 +108,21 @@ def expected_keys(self): ] @property - def expected_keys_noise(self): + def expected_keys_noise(self) -> list[str]: return self.expected_keys + [ "likelihood.noise_covar.raw_noise", "likelihood.noise_covar.raw_noise_constraint.lower_bound", "likelihood.noise_covar.raw_noise_constraint.upper_bound", ] - def _test_f(self, X): + def _test_f(self, X: torch.Tensor) -> torch.Tensor: return torch.sin(X[:, :1]) - def _get_data_and_model(self, infer_noise: bool, **tkwargs): + def _get_data_and_model( + self, infer_noise: bool, **tkwargs + ) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor | None, FullyBayesianSingleTaskGP + ]: with torch.random.fork_rng(): torch.manual_seed(0) train_X = torch.rand(10, 4, **tkwargs) @@ -135,7 +140,9 @@ def _get_data_and_model(self, infer_noise: bool, **tkwargs): ) return train_X, train_Y, train_Yvar, model - def _get_unnormalized_data(self, infer_noise: bool, **tkwargs): + def _get_unnormalized_data( + self, infer_noise: bool, **tkwargs + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]: with torch.random.fork_rng(): torch.manual_seed(0) train_X = 5 + 5 * torch.rand(10, 4, **tkwargs) @@ -148,7 +155,7 @@ def _get_unnormalized_data(self, infer_noise: bool, **tkwargs): def _get_unnormalized_condition_data( self, num_models: int, num_cond: int, infer_noise: bool, **tkwargs - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: with torch.random.fork_rng(): torch.manual_seed(0) cond_X = 5 + 5 * torch.rand(num_models, num_cond, 4, **tkwargs) @@ -160,7 +167,7 @@ def _get_unnormalized_condition_data( def _get_mcmc_samples( self, num_samples: int, dim: int, infer_noise: bool, **tkwargs - ): + ) -> dict[str, torch.Tensor]: mcmc_samples = { "lengthscale": torch.rand(num_samples, 1, dim, **tkwargs), "outputscale": torch.rand(num_samples, **tkwargs), @@ -170,7 +177,7 @@ def _get_mcmc_samples( mcmc_samples["noise"] = torch.rand(num_samples, 1, **tkwargs) return mcmc_samples - def test_raises(self): + def test_raises(self) -> None: tkwargs = {"device": self.device, "dtype": torch.double} with self.assertRaisesRegex( ValueError, @@ -230,7 +237,7 @@ def test_raises(self): with self.assertRaisesRegex(RuntimeError, not_fitted_error_msg): model.posterior(torch.rand(1, 4, **tkwargs)) - def test_fit_model(self): + def test_fit_model(self) -> None: for infer_noise, dtype in itertools.product( [True, False], [torch.float, torch.double] ): @@ -441,7 +448,7 @@ def test_fit_model(self): self.assertIsNone(model.covar_module) self.assertIsNone(model.likelihood) - def test_empty(self): + def test_empty(self) -> None: model = self.model_cls( train_X=torch.rand(0, 3), train_Y=torch.rand(0, 1), @@ -452,13 +459,12 @@ def test_empty(self): ) self.assertEqual(model.covar_module.outputscale.shape, torch.Size([2])) - def test_transforms(self): + def test_transforms(self) -> None: for infer_noise in [True, False]: tkwargs = {"device": self.device, "dtype": torch.double} train_X, train_Y, train_Yvar, test_X = self._get_unnormalized_data( infer_noise=infer_noise, **tkwargs ) - n, d = train_X.shape lb, ub = train_X.min(dim=0).values, train_X.max(dim=0).values mu, sigma = train_Y.mean(), train_Y.std() @@ -516,7 +522,7 @@ def test_transforms(self): self.assertIsInstance(tf, Normalize) self.assertEqual(tf.center, 0.0) - def test_acquisition_functions(self): + def test_acquisition_functions(self) -> None: tkwargs = {"device": self.device, "dtype": torch.double} train_X, train_Y, train_Yvar, model = self._get_data_and_model( infer_noise=True, **tkwargs @@ -655,15 +661,14 @@ def test_acquisition_functions(self): ) self.assertTrue(X_pruned.ndim == 2 and X_pruned.shape[-1] == 4) - def test_load_samples(self): + def test_load_samples(self) -> None: for infer_noise, dtype in itertools.product( [True, False], [torch.float, torch.double] ): tkwargs = {"device": self.device, "dtype": dtype} - train_X, train_Y, train_Yvar, model = self._get_data_and_model( + train_X, _, train_Yvar, model = self._get_data_and_model( infer_noise=infer_noise, **tkwargs ) - n, d = train_X.shape mcmc_samples = self._get_mcmc_samples( num_samples=3, dim=train_X.shape[-1], infer_noise=infer_noise, **tkwargs ) @@ -699,7 +704,7 @@ def test_load_samples(self): self.assertAllClose(warp.concentration0, mcmc_samples["c0"]) self.assertAllClose(warp.concentration1, mcmc_samples["c1"]) - def test_construct_inputs(self): + def test_construct_inputs(self) -> None: for infer_noise, dtype in itertools.product( (True, False), (torch.float, torch.double) ): @@ -719,7 +724,7 @@ def test_construct_inputs(self): else: self.assertTrue(Yvar.equal(data_dict["train_Yvar"])) - def test_custom_pyro_model(self): + def test_custom_pyro_model(self) -> None: for infer_noise, dtype in itertools.product( (True, False), (torch.float, torch.double) ): @@ -770,7 +775,7 @@ def test_custom_pyro_model(self): atol=5e-4, ) - def test_condition_on_observation(self): + def test_condition_on_observation(self) -> None: # The following conditioned data shapes should work (output describes): # training data shape after cond(batch shape in output is req. in gpytorch) # X: num_models x n x d, Y: num_models x n x d --> num_models x n x d @@ -885,7 +890,7 @@ def test_condition_on_observation(self): torch.Size([num_models, num_train + 3 * num_cond, num_dims]), ) - def test_bisect(self): + def test_bisect(self) -> None: def f(x): return 1 + x @@ -941,14 +946,14 @@ def test_deprecated_posterior(self) -> None: class TestPyroCatchNumericalErrors(BotorchTestCase): - def tearDown(self): + def tearDown(self) -> None: super().tearDown() # Remove exception handler so they don't affect the tests on rerun # TODO: Add functionality to pyro to clear the handlers so this # does not require touching the internals. del _EXCEPTION_HANDLERS["foo_runtime"] - def test_pyro_catch_error(self): + def test_pyro_catch_error(self) -> None: def potential_fn(z): mvn = pyro.distributions.MultivariateNormal( loc=torch.zeros(2), @@ -971,7 +976,7 @@ def potential_fn(z): # Default behavior should catch the LinAlgError when peforming a # Cholesky decomposition and return NaN instead - def potential_fn_chol(z): + def potential_fn_chol(z) -> torch.Tensor: return torch.linalg.cholesky(z["K"]) _, val = potential_grad(potential_fn_chol, z) @@ -985,7 +990,7 @@ def potential_fn_rterr_foo(z): potential_grad(potential_fn_rterr_foo, z) # But once we register this specific error then it should - def catch_runtime_error(e): + def catch_runtime_error(e) -> bool: return type(e) is RuntimeError and "foo" in str(e) register_exception_handler("foo_runtime", catch_runtime_error) @@ -1009,7 +1014,7 @@ def _test_f(self, X): return X.sum(dim=-1, keepdim=True) @property - def expected_keys(self): + def expected_keys(self) -> list[str]: expected_keys = [ "mean_module.raw_constant", "covar_module.raw_variance", @@ -1045,7 +1050,7 @@ def _get_mcmc_samples( dim: int, infer_noise: bool, **tkwargs, - ): + ) -> dict[str, torch.Tensor]: mcmc_samples = { "weight_variance": torch.rand(num_samples, 1, dim, **tkwargs), "mean": torch.randn(num_samples, **tkwargs), @@ -1057,11 +1062,11 @@ def _get_mcmc_samples( mcmc_samples[k] = torch.rand(num_samples, 1, dim, **tkwargs) return mcmc_samples - def test_custom_pyro_model(self): + def test_custom_pyro_model(self) -> None: # custom pyro models are not supported by FullyBayesianLinearSingleTaskGP pass - def test_empty(self): + def test_empty(self) -> None: # TODO: support empty models with LinearKernels pass