Skip to content

Commit

Permalink
Add in some type annotations for FullyBayesianSingleTaskGP models (#2703
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #2703

This made it easier for me to check if there were any other issues like #2702. Brings BoTorch down from 10275 to 10254 type errors

Reviewed By: saitcakmak

Differential Revision: D68711505

fbshipit-source-id: fe4861d9f7c2ff3ea879dc8b182a1947b5c120fc
  • Loading branch information
esantorella authored and facebook-github-bot committed Jan 27, 2025
1 parent 81155da commit 90fc872
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 42 deletions.
4 changes: 2 additions & 2 deletions botorch/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from botorch.logging import logger
from botorch.models import SingleTaskGP
from botorch.models.approximate_gp import ApproximateGPyTorchModel
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian import FullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.models.map_saas import get_map_saas_model
from botorch.models.model_list_gp_regression import ModelListGP
Expand Down Expand Up @@ -334,7 +334,7 @@ def _fit_fallback_approximate(


def fit_fully_bayesian_model_nuts(
model: SaasFullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
model: FullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
max_tree_depth: int = 6,
warmup_steps: int = 512,
num_samples: int = 256,
Expand Down
20 changes: 12 additions & 8 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,8 @@ def __init__(
validate_input_scaling(
train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar
)
self._num_outputs = train_Y.shape[-1]
self._input_batch_shape = train_X.shape[:-2]
self._num_outputs: int = train_Y.shape[-1]
self._input_batch_shape: torch.Size = train_X.shape[:-2]
if train_Yvar is not None: # Clamp after transforming
train_Yvar = train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL)

Expand All @@ -596,11 +596,11 @@ def __init__(
)
self.pyro_model: PyroModel = pyro_model
if outcome_transform is not None:
self.outcome_transform = outcome_transform
self.outcome_transform: OutcomeTransform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
self.input_transform: InputTransform = input_transform

def _check_if_fitted(self):
def _check_if_fitted(self) -> None:
r"""Raise an exception if the model hasn't been fitted."""
if self.covar_module is None:
raise RuntimeError(
Expand Down Expand Up @@ -769,7 +769,9 @@ def median_lengthscale(self) -> Tensor:
lengthscale = self.covar_module.base_kernel.lengthscale.clone()
return lengthscale.median(0).values.squeeze(0)

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True
) -> None:
r"""Custom logic for loading the state dict.
The standard approach of calling `load_state_dict` currently doesn't play well
Expand Down Expand Up @@ -894,7 +896,9 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
else:
self.input_transform = input_transform

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True
) -> None:
r"""Custom logic for loading the state dict.
The standard approach of calling `load_state_dict` currently doesn't play well
Expand Down Expand Up @@ -932,7 +936,7 @@ def construct_inputs(
*,
use_input_warping: bool = True,
indices_to_warp: list[int] | None = None,
) -> dict[str, BotorchContainer | Tensor]:
) -> dict[str, BotorchContainer | Tensor | None]:
r"""Construct `SingleTaskGP` keyword arguments from a `SupervisedDataset`.
Args:
Expand Down
69 changes: 37 additions & 32 deletions test/models/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.fully_bayesian import (
FullyBayesianLinearSingleTaskGP,
FullyBayesianSingleTaskGP,
LinearPyroModel,
MCMC_DIM,
MIN_INFERRED_NOISE_LEVEL,
Expand Down Expand Up @@ -82,20 +83,20 @@ class CustomPyroModel(PyroModel):
def sample(self) -> None:
pass

def postprocess_mcmc_samples(self, mcmc_samples, **kwargs):
def postprocess_mcmc_samples(self, mcmc_samples, **kwargs) -> None:
pass

def load_mcmc_samples(self, mcmc_samples):
def load_mcmc_samples(self, mcmc_samples) -> None:
pass


class TestSaasFullyBayesianSingleTaskGP(BotorchTestCase):
model_cls = SaasFullyBayesianSingleTaskGP
pyro_model_cls = SaasPyroModel
model_cls: type[FullyBayesianSingleTaskGP] = SaasFullyBayesianSingleTaskGP
pyro_model_cls: type[PyroModel] = SaasPyroModel
model_kwargs = {}

@property
def expected_keys(self):
def expected_keys(self) -> list[str]:
return [
"mean_module.raw_constant",
"covar_module.raw_outputscale",
Expand All @@ -107,17 +108,21 @@ def expected_keys(self):
]

@property
def expected_keys_noise(self):
def expected_keys_noise(self) -> list[str]:
return self.expected_keys + [
"likelihood.noise_covar.raw_noise",
"likelihood.noise_covar.raw_noise_constraint.lower_bound",
"likelihood.noise_covar.raw_noise_constraint.upper_bound",
]

def _test_f(self, X):
def _test_f(self, X: torch.Tensor) -> torch.Tensor:
return torch.sin(X[:, :1])

def _get_data_and_model(self, infer_noise: bool, **tkwargs):
def _get_data_and_model(
self, infer_noise: bool, **tkwargs
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor | None, FullyBayesianSingleTaskGP
]:
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.rand(10, 4, **tkwargs)
Expand All @@ -135,7 +140,9 @@ def _get_data_and_model(self, infer_noise: bool, **tkwargs):
)
return train_X, train_Y, train_Yvar, model

def _get_unnormalized_data(self, infer_noise: bool, **tkwargs):
def _get_unnormalized_data(
self, infer_noise: bool, **tkwargs
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]:
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = 5 + 5 * torch.rand(10, 4, **tkwargs)
Expand All @@ -148,7 +155,7 @@ def _get_unnormalized_data(self, infer_noise: bool, **tkwargs):

def _get_unnormalized_condition_data(
self, num_models: int, num_cond: int, infer_noise: bool, **tkwargs
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
with torch.random.fork_rng():
torch.manual_seed(0)
cond_X = 5 + 5 * torch.rand(num_models, num_cond, 4, **tkwargs)
Expand All @@ -160,7 +167,7 @@ def _get_unnormalized_condition_data(

def _get_mcmc_samples(
self, num_samples: int, dim: int, infer_noise: bool, **tkwargs
):
) -> dict[str, torch.Tensor]:
mcmc_samples = {
"lengthscale": torch.rand(num_samples, 1, dim, **tkwargs),
"outputscale": torch.rand(num_samples, **tkwargs),
Expand All @@ -170,7 +177,7 @@ def _get_mcmc_samples(
mcmc_samples["noise"] = torch.rand(num_samples, 1, **tkwargs)
return mcmc_samples

def test_raises(self):
def test_raises(self) -> None:
tkwargs = {"device": self.device, "dtype": torch.double}
with self.assertRaisesRegex(
ValueError,
Expand Down Expand Up @@ -230,7 +237,7 @@ def test_raises(self):
with self.assertRaisesRegex(RuntimeError, not_fitted_error_msg):
model.posterior(torch.rand(1, 4, **tkwargs))

def test_fit_model(self):
def test_fit_model(self) -> None:
for infer_noise, dtype in itertools.product(
[True, False], [torch.float, torch.double]
):
Expand Down Expand Up @@ -441,7 +448,7 @@ def test_fit_model(self):
self.assertIsNone(model.covar_module)
self.assertIsNone(model.likelihood)

def test_empty(self):
def test_empty(self) -> None:
model = self.model_cls(
train_X=torch.rand(0, 3),
train_Y=torch.rand(0, 1),
Expand All @@ -452,13 +459,12 @@ def test_empty(self):
)
self.assertEqual(model.covar_module.outputscale.shape, torch.Size([2]))

def test_transforms(self):
def test_transforms(self) -> None:
for infer_noise in [True, False]:
tkwargs = {"device": self.device, "dtype": torch.double}
train_X, train_Y, train_Yvar, test_X = self._get_unnormalized_data(
infer_noise=infer_noise, **tkwargs
)
n, d = train_X.shape

lb, ub = train_X.min(dim=0).values, train_X.max(dim=0).values
mu, sigma = train_Y.mean(), train_Y.std()
Expand Down Expand Up @@ -516,7 +522,7 @@ def test_transforms(self):
self.assertIsInstance(tf, Normalize)
self.assertEqual(tf.center, 0.0)

def test_acquisition_functions(self):
def test_acquisition_functions(self) -> None:
tkwargs = {"device": self.device, "dtype": torch.double}
train_X, train_Y, train_Yvar, model = self._get_data_and_model(
infer_noise=True, **tkwargs
Expand Down Expand Up @@ -655,15 +661,14 @@ def test_acquisition_functions(self):
)
self.assertTrue(X_pruned.ndim == 2 and X_pruned.shape[-1] == 4)

def test_load_samples(self):
def test_load_samples(self) -> None:
for infer_noise, dtype in itertools.product(
[True, False], [torch.float, torch.double]
):
tkwargs = {"device": self.device, "dtype": dtype}
train_X, train_Y, train_Yvar, model = self._get_data_and_model(
train_X, _, train_Yvar, model = self._get_data_and_model(
infer_noise=infer_noise, **tkwargs
)
n, d = train_X.shape
mcmc_samples = self._get_mcmc_samples(
num_samples=3, dim=train_X.shape[-1], infer_noise=infer_noise, **tkwargs
)
Expand Down Expand Up @@ -699,7 +704,7 @@ def test_load_samples(self):
self.assertAllClose(warp.concentration0, mcmc_samples["c0"])
self.assertAllClose(warp.concentration1, mcmc_samples["c1"])

def test_construct_inputs(self):
def test_construct_inputs(self) -> None:
for infer_noise, dtype in itertools.product(
(True, False), (torch.float, torch.double)
):
Expand All @@ -719,7 +724,7 @@ def test_construct_inputs(self):
else:
self.assertTrue(Yvar.equal(data_dict["train_Yvar"]))

def test_custom_pyro_model(self):
def test_custom_pyro_model(self) -> None:
for infer_noise, dtype in itertools.product(
(True, False), (torch.float, torch.double)
):
Expand Down Expand Up @@ -770,7 +775,7 @@ def test_custom_pyro_model(self):
atol=5e-4,
)

def test_condition_on_observation(self):
def test_condition_on_observation(self) -> None:
# The following conditioned data shapes should work (output describes):
# training data shape after cond(batch shape in output is req. in gpytorch)
# X: num_models x n x d, Y: num_models x n x d --> num_models x n x d
Expand Down Expand Up @@ -885,7 +890,7 @@ def test_condition_on_observation(self):
torch.Size([num_models, num_train + 3 * num_cond, num_dims]),
)

def test_bisect(self):
def test_bisect(self) -> None:
def f(x):
return 1 + x

Expand Down Expand Up @@ -941,14 +946,14 @@ def test_deprecated_posterior(self) -> None:


class TestPyroCatchNumericalErrors(BotorchTestCase):
def tearDown(self):
def tearDown(self) -> None:
super().tearDown()
# Remove exception handler so they don't affect the tests on rerun
# TODO: Add functionality to pyro to clear the handlers so this
# does not require touching the internals.
del _EXCEPTION_HANDLERS["foo_runtime"]

def test_pyro_catch_error(self):
def test_pyro_catch_error(self) -> None:
def potential_fn(z):
mvn = pyro.distributions.MultivariateNormal(
loc=torch.zeros(2),
Expand All @@ -971,7 +976,7 @@ def potential_fn(z):

# Default behavior should catch the LinAlgError when peforming a
# Cholesky decomposition and return NaN instead
def potential_fn_chol(z):
def potential_fn_chol(z) -> torch.Tensor:
return torch.linalg.cholesky(z["K"])

_, val = potential_grad(potential_fn_chol, z)
Expand All @@ -985,7 +990,7 @@ def potential_fn_rterr_foo(z):
potential_grad(potential_fn_rterr_foo, z)

# But once we register this specific error then it should
def catch_runtime_error(e):
def catch_runtime_error(e) -> bool:
return type(e) is RuntimeError and "foo" in str(e)

register_exception_handler("foo_runtime", catch_runtime_error)
Expand All @@ -1009,7 +1014,7 @@ def _test_f(self, X):
return X.sum(dim=-1, keepdim=True)

@property
def expected_keys(self):
def expected_keys(self) -> list[str]:
expected_keys = [
"mean_module.raw_constant",
"covar_module.raw_variance",
Expand Down Expand Up @@ -1045,7 +1050,7 @@ def _get_mcmc_samples(
dim: int,
infer_noise: bool,
**tkwargs,
):
) -> dict[str, torch.Tensor]:
mcmc_samples = {
"weight_variance": torch.rand(num_samples, 1, dim, **tkwargs),
"mean": torch.randn(num_samples, **tkwargs),
Expand All @@ -1057,11 +1062,11 @@ def _get_mcmc_samples(
mcmc_samples[k] = torch.rand(num_samples, 1, dim, **tkwargs)
return mcmc_samples

def test_custom_pyro_model(self):
def test_custom_pyro_model(self) -> None:
# custom pyro models are not supported by FullyBayesianLinearSingleTaskGP
pass

def test_empty(self):
def test_empty(self) -> None:
# TODO: support empty models with LinearKernels
pass

Expand Down

0 comments on commit 90fc872

Please sign in to comment.