diff --git a/botorch/fit.py b/botorch/fit.py index eefdb943fa..0c045beeba 100644 --- a/botorch/fit.py +++ b/botorch/fit.py @@ -15,13 +15,19 @@ from typing import Any from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage +import torch + from botorch.exceptions.errors import ModelFittingError, UnsupportedError from botorch.exceptions.warnings import OptimizationWarning from botorch.logging import logger +from botorch.models import SingleTaskGP from botorch.models.approximate_gp import ApproximateGPyTorchModel -from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP +from botorch.models.fully_bayesian import FullyBayesianSingleTaskGP from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP +from botorch.models.map_saas import get_map_saas_model from botorch.models.model_list_gp_regression import ModelListGP +from botorch.models.transforms.input import InputTransform +from botorch.models.transforms.outcome import OutcomeTransform from botorch.optim.closures import get_loss_closure_with_grads from botorch.optim.core import _LBFGSB_MAXITER_MAXFUN_REGEX from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch @@ -38,11 +44,13 @@ from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder from gpytorch.likelihoods import Likelihood from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood from linear_operator.utils.errors import NotPSDError from pyro.infer.mcmc import MCMC, NUTS from torch import device, Tensor +from torch.distributions import HalfCauchy from torch.nn import Parameter from torch.utils.data import DataLoader @@ -326,7 +334,7 @@ def _fit_fallback_approximate( def fit_fully_bayesian_model_nuts( - model: SaasFullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP, + model: FullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP, max_tree_depth: int = 6, warmup_steps: int = 512, num_samples: int = 256, @@ -382,3 +390,128 @@ def fit_fully_bayesian_model_nuts( # Load the MCMC samples back into the BoTorch model model.load_mcmc_samples(mcmc_samples) model.eval() + + +def get_fitted_map_saas_model( + train_X: Tensor, + train_Y: Tensor, + train_Yvar: Tensor | None = None, + input_transform: InputTransform | None = None, + outcome_transform: OutcomeTransform | None = None, + tau: float | None = None, + optimizer_kwargs: dict[str, Any] | None = None, +) -> SingleTaskGP: + """Get a fitted MAP SAAS model with a Matern kernel. + + Args: + train_X: Tensor of shape `n x d` with training inputs. + train_Y: Tensor of shape `n x 1` with training targets. + train_Yvar: Optional tensor of shape `n x 1` with observed noise, + inferred if None. + input_transform: An optional input transform. + outcome_transform: An optional outcome transforms. + tau: Fixed value of the global shrinkage tau. If None, the model + places a HC(0.1) prior on tau. + optimizer_kwargs: A dict of options for the optimizer passed + to fit_gpytorch_mll. + + Returns: + A fitted SingleTaskGP with a Matern kernel. + """ + # make sure optimizer_kwargs is a Dict + optimizer_kwargs = optimizer_kwargs or {} + model = get_map_saas_model( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + input_transform=( + input_transform.train() if input_transform is not None else None + ), + outcome_transform=outcome_transform, + tau=tau, + ) + mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood) + fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs) + return model + + +def get_fitted_map_saas_ensemble( + train_X: Tensor, + train_Y: Tensor, + train_Yvar: Tensor | None = None, + input_transform: InputTransform | None = None, + outcome_transform: OutcomeTransform | None = None, + taus: Tensor | list[float] | None = None, + num_taus: int = 4, + optimizer_kwargs: dict[str, Any] | None = None, +) -> SaasFullyBayesianSingleTaskGP: + """Get a fitted SAAS ensemble using several different tau values. + + Args: + train_X: Tensor of shape `n x d` with training inputs. + train_Y: Tensor of shape `n x 1` with training targets. + train_Yvar: Optional tensor of shape `n x 1` with observed noise, + inferred if None. + input_transform: An optional input transform. + outcome_transform: An optional outcome transforms. + taus: Global shrinkage values to use. If None, we sample `num_taus` values + from an HC(0.1) distrbution. + num_taus: Optional argument for how many taus to sample. + optimizer_kwargs: A dict of options for the optimizer passed + to fit_gpytorch_mll. + + Returns: + A fitted SaasFullyBayesianSingleTaskGP with a Matern kernel. + """ + tkwargs = {"device": train_X.device, "dtype": train_X.dtype} + if taus is None: + taus = HalfCauchy(0.1).sample([num_taus]).to(**tkwargs) + num_samples = len(taus) + if num_samples == 1: + raise ValueError( + "Use `get_fitted_map_saas_model` if you only specify one value of tau" + ) + + mean = torch.zeros(num_samples, **tkwargs) + outputscale = torch.zeros(num_samples, **tkwargs) + lengthscale = torch.zeros(num_samples, train_X.shape[-1], **tkwargs) + noise = torch.zeros(num_samples, **tkwargs) + + # Fit a model for each tau and save the hyperparameters + for i, tau in enumerate(taus): + model = get_fitted_map_saas_model( + train_X, + train_Y, + train_Yvar=train_Yvar, + input_transform=input_transform, + outcome_transform=outcome_transform, + tau=tau, + optimizer_kwargs=optimizer_kwargs, + ) + mean[i] = model.mean_module.constant.detach().clone() + outputscale[i] = model.covar_module.outputscale.detach().clone() + lengthscale[i, :] = model.covar_module.base_kernel.lengthscale.detach().clone() + if train_Yvar is None: + noise[i] = model.likelihood.noise.detach().clone() + + # Load the samples into a fully Bayesian SAAS model + ensemble_model = SaasFullyBayesianSingleTaskGP( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + input_transform=( + input_transform.train() if input_transform is not None else None + ), + outcome_transform=outcome_transform, + ) + mcmc_samples = { + "mean": mean, + "outputscale": outputscale, + "lengthscale": lengthscale, + } + if train_Yvar is None: + mcmc_samples["noise"] = noise + ensemble_model.train() + ensemble_model.load_mcmc_samples(mcmc_samples=mcmc_samples) + ensemble_model.eval() + return ensemble_model diff --git a/botorch/models/__init__.py b/botorch/models/__init__.py index 031ce83299..06bd17b5b1 100644 --- a/botorch/models/__init__.py +++ b/botorch/models/__init__.py @@ -21,12 +21,16 @@ from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.higher_order_gp import HigherOrderGP + +from botorch.models.map_saas import add_saas_prior, AdditiveMapSaasSingleTaskGP from botorch.models.model import ModelList from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood __all__ = [ + "add_saas_prior", + "AdditiveMapSaasSingleTaskGP", "AffineDeterministicModel", "AffineFidelityCostModel", "ApproximateGPyTorchModel", diff --git a/botorch/models/fully_bayesian.py b/botorch/models/fully_bayesian.py index 52361a9e41..bc6f32f506 100644 --- a/botorch/models/fully_bayesian.py +++ b/botorch/models/fully_bayesian.py @@ -33,7 +33,7 @@ import math from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import Any +from typing import Any, TypeVar import pyro import torch @@ -67,6 +67,11 @@ from pyro.ops.integrator import register_exception_handler from torch import Tensor +# Can replace with Self type once 3.11 is the minimum version +TFullyBayesianSingleTaskGP = TypeVar( + "TFullyBayesianSingleTaskGP", bound="FullyBayesianSingleTaskGP" +) + _sqrt5 = math.sqrt(5) @@ -572,8 +577,8 @@ def __init__( validate_input_scaling( train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar ) - self._num_outputs = train_Y.shape[-1] - self._input_batch_shape = train_X.shape[:-2] + self._num_outputs: int = train_Y.shape[-1] + self._input_batch_shape: torch.Size = train_X.shape[:-2] if train_Yvar is not None: # Clamp after transforming train_Yvar = train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL) @@ -591,11 +596,11 @@ def __init__( ) self.pyro_model: PyroModel = pyro_model if outcome_transform is not None: - self.outcome_transform = outcome_transform + self.outcome_transform: OutcomeTransform = outcome_transform if input_transform is not None: - self.input_transform = input_transform + self.input_transform: InputTransform = input_transform - def _check_if_fitted(self): + def _check_if_fitted(self) -> None: r"""Raise an exception if the model hasn't been fitted.""" if self.covar_module is None: raise RuntimeError( @@ -623,13 +628,16 @@ def _aug_batch_shape(self) -> torch.Size: aug_batch_shape += torch.Size([self.num_outputs]) return aug_batch_shape - def train(self, mode: bool = True) -> None: + def train( + self: TFullyBayesianSingleTaskGP, mode: bool = True + ) -> TFullyBayesianSingleTaskGP: r"""Puts the model in `train` mode.""" super().train(mode=mode) if mode: self.mean_module = None self.covar_module = None self.likelihood = None + return self def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None: r"""Load the MCMC hyperparameter samples into the model. @@ -761,7 +769,9 @@ def median_lengthscale(self) -> Tensor: lengthscale = self.covar_module.base_kernel.lengthscale.clone() return lengthscale.median(0).values.squeeze(0) - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True + ) -> None: r"""Custom logic for loading the state dict. The standard approach of calling `load_state_dict` currently doesn't play well @@ -886,7 +896,9 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None: else: self.input_transform = input_transform - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True + ) -> None: r"""Custom logic for loading the state dict. The standard approach of calling `load_state_dict` currently doesn't play well @@ -924,7 +936,7 @@ def construct_inputs( *, use_input_warping: bool = True, indices_to_warp: list[int] | None = None, - ) -> dict[str, BotorchContainer | Tensor]: + ) -> dict[str, BotorchContainer | Tensor | None]: r"""Construct `SingleTaskGP` keyword arguments from a `SupervisedDataset`. Args: diff --git a/botorch/models/map_saas.py b/botorch/models/map_saas.py new file mode 100644 index 0000000000..17a69cd231 --- /dev/null +++ b/botorch/models/map_saas.py @@ -0,0 +1,419 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from botorch.exceptions import UnsupportedError +from botorch.models.gp_regression import SingleTaskGP +from botorch.models.transforms.input import InputTransform +from botorch.models.transforms.outcome import OutcomeTransform +from botorch.utils.constraints import LogTransformedInterval +from gpytorch.constraints import Interval +from gpytorch.kernels import AdditiveKernel, Kernel, MaternKernel, ScaleKernel +from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood +from gpytorch.means import ConstantMean +from gpytorch.priors import GammaPrior, HalfCauchyPrior, NormalPrior +from torch import Tensor +from torch.distributions.half_cauchy import HalfCauchy +from torch.nn import Parameter + + +EPS = 1e-8 + + +class SaasPriorHelper: + """Helper class for specifying parameter and setting closures.""" + + def __init__(self, tau: float | None = None): + """Instantiates a new helper object. + + Args: + tau: Value of the global shrinkage parameter. If `None`, the tau will be + a free parameter and inferred from the data. + """ + self._tau = torch.as_tensor(tau) if tau is not None else None + + def tau(self, m: Kernel) -> Tensor: + """The global shrinkage parameter `tau`. + + Args: + m: A kernel object equipped with a lengthscale. + + Returns: + The global shrinkage parameter of the SAAS prior. + """ + return ( + self._tau.to(m.lengthscale) + if self._tau is not None + else m.raw_tau_constraint.transform(m.raw_tau) + ) + + def inv_lengthscale_prior_param_or_closure(self, m: Kernel) -> Tensor: + """Closure to compute the scaled inverse lengthscale parameter (`tau / l^2`) + to which the SAAS prior is applied. + + Args: + m: A kernel object equipped with a lengthscale. + + Returns: + The scaled inverse lengthscale parameter. + """ + tau = self.tau(m) + return tau.view(*tau.shape, 1, 1) / (m.lengthscale**2) + + def inv_lengthscale_prior_setting_closure(self, m: Kernel, value: Tensor) -> None: + """Closure to set the inverse lengthscale prior parameter. + + Args: + m: A kernel object equipped with a lengthscale. + value: The value of the scaled inverse lengthscale parameter, (`tau / l^2`), + used to recover and set the lengthscale of the kernel. + """ + # Lengthscale is batch x m x 1 x d, update tau to avoid unwanted broadcasting. + tau = self.tau(m) + tau = tau.view(*tau.shape, 1, 1) + lb = m.raw_lengthscale_constraint.lower_bound.to(tau) + ub = m.raw_lengthscale_constraint.upper_bound.to(tau) + m._set_lengthscale((tau / value.to(tau)).sqrt().clamp(lb + EPS, ub - EPS)) + + def tau_prior_param_or_closure(self, m: Kernel) -> Tensor: + """Closure to compute the global shrinkage parameter `tau`. + + Args: + m: A kernel object equipped with a `raw_tau` parameter. + + Returns: + The transformed global shrinkage parameter `tau`. + """ + return m.raw_tau_constraint.transform(m.raw_tau) + + def tau_prior_setting_closure(self, m: Kernel, value: Tensor) -> None: + """Closure to set the global shrinkage parameter `tau`. + + Args: + m: A kernel object equipped with a `raw_tau` parameter. + value: The value of the global shrinkage parameter. + """ + lb = m.raw_tau_constraint.lower_bound.to(m.raw_tau) + ub = m.raw_tau_constraint.upper_bound.to(m.raw_tau) + m.raw_tau.data.fill_( + m.raw_tau_constraint.inverse_transform( + value.to(m.raw_tau).clamp(lb + EPS, ub - EPS) + ).item() + ) + + +def add_saas_prior( + base_kernel: Kernel, + tau: float | None = None, + log_scale: bool = True, +) -> Kernel: + """Add a SAAS prior to a given base_kernel. + + The SAAS prior is given by tau / lengthscale^2 ~ HC(1.0). If tau is None, + we place an additional HC(0.1) prior on tau similar to the original SAAS prior + that relies on inference with NUTS. + + Example: + >>> matern_kernel = MaternKernel(...) + >>> add_saas_prior(matern_kernel, tau=None) # Add a SAAS prior + + Args: + base_kernel: Base kernel that has a lengthscale and uses ARD. + Note that this function modifies the kernel object in place. + tau: Value of the global shrinkage. If `None`, infer the global + shrinkage parameter. + log_scale: Set to `True` if the lengthscale and tau should be optimized on + a log-scale without any domain rescaling. That is, we will learn + `raw_lengthscale := log(lengthscale)` and this hyperparameter needs to + satisfy the corresponding bound constraints. Setting this to `True` will + generally improve the numerical stability, but requires an optimizer that + can handle bound constraints, e.g., L-BFGS-B. + + Returns: + Base kernel with SAAS priors added. + """ + if not base_kernel.has_lengthscale: + raise UnsupportedError("base_kernel must have lengthscale(s)") + if hasattr(base_kernel, "lengthscale_prior"): + raise UnsupportedError("base_kernel must not specify a lengthscale prior") + tkwargs = {"device": base_kernel.device, "dtype": base_kernel.dtype} + + batch_shape = base_kernel.raw_lengthscale.shape[:-2] + IntervalClass = LogTransformedInterval if log_scale else Interval + base_kernel.register_constraint( + param_name="raw_lengthscale", + constraint=IntervalClass(0.01, 1e4, initial_value=1), + replace=True, + ) + prior_helper = SaasPriorHelper(tau=tau) + if tau is None: # Place a HC(0.1) prior on tau + base_kernel.register_parameter( + name="raw_tau", + parameter=Parameter(torch.full(batch_shape, 0.1, **tkwargs)), + ) + base_kernel.register_constraint( + param_name="raw_tau", + constraint=IntervalClass(1e-3, 10, initial_value=0.1), + replace=True, + ) + base_kernel.register_prior( + name="tau_prior", + prior=HalfCauchyPrior(torch.tensor(0.1, **tkwargs)), + param_or_closure=prior_helper.tau_prior_param_or_closure, + setting_closure=prior_helper.tau_prior_setting_closure, + ) + # Place a HC(1) prior on tau / lengthscale^2 + base_kernel.register_prior( + name="inv_lengthscale_prior", + prior=HalfCauchyPrior(torch.tensor(1.0, **tkwargs)), + param_or_closure=prior_helper.inv_lengthscale_prior_param_or_closure, + setting_closure=prior_helper.inv_lengthscale_prior_setting_closure, + ) + return base_kernel + + +def get_map_saas_model( + train_X: Tensor, + train_Y: Tensor, + train_Yvar: Tensor | None = None, + input_transform: InputTransform | None = None, + outcome_transform: OutcomeTransform | None = None, + tau: float | None = None, +) -> SingleTaskGP: + """Helper method for creating an unfitted MAP SAAS model. + + Args: + train_X: Tensor of shape `n x d` with training inputs. + train_Y: Tensor of shape `n x 1` with training targets. + train_Yvar: Optional tensor of shape `n x 1` with observed noise, + inferred if None. + input_transform: An optional input transform. + outcome_transform: An optional outcome transforms. + tau: Fixed value of the global shrinkage tau. If None, the model + places a HC(0.1) prior on tau and infers it. + + Returns: + A SingleTaskGP with a Matern kernel and a SAAS prior. + """ + # TODO: Shape checks + _, aug_batch_shape = SingleTaskGP.get_batch_dimensions( + train_X=train_X, train_Y=train_Y + ) + mean_module = get_mean_module_with_normal_prior(batch_shape=aug_batch_shape) + if input_transform is not None: + with torch.no_grad(): + transformed_X = input_transform(train_X) + ard_num_dims = transformed_X.shape[-1] + else: + ard_num_dims = train_X.shape[-1] + base_kernel = MaternKernel( + nu=2.5, ard_num_dims=ard_num_dims, batch_shape=aug_batch_shape + ) + # NOTE: need to call `to` to set device and dtype before calling `add_saas_prior`, + # since the SAAS prior contains tensors that are not parameters of the model, and + # terefore not automatically moved to the correct device with a `to` call on the + # model. + base_kernel.to(train_X) + add_saas_prior(base_kernel=base_kernel, tau=tau) + covar_module = ScaleKernel( + base_kernel=base_kernel, + outputscale_constraint=LogTransformedInterval(1e-2, 1e4, initial_value=10), + batch_shape=aug_batch_shape, + ) + if train_Yvar is None: + likelihood = get_gaussian_likelihood_with_gamma_prior( + batch_shape=aug_batch_shape + ) + else: + likelihood = None + model = SingleTaskGP( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + mean_module=mean_module, + covar_module=covar_module, + likelihood=likelihood, + input_transform=input_transform, + outcome_transform=outcome_transform, + ) + model.to(train_X) + return model + + +def get_mean_module_with_normal_prior( + batch_shape: torch.Size | None = None, +) -> ConstantMean: + """Return constant mean with a N(0, 1) prior constrained to [-10, 10]. + + This prior assumes the outputs (targets) have been standardized to have zero mean + and unit variance. + + Args: + batch_shape: Optional batch shape for the constant-mean module. + + Returns: + ConstantMean module. + """ + return ConstantMean( + constant_prior=NormalPrior(loc=0.0, scale=1.0), + constant_constraint=Interval( + -10, + 10, + initial_value=0, + transform=None, + ), + batch_shape=batch_shape or torch.Size(), + ) + + +def get_gaussian_likelihood_with_gamma_prior(batch_shape: torch.Size | None = None): + """Return Gaussian likelihood with a Gamma(0.9, 10) prior. + + This prior prefers small noise, but also has heavy tails. + + Args: + batch_shape: Batch shape for the likelihood. + + Returns: + GaussianLikelihood with Gamma(0.9, 10) prior constrained to [1e-4, 0.1]. + """ + return GaussianLikelihood( + noise_prior=GammaPrior(0.9, 10.0), + noise_constraint=LogTransformedInterval(1e-4, 1, initial_value=1e-2), + batch_shape=batch_shape or torch.Size(), + ) + + +def get_additive_map_saas_covar_module( + ard_num_dims: int, + num_taus: int = 4, + active_dims: tuple[int, ...] | None = None, + batch_shape: torch.Size | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, +): + """Return an additive map SAAS covar module. + + The constructed kernel is an additive kernel with `num_taus` terms. Each term is a + scaled Matern kernel with a SAAS prior and a tau sampled from a HalfCauchy(0, 1) + distrbution. + + Args: + ard_num_dims: The number of inputs dimensions. + num_taus: The number of taus to use (4 if omitted). + active_dims: Active dims for the covar module. The kernel will be evaluated + only using these columns of the input tensor. + batch_shape: Batch shape for the covar module. + + Returns: + An additive MAP SAAS covar module. + """ + batch_shape = batch_shape or torch.Size() + kernels = [] + for _ in range(num_taus): + base_kernel = MaternKernel( + nu=2.5, + ard_num_dims=ard_num_dims, + batch_shape=batch_shape, + active_dims=active_dims, + ).to(dtype=dtype, device=device) + add_saas_prior(base_kernel=base_kernel, tau=HalfCauchy(0.1).sample(batch_shape)) + scaled_kernel = ScaleKernel( + base_kernel=base_kernel, + outputscale_constraint=LogTransformedInterval(1e-2, 1e4, initial_value=10), + batch_shape=batch_shape, + ) + kernels.append(scaled_kernel) + return AdditiveKernel(*kernels) + + +class AdditiveMapSaasSingleTaskGP(SingleTaskGP): + """An additive MAP SAAS single-task GP. + + This is a maximum-a-posteriori (MAP) version of sparse axis-aligned subspace BO + (SAASBO), see `SaasFullyBayesianSingleTaskGP` for more details. SAASBO is a + high-dimensional Bayesian optimization approach that uses approximate fully + Bayesian inference via NUTS to learn the model hyperparameters. This works very + well, but is very computationally expensive which limits the use of SAASBO to a + small (~100) number of trials. Two of the main benefits with SAASBO are: + + (1) A sparse prior on the inverse lengthscales that avoid overfitting. + (2) The ability to sample several (~16) sets of hyperparameters from the + posterior that we can average over when computing the acquisition + function (ensembling). + + The goal of this Additive MAP SAAS model is to retain the main benefits of the SAAS + model while significantly speeding up the time to fit the model. We achieve this by + creating an additive kernel where each kernel in the sum is a Matern-5/2 kernel + with a SAAS prior and a separate outputscale. The sparsity level for each kernel + is sampled from an HC(0.1) distribution leading to a mix of sparsity levels (as is + often the case for the fully Bayesian SAAS model). We learn all the hyperparameters + using MAP inference which is significantly faster than using NUTS. + + While we often find that the original SAAS model with NUTS performs better, the + additive MAP SAAS model can be several orders of magnitude faster to fit, which + makes it applicable to problems with potentially thousands of trials. + """ + + def __init__( + self, + train_X: Tensor, + train_Y: Tensor, + train_Yvar: Tensor | None = None, + outcome_transform: OutcomeTransform | None = None, + input_transform: InputTransform | None = None, + num_taus: int = 4, + ) -> None: + """Instantiates an AdditiveMapSaasSingleTaskGP. + + Args: + train_X: A `batch_shape x n x d` tensor of training features. + train_Y: A `batch_shape x n x m` tensor of training observations. + train_Yvar: A `batch_shape x n x m` tensor of observed noise. + outcome_transform: An optional outcome transform. + input_transform: An optional input transform. + num_taus: The number of taus to use (4 if omitted). + """ + self._set_dimensions(train_X=train_X, train_Y=train_Y) + mean_module = get_mean_module_with_normal_prior( + batch_shape=self._aug_batch_shape + ) + if train_Yvar is not None: + _, _, train_Yvar = self._transform_tensor_args( + X=train_X, Y=train_Y, Yvar=train_Yvar + ) + likelihood = ( + FixedNoiseGaussianLikelihood( + noise=train_Yvar, batch_shape=self._aug_batch_shape + ) + if train_Yvar is not None + else get_gaussian_likelihood_with_gamma_prior( + batch_shape=self._aug_batch_shape + ) + ) + covar_module = get_additive_map_saas_covar_module( + ard_num_dims=train_X.shape[-1], + num_taus=num_taus, + batch_shape=self._aug_batch_shape, + # Need to pass dtype and device at initialization of the covar_module + # because its priors contain tensors, and prior are currently not moved + # to the correct device/dtype when callling `to` on the model. + dtype=train_X.dtype, + device=train_X.device, + ) + + SingleTaskGP.__init__( + self=self, + train_X=train_X, + train_Y=train_Y, + mean_module=mean_module, + covar_module=covar_module, + likelihood=likelihood, + input_transform=input_transform, + outcome_transform=outcome_transform, + ) + # Make sure that all buffers and parameters have the correct device and dtype + self.to(dtype=train_X.dtype, device=train_X.device) diff --git a/botorch/utils/constraints.py b/botorch/utils/constraints.py index 4af5d408ab..72ba696932 100644 --- a/botorch/utils/constraints.py +++ b/botorch/utils/constraints.py @@ -10,11 +10,14 @@ from __future__ import annotations +import math + from collections.abc import Callable from functools import partial import torch +from gpytorch import settings from gpytorch.constraints import Interval from torch import Tensor @@ -143,3 +146,77 @@ def transform(self, tensor: Tensor) -> Tensor: def inverse_transform(self, transformed_tensor: Tensor) -> Tensor: return transformed_tensor + + +class LogTransformedInterval(Interval): + """Modification of the GPyTorch interval class. + + The Interval class in GPyTorch will map the parameter to the range [0, 1] before + applying the inverse transform. LogTransformedInterval skips this step to avoid + numerical issues, and applies the log transform directly to the parameter values. + GPyTorch automatically recognizes that the bound constraint have not been applied + yet, and passes the bounds to the optimizer instead, which then optimizes + log(parameter) under the constraints log(lower) <= log(parameter) <= log(upper). + """ + + def __init__( + self, + lower_bound: float, + upper_bound: float, + initial_value: float | None = None, + ): + """Constructor of the LogTransformedInterval class. + + Args: + lower_bound: The lower bound of the interval. + upper_bound: The upper bound of the interval. + initial_value: The initial value of the parameter. + """ + super().__init__( + lower_bound=lower_bound, + upper_bound=upper_bound, + transform=torch.exp, + inv_transform=torch.log, + initial_value=initial_value, + ) + + # Save the untransformed initial value + self.register_buffer( + "initial_value_untransformed", + ( + torch.tensor(initial_value).to(self.lower_bound) + if initial_value is not None + else None + ), + ) + + if settings.debug.on(): + max_bound = torch.max(self.upper_bound) + min_bound = torch.min(self.lower_bound) + if max_bound == math.inf or min_bound == -math.inf: + raise RuntimeError( + "Cannot make an Interval directly with non-finite bounds. Use a " + "derived class like GreaterThan or LessThan instead." + ) + + def transform(self, tensor: Tensor) -> Tensor: + """Transform the parameter using the exponential function. + + Args: + tensor: Tensor of parameter values to transform. + + Returns: + Tensor of transformed parameter values. + """ + return self._transform(tensor) + + def inverse_transform(self, transformed_tensor: Tensor) -> Tensor: + """Untransform the parameter using the natural logarithm. + + Args: + tensor: Tensor of parameter values to untransform. + + Returns: + Tensor of untransformed parameter values. + """ + return self._inv_transform(transformed_tensor) diff --git a/sphinx/source/models.rst b/sphinx/source/models.rst index dbeb43998d..e904fd26e9 100644 --- a/sphinx/source/models.rst +++ b/sphinx/source/models.rst @@ -39,29 +39,34 @@ Cost Models (for cost-aware optimization) .. automodule:: botorch.models.cost :members: -GP Regression Models -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.gp_regression +Contextual GP Models with Aggregate Rewards +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.contextual :members: -Multi-Fidelity GP Regression Models -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.gp_regression_fidelity +Contextual GP Models with Context Rewards +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.contextual_multioutput :members: -GP Regression Models for Mixed Parameter Spaces +Fully Bayesian GP Models ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.gp_regression_mixed +.. automodule:: botorch.models.fully_bayesian :members: -Model List GP Regression Models -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.model_list_gp_regression +Fully Bayesian Multitask GP Models +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.fully_bayesian_multitask :members: -Multitask GP Models +GP Regression Models ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.multitask +.. automodule:: botorch.models.gp_regression + :members: + +GP Regression Models for Mixed Parameter Spaces +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.gp_regression_mixed :members: Higher Order GP Models @@ -74,39 +79,39 @@ Latent Kronecker GP Models .. automodule:: botorch.models.latent_kronecker_gp :members: -Pairwise GP Models +Model List GP Regression Models ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.pairwise_gp +.. automodule:: botorch.models.model_list_gp_regression :members: -Contextual GP Models with Aggregate Rewards -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.contextual +Multitask GP Models +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.multitask :members: -Contextual GP Models with Context Rewards -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.contextual_multioutput +Multi-Fidelity GP Regression Models +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.gp_regression_fidelity :members: -Variational GP Models -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.approximate_gp +Pairwise GP Models +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.pairwise_gp :members: -Fully Bayesian GP Models +Relevance Pursuit Models ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.fully_bayesian +.. automodule:: botorch.models.relevance_pursuit :members: -Fully Bayesian Multitask GP Models +Sparse Axis-Aligned Subspaces (SAAS) GP Models ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.fully_bayesian_multitask +.. automodule:: botorch.models.map_saas :members: -Relevance Pursuit Models +Variational GP Models ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.relevance_pursuit +.. automodule:: botorch.models.approximate_gp :members: Model Components diff --git a/test/models/test_fully_bayesian.py b/test/models/test_fully_bayesian.py index 43689428de..1c2fcc0f85 100644 --- a/test/models/test_fully_bayesian.py +++ b/test/models/test_fully_bayesian.py @@ -44,6 +44,7 @@ from botorch.models.deterministic import GenericDeterministicModel from botorch.models.fully_bayesian import ( FullyBayesianLinearSingleTaskGP, + FullyBayesianSingleTaskGP, LinearPyroModel, MCMC_DIM, MIN_INFERRED_NOISE_LEVEL, @@ -82,20 +83,20 @@ class CustomPyroModel(PyroModel): def sample(self) -> None: pass - def postprocess_mcmc_samples(self, mcmc_samples, **kwargs): + def postprocess_mcmc_samples(self, mcmc_samples, **kwargs) -> None: pass - def load_mcmc_samples(self, mcmc_samples): + def load_mcmc_samples(self, mcmc_samples) -> None: pass class TestSaasFullyBayesianSingleTaskGP(BotorchTestCase): - model_cls = SaasFullyBayesianSingleTaskGP - pyro_model_cls = SaasPyroModel + model_cls: type[FullyBayesianSingleTaskGP] = SaasFullyBayesianSingleTaskGP + pyro_model_cls: type[PyroModel] = SaasPyroModel model_kwargs = {} @property - def expected_keys(self): + def expected_keys(self) -> list[str]: return [ "mean_module.raw_constant", "covar_module.raw_outputscale", @@ -107,17 +108,21 @@ def expected_keys(self): ] @property - def expected_keys_noise(self): + def expected_keys_noise(self) -> list[str]: return self.expected_keys + [ "likelihood.noise_covar.raw_noise", "likelihood.noise_covar.raw_noise_constraint.lower_bound", "likelihood.noise_covar.raw_noise_constraint.upper_bound", ] - def _test_f(self, X): + def _test_f(self, X: torch.Tensor) -> torch.Tensor: return torch.sin(X[:, :1]) - def _get_data_and_model(self, infer_noise: bool, **tkwargs): + def _get_data_and_model( + self, infer_noise: bool, **tkwargs + ) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor | None, FullyBayesianSingleTaskGP + ]: with torch.random.fork_rng(): torch.manual_seed(0) train_X = torch.rand(10, 4, **tkwargs) @@ -135,7 +140,9 @@ def _get_data_and_model(self, infer_noise: bool, **tkwargs): ) return train_X, train_Y, train_Yvar, model - def _get_unnormalized_data(self, infer_noise: bool, **tkwargs): + def _get_unnormalized_data( + self, infer_noise: bool, **tkwargs + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]: with torch.random.fork_rng(): torch.manual_seed(0) train_X = 5 + 5 * torch.rand(10, 4, **tkwargs) @@ -148,7 +155,7 @@ def _get_unnormalized_data(self, infer_noise: bool, **tkwargs): def _get_unnormalized_condition_data( self, num_models: int, num_cond: int, infer_noise: bool, **tkwargs - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: with torch.random.fork_rng(): torch.manual_seed(0) cond_X = 5 + 5 * torch.rand(num_models, num_cond, 4, **tkwargs) @@ -160,7 +167,7 @@ def _get_unnormalized_condition_data( def _get_mcmc_samples( self, num_samples: int, dim: int, infer_noise: bool, **tkwargs - ): + ) -> dict[str, torch.Tensor]: mcmc_samples = { "lengthscale": torch.rand(num_samples, 1, dim, **tkwargs), "outputscale": torch.rand(num_samples, **tkwargs), @@ -170,7 +177,7 @@ def _get_mcmc_samples( mcmc_samples["noise"] = torch.rand(num_samples, 1, **tkwargs) return mcmc_samples - def test_raises(self): + def test_raises(self) -> None: tkwargs = {"device": self.device, "dtype": torch.double} with self.assertRaisesRegex( ValueError, @@ -230,7 +237,7 @@ def test_raises(self): with self.assertRaisesRegex(RuntimeError, not_fitted_error_msg): model.posterior(torch.rand(1, 4, **tkwargs)) - def test_fit_model(self): + def test_fit_model(self) -> None: for infer_noise, dtype in itertools.product( [True, False], [torch.float, torch.double] ): @@ -434,13 +441,14 @@ def test_fit_model(self): # Make sure the model shapes are set correctly self.assertEqual(model.pyro_model.train_X.shape, torch.Size([n, d])) self.assertAllClose(model.pyro_model.train_X, train_X) - model.train() # Put the model in train mode + trained_model = model.train() # Put the model in train mode + self.assertIs(trained_model, model) self.assertAllClose(train_X, model.pyro_model.train_X) self.assertIsNone(model.mean_module) self.assertIsNone(model.covar_module) self.assertIsNone(model.likelihood) - def test_empty(self): + def test_empty(self) -> None: model = self.model_cls( train_X=torch.rand(0, 3), train_Y=torch.rand(0, 1), @@ -451,13 +459,12 @@ def test_empty(self): ) self.assertEqual(model.covar_module.outputscale.shape, torch.Size([2])) - def test_transforms(self): + def test_transforms(self) -> None: for infer_noise in [True, False]: tkwargs = {"device": self.device, "dtype": torch.double} train_X, train_Y, train_Yvar, test_X = self._get_unnormalized_data( infer_noise=infer_noise, **tkwargs ) - n, d = train_X.shape lb, ub = train_X.min(dim=0).values, train_X.max(dim=0).values mu, sigma = train_Y.mean(), train_Y.std() @@ -515,7 +522,7 @@ def test_transforms(self): self.assertIsInstance(tf, Normalize) self.assertEqual(tf.center, 0.0) - def test_acquisition_functions(self): + def test_acquisition_functions(self) -> None: tkwargs = {"device": self.device, "dtype": torch.double} train_X, train_Y, train_Yvar, model = self._get_data_and_model( infer_noise=True, **tkwargs @@ -654,15 +661,14 @@ def test_acquisition_functions(self): ) self.assertTrue(X_pruned.ndim == 2 and X_pruned.shape[-1] == 4) - def test_load_samples(self): + def test_load_samples(self) -> None: for infer_noise, dtype in itertools.product( [True, False], [torch.float, torch.double] ): tkwargs = {"device": self.device, "dtype": dtype} - train_X, train_Y, train_Yvar, model = self._get_data_and_model( + train_X, _, train_Yvar, model = self._get_data_and_model( infer_noise=infer_noise, **tkwargs ) - n, d = train_X.shape mcmc_samples = self._get_mcmc_samples( num_samples=3, dim=train_X.shape[-1], infer_noise=infer_noise, **tkwargs ) @@ -698,7 +704,7 @@ def test_load_samples(self): self.assertAllClose(warp.concentration0, mcmc_samples["c0"]) self.assertAllClose(warp.concentration1, mcmc_samples["c1"]) - def test_construct_inputs(self): + def test_construct_inputs(self) -> None: for infer_noise, dtype in itertools.product( (True, False), (torch.float, torch.double) ): @@ -718,7 +724,7 @@ def test_construct_inputs(self): else: self.assertTrue(Yvar.equal(data_dict["train_Yvar"])) - def test_custom_pyro_model(self): + def test_custom_pyro_model(self) -> None: for infer_noise, dtype in itertools.product( (True, False), (torch.float, torch.double) ): @@ -769,7 +775,7 @@ def test_custom_pyro_model(self): atol=5e-4, ) - def test_condition_on_observation(self): + def test_condition_on_observation(self) -> None: # The following conditioned data shapes should work (output describes): # training data shape after cond(batch shape in output is req. in gpytorch) # X: num_models x n x d, Y: num_models x n x d --> num_models x n x d @@ -884,7 +890,7 @@ def test_condition_on_observation(self): torch.Size([num_models, num_train + 3 * num_cond, num_dims]), ) - def test_bisect(self): + def test_bisect(self) -> None: def f(x): return 1 + x @@ -940,14 +946,14 @@ def test_deprecated_posterior(self) -> None: class TestPyroCatchNumericalErrors(BotorchTestCase): - def tearDown(self): + def tearDown(self) -> None: super().tearDown() # Remove exception handler so they don't affect the tests on rerun # TODO: Add functionality to pyro to clear the handlers so this # does not require touching the internals. del _EXCEPTION_HANDLERS["foo_runtime"] - def test_pyro_catch_error(self): + def test_pyro_catch_error(self) -> None: def potential_fn(z): mvn = pyro.distributions.MultivariateNormal( loc=torch.zeros(2), @@ -970,7 +976,7 @@ def potential_fn(z): # Default behavior should catch the LinAlgError when peforming a # Cholesky decomposition and return NaN instead - def potential_fn_chol(z): + def potential_fn_chol(z) -> torch.Tensor: return torch.linalg.cholesky(z["K"]) _, val = potential_grad(potential_fn_chol, z) @@ -984,7 +990,7 @@ def potential_fn_rterr_foo(z): potential_grad(potential_fn_rterr_foo, z) # But once we register this specific error then it should - def catch_runtime_error(e): + def catch_runtime_error(e) -> bool: return type(e) is RuntimeError and "foo" in str(e) register_exception_handler("foo_runtime", catch_runtime_error) @@ -1008,7 +1014,7 @@ def _test_f(self, X): return X.sum(dim=-1, keepdim=True) @property - def expected_keys(self): + def expected_keys(self) -> list[str]: expected_keys = [ "mean_module.raw_constant", "covar_module.raw_variance", @@ -1044,7 +1050,7 @@ def _get_mcmc_samples( dim: int, infer_noise: bool, **tkwargs, - ): + ) -> dict[str, torch.Tensor]: mcmc_samples = { "weight_variance": torch.rand(num_samples, 1, dim, **tkwargs), "mean": torch.randn(num_samples, **tkwargs), @@ -1056,11 +1062,11 @@ def _get_mcmc_samples( mcmc_samples[k] = torch.rand(num_samples, 1, dim, **tkwargs) return mcmc_samples - def test_custom_pyro_model(self): + def test_custom_pyro_model(self) -> None: # custom pyro models are not supported by FullyBayesianLinearSingleTaskGP pass - def test_empty(self): + def test_empty(self) -> None: # TODO: support empty models with LinearKernels pass diff --git a/test/models/test_map_saas.py b/test/models/test_map_saas.py new file mode 100644 index 0000000000..dc6fe6a506 --- /dev/null +++ b/test/models/test_map_saas.py @@ -0,0 +1,694 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import math +import pickle +from itertools import product +from typing import Any +from unittest import mock + +import torch + +from botorch.exceptions import UnsupportedError +from botorch.fit import ( + fit_gpytorch_mll, + get_fitted_map_saas_ensemble, + get_fitted_map_saas_model, +) +from botorch.models import SaasFullyBayesianSingleTaskGP, SingleTaskGP +from botorch.models.map_saas import ( + add_saas_prior, + AdditiveMapSaasSingleTaskGP, + get_additive_map_saas_covar_module, + get_gaussian_likelihood_with_gamma_prior, + get_mean_module_with_normal_prior, +) +from botorch.models.transforms.input import AppendFeatures, FilterFeatures, Normalize +from botorch.models.transforms.outcome import Standardize +from botorch.optim.utils import get_parameters_and_bounds, sample_all_priors +from botorch.posteriors.gpytorch import GPyTorchPosterior +from botorch.utils.constraints import LogTransformedInterval +from botorch.utils.testing import BotorchTestCase +from gpytorch.constraints import Interval +from gpytorch.kernels import AdditiveKernel, MaternKernel, ScaleKernel +from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood +from gpytorch.means import ConstantMean +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood +from gpytorch.priors import GammaPrior, HalfCauchyPrior, NormalPrior +from torch import Tensor + + +class TestMapSaas(BotorchTestCase): + def _get_data(self, **tkwargs) -> tuple[Tensor, Tensor, Tensor]: + with torch.random.fork_rng(): + torch.manual_seed(0) + train_X = 1 + 2 * torch.rand(10, 3, **tkwargs) + train_Y = torch.sin(train_X[:, :1]) + test_X = 1 + 2 * torch.rand(5, 3, **tkwargs) + return train_X, train_Y, test_X + + def _get_data_hardcoded(self, **tkwargs) -> tuple[Tensor, Tensor, Tensor]: + """This is equal to _get_data on CPU with a seed of 0, and is hard-coded here + to ensure that the results are identical on GPUs, which have different RNGs. + """ + train_X = torch.tensor( + [ + [2.9401, 2.4156, 1.9188], + [2.8415, 2.2900, 2.5823], + [1.3572, 1.7022, 2.1627], + [1.5765, 1.9057, 1.3536], + [1.7105, 2.2438, 1.9637], + [1.8816, 1.8146, 1.4109], + [2.3301, 2.5697, 1.4207], + [2.3535, 1.2195, 2.0475], + [1.4520, 2.1165, 2.1751], + [2.3639, 2.4908, 1.4553], + ], + **tkwargs, + ) + train_Y = torch.tensor( + [ + [0.2001], + [0.2956], + [0.9773], + [1.0000], + [0.9903], + [0.9521], + [0.7253], + [0.7090], + [0.9930], + [0.7016], + ], + **tkwargs, + ) + test_X = torch.tensor( + [ + [2.6198, 2.2612, 1.3918], + [1.3055, 1.9630, 2.8350], + [2.1441, 1.9617, 2.5921], + [1.7359, 1.5444, 1.0829], + [2.4974, 1.6835, 2.0765], + ], + **tkwargs, + ) + return train_X, train_Y, test_X + + def test_add_saas_prior(self) -> None: + for dtype, infer_tau in itertools.product( + [torch.float, torch.double], [True, False] + ): + tkwargs = {"device": self.device, "dtype": dtype} + train_X, _, _ = self._get_data(**tkwargs) + base_kernel = MaternKernel(nu=2.5, ard_num_dims=train_X.shape[-1]).to( + **tkwargs + ) + tau = None if infer_tau else 0.1234 + add_saas_prior(base_kernel=base_kernel, tau=tau) + pickle.loads(pickle.dumps(base_kernel)) # pickle and unpickle should work + if not infer_tau: # Make sure there is no raw_tau parameter + self.assertFalse(hasattr(base_kernel, "raw_tau")) + else: + self.assertTrue(hasattr(base_kernel, "raw_tau")) + self.assertIsInstance(base_kernel.tau_prior, HalfCauchyPrior) + self.assertAlmostEqual(base_kernel.tau_prior.scale.item(), 0.1) + # Make sure there is a constraint on tau + self.assertIsInstance(base_kernel.raw_tau_constraint, Interval) + self.assertEqual(base_kernel.raw_tau_constraint.lower_bound, 1e-3) + self.assertEqual(base_kernel.raw_tau_constraint.upper_bound, 10.0) + self.assertIsInstance(base_kernel.inv_lengthscale_prior, HalfCauchyPrior) + self.assertAlmostEqual(base_kernel.inv_lengthscale_prior.scale.item(), 1.0) + # Make sure we have specified a constraint on the lengthscale + self.assertIsInstance(base_kernel.raw_lengthscale_constraint, Interval) + self.assertEqual(base_kernel.raw_lengthscale_constraint.lower_bound, 1e-2) + self.assertEqual(base_kernel.raw_lengthscale_constraint.upper_bound, 1e4) + # Lengthscale closures + _inv_lengthscale_prior = base_kernel._priors["inv_lengthscale_prior"] + self.assertIsInstance(_inv_lengthscale_prior[0], HalfCauchyPrior) + self.assertAlmostEqual(_inv_lengthscale_prior[0].scale.item(), 1.0) + base_kernel.lengthscale = 0.5678 + true_value = (0.1 if infer_tau else tau) / (0.5678**2) # tau / ell^2 + self.assertAllClose( + _inv_lengthscale_prior[1](base_kernel), + true_value * torch.ones(1, train_X.shape[-1], **tkwargs), + ) + _inv_lengthscale_prior[2](base_kernel, torch.tensor(5.55, **tkwargs)) + true_value = math.sqrt((0.1 if infer_tau else tau) / 5.55) + self.assertAllClose( + base_kernel.lengthscale, + true_value * torch.ones(1, train_X.shape[-1], **tkwargs), + ) + if infer_tau: # Global shrinkage closures + _tau_prior = base_kernel._priors["tau_prior"] + self.assertIsInstance(_tau_prior[0], HalfCauchyPrior) + self.assertAlmostEqual(_tau_prior[0].scale.item(), 0.1) + _tau_prior[2](base_kernel, torch.tensor(1.234, **tkwargs)) + self.assertAlmostEqual( + _tau_prior[1](base_kernel).item(), 1.234, delta=1e-6 + ) + + with self.assertRaisesRegex(UnsupportedError, "must have lengthscale"): + add_saas_prior(base_kernel=ScaleKernel(base_kernel)) + + kernel_with_prior = MaternKernel( + nu=2.5, + ard_num_dims=train_X.shape[-1], + lengthscale_prior=GammaPrior(3.0, 6.0), + ).to(**tkwargs) + with self.assertRaisesRegex( + UnsupportedError, "must not specify a lengthscale prior" + ): + add_saas_prior(base_kernel=kernel_with_prior) + + def test_get_saas_model(self) -> None: + for infer_tau, infer_noise in itertools.product([True, False], [True, False]): + tkwargs = {"device": self.device, "dtype": torch.double} + train_X, train_Y, test_X = self._get_data_hardcoded(**tkwargs) + + lb, ub = train_X.min(dim=0).values, train_X.max(dim=0).values + mu, sigma = train_Y.mean(), train_Y.std() + d = train_X.shape[-1] + tau = None if infer_tau else 0.1234 + train_Yvar = ( + None + if infer_noise + else 0.1 * torch.arange(len(train_X), **tkwargs).unsqueeze(-1) + ) + # Fit with transforms + model = get_fitted_map_saas_model( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + input_transform=Normalize(d=d), + outcome_transform=Standardize(m=1), + tau=tau, + ) + posterior = model.posterior(test_X) + pred_mean, pred_var = posterior.mean, posterior.variance + # Make sure the lengthscales are reasonable + self.assertTrue( + (model.covar_module.base_kernel.lengthscale[:, 1:] > 1e2).all() + ) + self.assertTrue(model.covar_module.base_kernel.lengthscale[:, 0] < 10) + # Test with fitting without transforms and make sure predictions match + model2 = get_fitted_map_saas_model( + train_X=(train_X - lb) / (ub - lb), + train_Y=(train_Y - mu) / sigma, + train_Yvar=( + train_Yvar / (sigma**2) if train_Yvar is not None else train_Yvar + ), + tau=tau, + ) + posterior2 = model2.posterior((test_X - lb) / (ub - lb)) + pred_mean2 = mu + sigma * posterior2.mean + pred_var2 = (sigma**2) * posterior2.variance + self.assertAllClose(pred_mean, pred_mean2) + self.assertAllClose(pred_var, pred_var2) + + # testing optimizer_options: short optimization run with maxiter = 3 + fit_gpytorch_mll_mock = mock.Mock(wraps=fit_gpytorch_mll) + with mock.patch( + "botorch.fit.fit_gpytorch_mll", + new=fit_gpytorch_mll_mock, + ): + maxiter = 3 + model_short = get_fitted_map_saas_model( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + input_transform=Normalize(d=d), + outcome_transform=Standardize(m=1), + tau=tau, + optimizer_kwargs={"options": {"maxiter": maxiter}}, + ) + kwargs = fit_gpytorch_mll_mock.call_args.kwargs + # fit_gpytorch_mll has "option" kwarg, not "optimizer_options" + self.assertEqual( + kwargs["optimizer_kwargs"]["options"]["maxiter"], maxiter + ) + + # Compute marginal likelihood after short run. + # Putting the MLL in train model to silence warnings. + mll_short = ExactMarginalLogLikelihood( + model=model_short, likelihood=model_short.likelihood + ).train() + train_inputs = mll_short.model.train_inputs + train_targets = mll_short.model.train_targets + output = mll_short.model(*train_inputs) + loss_short = -mll_short(output, train_targets).item() + + # Make sure the correct bounds are extracted + _, bounds = get_parameters_and_bounds(mll_short) + if infer_noise: + self.assertAllClose( + bounds["likelihood.noise_covar.raw_noise"][0].item(), math.log(1e-4) + ) + self.assertAllClose( + bounds["likelihood.noise_covar.raw_noise"][1].item(), math.log(1) + ) + self.assertAllClose( + bounds["model.mean_module.raw_constant"][0].item(), -10.0 + ) + self.assertAllClose( + bounds["model.mean_module.raw_constant"][1].item(), 10.0 + ) + self.assertAllClose( + bounds["model.covar_module.raw_outputscale"][0].item(), math.log(1e-2) + ) + self.assertAllClose( + bounds["model.covar_module.raw_outputscale"][1].item(), math.log(1e4) + ) + self.assertAllClose( + bounds["model.covar_module.base_kernel.raw_lengthscale"][0].item(), + math.log(1e-2), + ) + self.assertAllClose( + bounds["model.covar_module.base_kernel.raw_lengthscale"][1].item(), + math.log(1e4), + ) + if infer_tau: + self.assertAllClose( + bounds["model.covar_module.base_kernel.raw_tau"][0].item(), + math.log(1e-3), + ) + self.assertAllClose( + bounds["model.covar_module.base_kernel.raw_tau"][1].item(), + math.log(10.0), + ) + + # compute marginal likelihood after standard run + mll = ExactMarginalLogLikelihood( + model=model, likelihood=model.likelihood + ).train() + # reusing train_inputs and train_targets, since the transforms are the same + loss = -mll(model(*train_inputs), train_targets).item() + # longer running optimization should have smaller loss than the shorter one + self.assertTrue(loss < loss_short) + + def test_get_saas_ensemble(self) -> None: + for infer_noise, taus in itertools.product([True, False], [None, [0.1, 0.2]]): + tkwargs = {"device": self.device, "dtype": torch.double} + train_X, train_Y, _ = self._get_data_hardcoded(**tkwargs) + d = train_X.shape[-1] + train_Yvar = ( + None + if infer_noise + else 0.1 * torch.arange(len(train_X), **tkwargs).unsqueeze(-1) + ) + # Fit without specifying tau + with torch.random.fork_rng(): + torch.manual_seed(0) + model = get_fitted_map_saas_ensemble( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + input_transform=Normalize(d=d), + outcome_transform=Standardize(m=1), + taus=taus, + ) + self.assertIsInstance(model, SaasFullyBayesianSingleTaskGP) + num_taus = 4 if taus is None else len(taus) + self.assertEqual( + model.covar_module.base_kernel.lengthscale.shape, + torch.Size([num_taus, 1, d]), + ) + self.assertEqual(model.batch_shape, torch.Size([num_taus])) + # Make sure the lengthscales are reasonable + self.assertGreater( + model.covar_module.base_kernel.lengthscale[..., 1:].min(), 50 + ) + self.assertLess( + model.covar_module.base_kernel.lengthscale[..., 0].max(), 10 + ) + + # testing optimizer_options: short optimization run with maxiter = 3 + with torch.random.fork_rng(): + torch.manual_seed(0) + fit_gpytorch_mll_mock = mock.Mock(wraps=fit_gpytorch_mll) + with mock.patch( + "botorch.fit.fit_gpytorch_mll", + new=fit_gpytorch_mll_mock, + ): + maxiter = 3 + model_short = get_fitted_map_saas_ensemble( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + input_transform=Normalize(d=d), + outcome_transform=Standardize(m=1), + taus=taus, + optimizer_kwargs={"options": {"maxiter": maxiter}}, + ) + kwargs = fit_gpytorch_mll_mock.call_args.kwargs + # fit_gpytorch_mll has "option" kwarg, not "optimizer_options" + self.assertEqual( + kwargs["optimizer_kwargs"]["options"]["maxiter"], maxiter + ) + + # compute sum of marginal likelihoods of ensemble after short run + # NOTE: We can't put MLL in train mode here since + # SaasFullyBayesianSingleTaskGP requires NUTS for training. + mll_short = ExactMarginalLogLikelihood( + model=model_short, likelihood=model_short.likelihood + ) + train_inputs = mll_short.model.train_inputs + train_targets = mll_short.model.train_targets + loss_short = -mll_short(model_short(*train_inputs), train_targets) + # compute sum of marginal likelihoods of ensemble after standard run + mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood) + # reusing train_inputs and train_targets, since the transforms are the same + loss = -mll(model(*train_inputs), train_targets) + # the longer running optimization should have smaller loss than the shorter + self.assertLess((loss - loss_short).max(), 0.0) + + # test error message + with self.assertRaisesRegex( + ValueError, "if you only specify one value of tau" + ): + model_short = get_fitted_map_saas_ensemble( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + input_transform=Normalize(d=d), + outcome_transform=Standardize(m=1), + taus=[0.1], + ) + + def test_input_transform_in_train(self) -> None: + train_X, train_Y, test_X = self._get_data() + # Use a transform that only works in eval mode. + append_tf = AppendFeatures(feature_set=torch.randn_like(train_X)).eval() + with mock.patch.object(SingleTaskGP, "_validate_tensor_args") as mock_validate: + get_fitted_map_saas_model( + train_X=train_X, + train_Y=train_Y, + input_transform=append_tf, + outcome_transform=Standardize(m=1), + ) + call_X = mock_validate.call_args[1]["X"] + self.assertTrue(torch.equal(call_X, train_X)) + + def test_filterfeatures_input_transform(self) -> None: + train_X, train_Y, test_X = self._get_data() + idxs_to_filter = [0, 2] + filter_feature_transforn = FilterFeatures( + feature_indices=torch.tensor(idxs_to_filter) + ) + # Use a transform that only works in eval mode. + model = get_fitted_map_saas_model( + train_X=train_X, + train_Y=train_Y, + input_transform=filter_feature_transforn, + outcome_transform=Standardize(m=1), + ) + self.assertTrue(model.train_inputs[0].shape[-1] == len(idxs_to_filter)) + self.assertAllClose(model.train_inputs[0], train_X[:, idxs_to_filter]) + self.assertTrue( + model.covar_module.base_kernel.lengthscale.shape[-1] == len(idxs_to_filter) + ) + + def test_batch_model_fitting(self) -> None: + tkwargs = {"device": self.device, "dtype": torch.double} + with torch.random.fork_rng(): + torch.manual_seed(0) + train_X = 1 + 2 * torch.rand(10, 3, **tkwargs) + train_Y = torch.cat( + (torch.sin(train_X[:, :1]), torch.cos(train_X[:, :1])), dim=-1 + ) + + for tau in [0.1, None]: + batch_model = get_fitted_map_saas_model( + train_X=train_X, train_Y=train_Y, tau=tau + ) + model_1 = get_fitted_map_saas_model( + train_X=train_X, train_Y=train_Y[:, :1], tau=tau + ) + model_2 = get_fitted_map_saas_model( + train_X=train_X, train_Y=train_Y[:, 1:], tau=tau + ) + # Check lengthscales + self.assertEqual( + batch_model.covar_module.base_kernel.lengthscale.shape, + torch.Size([2, 1, 3]), + ) + self.assertEqual( + model_1.covar_module.base_kernel.lengthscale.shape, + torch.Size([1, 3]), + ) + self.assertEqual( + model_2.covar_module.base_kernel.lengthscale.shape, + torch.Size([1, 3]), + ) + self.assertAllClose( + batch_model.covar_module.base_kernel.lengthscale[0, :], + model_1.covar_module.base_kernel.lengthscale, + atol=1e-3, + ) + self.assertAllClose( + batch_model.covar_module.base_kernel.lengthscale[1, :], + model_2.covar_module.base_kernel.lengthscale, + atol=1e-3, + ) + # Check the outputscale + self.assertEqual( + batch_model.covar_module.outputscale.shape, torch.Size([2]) + ) + self.assertEqual(model_1.covar_module.outputscale.shape, torch.Size([])) + self.assertEqual(model_2.covar_module.outputscale.shape, torch.Size([])) + self.assertAllClose( + batch_model.covar_module.outputscale, + torch.stack( + ( + model_1.covar_module.outputscale, + model_2.covar_module.outputscale, + ) + ), + atol=1e-3, + ) + # Check the mean + self.assertEqual(batch_model.mean_module.constant.shape, torch.Size([2])) + self.assertEqual(model_1.mean_module.constant.shape, torch.Size([])) + self.assertEqual(model_2.mean_module.constant.shape, torch.Size([])) + self.assertAllClose( + batch_model.mean_module.constant, + torch.stack( + (model_1.mean_module.constant, model_2.mean_module.constant) + ), + atol=1e-3, + ) + # Check noise + self.assertEqual(batch_model.likelihood.noise.shape, torch.Size([2, 1])) + self.assertEqual(model_1.likelihood.noise.shape, torch.Size([1])) + self.assertEqual(model_2.likelihood.noise.shape, torch.Size([1])) + self.assertAllClose( + batch_model.likelihood.noise, + torch.stack((model_1.likelihood.noise, model_2.likelihood.noise)), + atol=1e-3, + ) + # Check tau + if tau is None: + self.assertEqual( + batch_model.covar_module.base_kernel.raw_tau.shape, torch.Size([2]) + ) + self.assertEqual( + model_1.covar_module.base_kernel.raw_tau.shape, torch.Size([]) + ) + self.assertEqual( + model_2.covar_module.base_kernel.raw_tau.shape, torch.Size([]) + ) + self.assertAllClose( + batch_model.covar_module.base_kernel.raw_tau, + torch.stack( + ( + model_1.covar_module.base_kernel.raw_tau, + model_2.covar_module.base_kernel.raw_tau, + ) + ), + atol=1e-3, + ) + + +class TestAdditiveMapSaasSingleTaskGP(BotorchTestCase): + def _get_data_and_model( + self, + infer_noise: bool, + m: int = 1, + batch_shape: list[int] | None = None, + **tkwargs, + ): + batch_shape = batch_shape or [] + with torch.random.fork_rng(): + torch.manual_seed(0) + train_X = torch.rand(batch_shape + [10, 4], **tkwargs) + train_Y = ( + torch.sin(train_X) + .sum(dim=-1, keepdim=True) + .repeat(*[1] * (train_X.ndim - 1), m) + ) + train_Yvar = ( + None if infer_noise else torch.rand(batch_shape + [10, m], **tkwargs) + ) + model = AdditiveMapSaasSingleTaskGP( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + ) + return train_X, train_Y, train_Yvar, model + + def test_construct_mean_module(self) -> None: + tkwargs = {"device": self.device, "dtype": torch.double} + for batch_shape in [None, torch.Size([5])]: + mean_module = get_mean_module_with_normal_prior(batch_shape=batch_shape).to( + **tkwargs + ) + self.assertIsInstance(mean_module, ConstantMean) + self.assertIsInstance(mean_module.mean_prior, NormalPrior) + self.assertEqual(mean_module.mean_prior.loc, torch.zeros(1, **tkwargs)) + self.assertEqual(mean_module.mean_prior.scale, torch.ones(1, **tkwargs)) + self.assertEqual( + mean_module.raw_constant.shape, batch_shape or torch.Size() + ) + self.assertEqual(mean_module.raw_constant_constraint.lower_bound, -10.0) + self.assertEqual(mean_module.raw_constant_constraint.upper_bound, 10.0) + + def test_construct_likelihood(self) -> None: + tkwargs = {"device": self.device, "dtype": torch.double} + for batch_shape in [None, torch.Size([5])]: + likelihood = get_gaussian_likelihood_with_gamma_prior( + batch_shape=batch_shape + ).to(**tkwargs) + self.assertIsInstance(likelihood, GaussianLikelihood) + self.assertIsInstance(likelihood.noise_covar.noise_prior, GammaPrior) + self.assertAllClose( + likelihood.noise_covar.noise_prior.concentration, + torch.tensor(0.9, **tkwargs), + ) + self.assertAllClose( + likelihood.noise_covar.noise_prior.rate, torch.tensor(10, **tkwargs) + ) + self.assertEqual( + likelihood.noise_covar.raw_noise.shape, + torch.Size([1]) if batch_shape is None else torch.Size([5, 1]), + ) + self.assertIsInstance( + likelihood.noise_covar.raw_noise_constraint, LogTransformedInterval + ) + self.assertAllClose( + likelihood.noise_covar.raw_noise_constraint.lower_bound.item(), 1e-4 + ) + self.assertAllClose( + likelihood.noise_covar.raw_noise_constraint.upper_bound.item(), 1.0 + ) + + def test_construct_covar_module(self) -> None: + tkwargs = {"device": self.device, "dtype": torch.double} + for batch_shape in [None, torch.Size([5])]: + covar_module = get_additive_map_saas_covar_module( + ard_num_dims=10, num_taus=4, batch_shape=batch_shape + ).to(**tkwargs) + self.assertIsInstance(covar_module, AdditiveKernel) + self.assertEqual(len(covar_module.kernels), 4) + for kernel in covar_module.kernels: + self.assertIsInstance(kernel, ScaleKernel) + self.assertIsInstance(kernel.base_kernel, MaternKernel) + expected_shape = ( + torch.Size([1, 10]) + if batch_shape is None + else torch.Size([5, 1, 10]) + ) + self.assertEqual(kernel.base_kernel.lengthscale.shape, expected_shape) + # Check for a SAAS prior + self.assertFalse(hasattr(kernel.base_kernel, "raw_tau")) + self.assertIsInstance( + kernel.base_kernel.inv_lengthscale_prior, HalfCauchyPrior + ) + self.assertAlmostEqual( + kernel.base_kernel.inv_lengthscale_prior.scale.item(), 1.0 + ) + + def test_fit_model(self) -> None: + for infer_noise, m, batch_shape in ( + (True, 1, None), + (False, 1, [5]), + (True, 2, None), + (False, 2, [3]), + (True, 3, [5]), + ): + tkwargs = {"device": self.device, "dtype": torch.double} + train_X, train_Y, train_Yvar, model = self._get_data_and_model( + infer_noise=infer_noise, m=m, batch_shape=batch_shape, **tkwargs + ) + expected_batch_shape = ( + torch.Size(batch_shape) if batch_shape else torch.Size() + ) + expected_aug_batch_shape = expected_batch_shape + torch.Size( + [m] if m > 1 else [] + ) + + # Test init + self.assertIsInstance(model.mean_module, ConstantMean) + self.assertIsInstance(model.covar_module, AdditiveKernel) + self.assertIsInstance( + model.likelihood, + GaussianLikelihood if infer_noise else FixedNoiseGaussianLikelihood, + ) + expected_X, expected_Y, expected_Yvar = model._transform_tensor_args( + X=train_X, Y=train_Y, Yvar=train_Yvar + ) + self.assertAllClose(expected_X, model.train_inputs[0]) + self.assertAllClose(expected_Y, model.train_targets) + if not infer_noise: + self.assertAllClose(model.likelihood.noise_covar.noise, expected_Yvar) + + # Fit a model + mll = ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model) + fit_gpytorch_mll(mll, optimizer_kwargs={"options": {"maxiter": 5}}) + self.assertEqual(model.batch_shape, expected_batch_shape) + self.assertEqual(model._aug_batch_shape, expected_aug_batch_shape) + + # Predict on some test points + test_X = torch.rand(13, train_X.shape[-1], **tkwargs) + posterior = model.posterior(test_X) + self.assertIsInstance(posterior, GPyTorchPosterior) + # Mean/variance + expected_shape = (*expected_batch_shape, 13, m) + mean, var = posterior.mean, posterior.variance + self.assertEqual(mean.shape, torch.Size(expected_shape)) + self.assertEqual(var.shape, torch.Size(expected_shape)) + + # Test AdditiveMapSaasSingleTaskGP constructor + input_transform = Normalize(d=train_X.shape[-1]) + outcome_transform = Standardize(m=m, batch_shape=expected_batch_shape) + with mock.patch.object( + SingleTaskGP, "__init__", wraps=SingleTaskGP.__init__ + ) as mock_init: + model = AdditiveMapSaasSingleTaskGP( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + input_transform=input_transform, + outcome_transform=outcome_transform, + num_taus=3, + ) + self.assertEqual( + input_transform, mock_init.call_args[1]["input_transform"] + ) + self.assertEqual( + outcome_transform, mock_init.call_args[1]["outcome_transform"] + ) + self.assertIsInstance( + mock_init.call_args[1]["covar_module"], AdditiveKernel + ) + self.assertEqual(3, len(mock_init.call_args[1]["covar_module"].kernels)) + + def test_sample_from_prior_additive_map_saas(self) -> None: + tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double} + for batch, m in product((torch.Size([]), torch.Size([3])), (1, 2)): + train_X = torch.rand(*batch, 10, 4, **tkwargs) + train_Y = torch.rand(*batch, 10, m, **tkwargs) + for _ in range(10): + model = AdditiveMapSaasSingleTaskGP(train_X=train_X, train_Y=train_Y) + sample_all_priors(model) diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 3093650439..e37c9c73f2 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -1036,7 +1036,7 @@ def test_warp_transform(self) -> None: eps = 1e-6 if dtype == torch.double else 1e-5 if dtype == torch.float32: # defaults are 1e-5, 1e-8 - tols = {"rtol": 2e-5, "atol": 8e-8} + tols = {"rtol": 1e-4, "atol": 5e-7} else: tols = {} diff --git a/test/utils/test_constraints.py b/test/utils/test_constraints.py index c18278a0d4..1988c8b86b 100644 --- a/test/utils/test_constraints.py +++ b/test/utils/test_constraints.py @@ -6,7 +6,10 @@ import torch from botorch.utils import get_outcome_constraint_transforms -from botorch.utils.constraints import get_monotonicity_constraints +from botorch.utils.constraints import ( + get_monotonicity_constraints, + LogTransformedInterval, +) from botorch.utils.testing import BotorchTestCase @@ -80,3 +83,17 @@ def test_get_monotonicity_constraints(self): Ad, bd = get_monotonicity_constraints(d, descending=True, **tkwargs) self.assertAllClose(Ad, -A) self.assertAllClose(bd, b) + + def test_log_transformed_interval(self): + constraint = LogTransformedInterval( + lower_bound=0.1, upper_bound=0.2, initial_value=0.15 + ) + x = torch.tensor(0.1, device=self.device) + self.assertAllClose(constraint.transform(x), x.exp()) + self.assertAllClose(constraint.inverse_transform(constraint.transform(x)), x) + with self.assertRaisesRegex( + RuntimeError, "Cannot make an Interval directly with non-finite bounds" + ): + constraint = LogTransformedInterval( + lower_bound=-torch.inf, upper_bound=torch.inf, initial_value=0.15 + )