Skip to content

Commit

Permalink
Properly handle observed noise in AdditiveMapSaasSingleTaskGP with ou…
Browse files Browse the repository at this point in the history
…tcome transforms (#2763)

Summary:
Pull Request resolved: #2763

Currently the noise is not transformed properly when there are outcome transforms (e.g. `Standardize`). This was very problematic for model fitting.

This also updates the default outcome transform to match `SingleTaskGP`

Reviewed By: saitcakmak

Differential Revision: D70785752

fbshipit-source-id: 6805b42f2667de687b71a90ee3a4937d43c13c35
  • Loading branch information
sdaulton authored and facebook-github-bot committed Mar 7, 2025
1 parent 290c0ba commit a79b050
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
28 changes: 16 additions & 12 deletions botorch/models/map_saas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils.gpytorch_modules import (
get_gaussian_likelihood_with_lognormal_prior,
)
from botorch.utils.constraints import LogTransformedInterval
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.constraints import Interval
from gpytorch.kernels import AdditiveKernel, Kernel, MaternKernel, ScaleKernel
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.priors import GammaPrior, HalfCauchyPrior, NormalPrior
from torch import Tensor
Expand Down Expand Up @@ -363,7 +367,7 @@ def __init__(
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor | None = None,
outcome_transform: OutcomeTransform | None = None,
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
input_transform: InputTransform | None = None,
num_taus: int = 4,
) -> None:
Expand All @@ -373,26 +377,25 @@ def __init__(
train_X: A `batch_shape x n x d` tensor of training features.
train_Y: A `batch_shape x n x m` tensor of training observations.
train_Yvar: A `batch_shape x n x m` tensor of observed noise.
outcome_transform: An optional outcome transform.
outcome_transform: An outcome transform that is applied to the
training data during instantiation and to the posterior during
inference (that is, the `Posterior` obtained by calling
`.posterior` on the model will be on the original scale). We use a
`Standardize` transform if no `outcome_transform` is specified.
Pass down `None` to use no outcome transform.
input_transform: An optional input transform.
num_taus: The number of taus to use (4 if omitted).
"""
self._set_dimensions(train_X=train_X, train_Y=train_Y)
mean_module = get_mean_module_with_normal_prior(
batch_shape=self._aug_batch_shape
)
if train_Yvar is not None:
_, _, train_Yvar = self._transform_tensor_args(
X=train_X, Y=train_Y, Yvar=train_Yvar
)
likelihood = (
FixedNoiseGaussianLikelihood(
noise=train_Yvar, batch_shape=self._aug_batch_shape
)
if train_Yvar is not None
else get_gaussian_likelihood_with_gamma_prior(
get_gaussian_likelihood_with_lognormal_prior(
batch_shape=self._aug_batch_shape
)
if train_Yvar is None
else None
)
covar_module = get_additive_map_saas_covar_module(
ard_num_dims=train_X.shape[-1],
Expand All @@ -409,6 +412,7 @@ def __init__(
self=self,
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
mean_module=mean_module,
covar_module=covar_module,
likelihood=likelihood,
Expand Down
5 changes: 4 additions & 1 deletion test/models/test_map_saas.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,11 @@ def test_fit_model(self) -> None:
model.likelihood,
GaussianLikelihood if infer_noise else FixedNoiseGaussianLikelihood,
)
expected_Y, expected_Yvar = model.outcome_transform(
Y=train_Y, Yvar=train_Yvar
)
expected_X, expected_Y, expected_Yvar = model._transform_tensor_args(
X=train_X, Y=train_Y, Yvar=train_Yvar
X=train_X, Y=expected_Y, Yvar=expected_Yvar
)
self.assertAllClose(expected_X, model.train_inputs[0])
self.assertAllClose(expected_Y, model.train_targets)
Expand Down

0 comments on commit a79b050

Please sign in to comment.