Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankWanger authored Jan 27, 2025
2 parents e5d7fd3 + 90fc872 commit fd6b680
Show file tree
Hide file tree
Showing 10 changed files with 1,444 additions and 77 deletions.
137 changes: 135 additions & 2 deletions botorch/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
from typing import Any
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage

import torch

from botorch.exceptions.errors import ModelFittingError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.logging import logger
from botorch.models import SingleTaskGP
from botorch.models.approximate_gp import ApproximateGPyTorchModel
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian import FullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.models.map_saas import get_map_saas_model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.optim.closures import get_loss_closure_with_grads
from botorch.optim.core import _LBFGSB_MAXITER_MAXFUN_REGEX
from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch
Expand All @@ -38,11 +44,13 @@
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from linear_operator.utils.errors import NotPSDError
from pyro.infer.mcmc import MCMC, NUTS
from torch import device, Tensor
from torch.distributions import HalfCauchy
from torch.nn import Parameter
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -326,7 +334,7 @@ def _fit_fallback_approximate(


def fit_fully_bayesian_model_nuts(
model: SaasFullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
model: FullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
max_tree_depth: int = 6,
warmup_steps: int = 512,
num_samples: int = 256,
Expand Down Expand Up @@ -382,3 +390,128 @@ def fit_fully_bayesian_model_nuts(
# Load the MCMC samples back into the BoTorch model
model.load_mcmc_samples(mcmc_samples)
model.eval()


def get_fitted_map_saas_model(
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor | None = None,
input_transform: InputTransform | None = None,
outcome_transform: OutcomeTransform | None = None,
tau: float | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
) -> SingleTaskGP:
"""Get a fitted MAP SAAS model with a Matern kernel.
Args:
train_X: Tensor of shape `n x d` with training inputs.
train_Y: Tensor of shape `n x 1` with training targets.
train_Yvar: Optional tensor of shape `n x 1` with observed noise,
inferred if None.
input_transform: An optional input transform.
outcome_transform: An optional outcome transforms.
tau: Fixed value of the global shrinkage tau. If None, the model
places a HC(0.1) prior on tau.
optimizer_kwargs: A dict of options for the optimizer passed
to fit_gpytorch_mll.
Returns:
A fitted SingleTaskGP with a Matern kernel.
"""
# make sure optimizer_kwargs is a Dict
optimizer_kwargs = optimizer_kwargs or {}
model = get_map_saas_model(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
input_transform=(
input_transform.train() if input_transform is not None else None
),
outcome_transform=outcome_transform,
tau=tau,
)
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
return model


def get_fitted_map_saas_ensemble(
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor | None = None,
input_transform: InputTransform | None = None,
outcome_transform: OutcomeTransform | None = None,
taus: Tensor | list[float] | None = None,
num_taus: int = 4,
optimizer_kwargs: dict[str, Any] | None = None,
) -> SaasFullyBayesianSingleTaskGP:
"""Get a fitted SAAS ensemble using several different tau values.
Args:
train_X: Tensor of shape `n x d` with training inputs.
train_Y: Tensor of shape `n x 1` with training targets.
train_Yvar: Optional tensor of shape `n x 1` with observed noise,
inferred if None.
input_transform: An optional input transform.
outcome_transform: An optional outcome transforms.
taus: Global shrinkage values to use. If None, we sample `num_taus` values
from an HC(0.1) distrbution.
num_taus: Optional argument for how many taus to sample.
optimizer_kwargs: A dict of options for the optimizer passed
to fit_gpytorch_mll.
Returns:
A fitted SaasFullyBayesianSingleTaskGP with a Matern kernel.
"""
tkwargs = {"device": train_X.device, "dtype": train_X.dtype}
if taus is None:
taus = HalfCauchy(0.1).sample([num_taus]).to(**tkwargs)
num_samples = len(taus)
if num_samples == 1:
raise ValueError(
"Use `get_fitted_map_saas_model` if you only specify one value of tau"
)

mean = torch.zeros(num_samples, **tkwargs)
outputscale = torch.zeros(num_samples, **tkwargs)
lengthscale = torch.zeros(num_samples, train_X.shape[-1], **tkwargs)
noise = torch.zeros(num_samples, **tkwargs)

# Fit a model for each tau and save the hyperparameters
for i, tau in enumerate(taus):
model = get_fitted_map_saas_model(
train_X,
train_Y,
train_Yvar=train_Yvar,
input_transform=input_transform,
outcome_transform=outcome_transform,
tau=tau,
optimizer_kwargs=optimizer_kwargs,
)
mean[i] = model.mean_module.constant.detach().clone()
outputscale[i] = model.covar_module.outputscale.detach().clone()
lengthscale[i, :] = model.covar_module.base_kernel.lengthscale.detach().clone()
if train_Yvar is None:
noise[i] = model.likelihood.noise.detach().clone()

# Load the samples into a fully Bayesian SAAS model
ensemble_model = SaasFullyBayesianSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
input_transform=(
input_transform.train() if input_transform is not None else None
),
outcome_transform=outcome_transform,
)
mcmc_samples = {
"mean": mean,
"outputscale": outputscale,
"lengthscale": lengthscale,
}
if train_Yvar is None:
mcmc_samples["noise"] = noise
ensemble_model.train()
ensemble_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
ensemble_model.eval()
return ensemble_model
4 changes: 4 additions & 0 deletions botorch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.higher_order_gp import HigherOrderGP

from botorch.models.map_saas import add_saas_prior, AdditiveMapSaasSingleTaskGP
from botorch.models.model import ModelList
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood

__all__ = [
"add_saas_prior",
"AdditiveMapSaasSingleTaskGP",
"AffineDeterministicModel",
"AffineFidelityCostModel",
"ApproximateGPyTorchModel",
Expand Down
32 changes: 22 additions & 10 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import math
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any
from typing import Any, TypeVar

import pyro
import torch
Expand Down Expand Up @@ -67,6 +67,11 @@
from pyro.ops.integrator import register_exception_handler
from torch import Tensor

# Can replace with Self type once 3.11 is the minimum version
TFullyBayesianSingleTaskGP = TypeVar(
"TFullyBayesianSingleTaskGP", bound="FullyBayesianSingleTaskGP"
)

_sqrt5 = math.sqrt(5)


Expand Down Expand Up @@ -572,8 +577,8 @@ def __init__(
validate_input_scaling(
train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar
)
self._num_outputs = train_Y.shape[-1]
self._input_batch_shape = train_X.shape[:-2]
self._num_outputs: int = train_Y.shape[-1]
self._input_batch_shape: torch.Size = train_X.shape[:-2]
if train_Yvar is not None: # Clamp after transforming
train_Yvar = train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL)

Expand All @@ -591,11 +596,11 @@ def __init__(
)
self.pyro_model: PyroModel = pyro_model
if outcome_transform is not None:
self.outcome_transform = outcome_transform
self.outcome_transform: OutcomeTransform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
self.input_transform: InputTransform = input_transform

def _check_if_fitted(self):
def _check_if_fitted(self) -> None:
r"""Raise an exception if the model hasn't been fitted."""
if self.covar_module is None:
raise RuntimeError(
Expand Down Expand Up @@ -623,13 +628,16 @@ def _aug_batch_shape(self) -> torch.Size:
aug_batch_shape += torch.Size([self.num_outputs])
return aug_batch_shape

def train(self, mode: bool = True) -> None:
def train(
self: TFullyBayesianSingleTaskGP, mode: bool = True
) -> TFullyBayesianSingleTaskGP:
r"""Puts the model in `train` mode."""
super().train(mode=mode)
if mode:
self.mean_module = None
self.covar_module = None
self.likelihood = None
return self

def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
r"""Load the MCMC hyperparameter samples into the model.
Expand Down Expand Up @@ -761,7 +769,9 @@ def median_lengthscale(self) -> Tensor:
lengthscale = self.covar_module.base_kernel.lengthscale.clone()
return lengthscale.median(0).values.squeeze(0)

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True
) -> None:
r"""Custom logic for loading the state dict.
The standard approach of calling `load_state_dict` currently doesn't play well
Expand Down Expand Up @@ -886,7 +896,9 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
else:
self.input_transform = input_transform

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True
) -> None:
r"""Custom logic for loading the state dict.
The standard approach of calling `load_state_dict` currently doesn't play well
Expand Down Expand Up @@ -924,7 +936,7 @@ def construct_inputs(
*,
use_input_warping: bool = True,
indices_to_warp: list[int] | None = None,
) -> dict[str, BotorchContainer | Tensor]:
) -> dict[str, BotorchContainer | Tensor | None]:
r"""Construct `SingleTaskGP` keyword arguments from a `SupervisedDataset`.
Args:
Expand Down
Loading

0 comments on commit fd6b680

Please sign in to comment.