Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add qPosteriorStandardDeviation acquisition function #2634

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions botorch/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
from botorch.acquisition.monte_carlo import (
MCAcquisitionFunction,
qExpectedImprovement,
qLowerConfidenceBound,
qNoisyExpectedImprovement,
qPosteriorStandardDeviation,
qProbabilityOfImprovement,
qSimpleRegret,
qUpperConfidenceBound,
Expand Down Expand Up @@ -120,6 +122,8 @@
"qNegIntegratedPosteriorVariance",
"qProbabilityOfImprovement",
"qSimpleRegret",
"qPosteriorStandardDeviation",
"qLowerConfidenceBound",
"qUpperConfidenceBound",
"ConstrainedMCObjective",
"GenericMCObjective",
Expand Down
79 changes: 73 additions & 6 deletions botorch/acquisition/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ class qSimpleRegret(SampleReducingMCAcquisitionFunction):
non-negative. `qSimpleRegret` acquisition values can be negative, so we instead use
a `ConstrainedMCObjective` which applies constraints to the objectives (e.g. before
computing the acquisition function) and shifts negative objective values using
by an infeasible cost to ensure non-negativity (before applying constraints and
an infeasible cost to ensure non-negativity (before applying constraints and
shifting them back).

Example:
Expand Down Expand Up @@ -813,11 +813,11 @@ class qUpperConfidenceBound(SampleReducingMCAcquisitionFunction):
`SampleReducingMCAcquisitionFunction` computes the acquisition values on the sample
level and then weights the sample-level acquisition values by a soft feasibility
indicator. Hence, it expects non-log acquisition function values to be
non-negative. `qSimpleRegret` acquisition values can be negative, so we instead use
a `ConstrainedMCObjective` which applies constraints to the objectives (e.g. before
computing the acquisition function) and shifts negative objective values using
by an infeasible cost to ensure non-negativity (before applying constraints and
shifting them back).
non-negative. `qUpperConfidenceBound` acquisition values can be negative, so we
instead use a `ConstrainedMCObjective` which applies constraints to the objectives
(e.g. before computing the acquisition function) and shifts negative objective
values using an infeasible cost to ensure non-negativity (before applying
constraints and shifting them back).

Example:
>>> model = SingleTaskGP(train_X, train_Y)
Expand Down Expand Up @@ -887,3 +887,70 @@ class qLowerConfidenceBound(qUpperConfidenceBound):
def _get_beta_prime(self, beta: float) -> float:
"""Multiply beta prime by -1 to get the lower confidence bound."""
return -super()._get_beta_prime(beta=beta)


class qPosteriorStandardDeviation(SampleReducingMCAcquisitionFunction):
r"""MC-based batch Posterior Standard Deviation.

An acquisition function for pure exploration.

Example:
>>> model = SingleTaskGP(train_X, train_Y)
>>> sampler = SobolQMCNormalSampler(1024)
>>> qPSTD = qPosteriorStandardDeviation(model, sampler)
>>> std = qPSTD(test_X)
"""

def __init__(
self,
model: Model,
sampler: MCSampler | None = None,
objective: MCAcquisitionObjective | None = None,
posterior_transform: PosteriorTransform | None = None,
X_pending: Tensor | None = None,
constraints: list[Callable[[Tensor], Tensor]] | None = None,
eta: Tensor | float = 1e-3,
) -> None:
r"""q-Posterior Standard Deviation.

Args:
model: A fitted model.
sampler: The sampler used to draw base samples. See `MCAcquisitionFunction`
more details.
objective: The MCAcquisitionObjective under which the samples are
evaluated. Defaults to `IdentityMCObjective()`.
posterior_transform: A PosteriorTransform (optional).
X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points that have
points that have been submitted for function evaluation but have not yet
been evaluated. Concatenated into X upon forward call. Copied and set to
have no gradient.
constraints: A list of constraint callables which map a Tensor of posterior
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
are considered satisfied if the output is less than zero.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
"""
super().__init__(
model=model,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
constraints=constraints,
eta=eta,
)
self._scale = math.sqrt(math.pi / 2)

def _sample_forward(self, obj: Tensor) -> Tensor:
r"""Evaluate qPosteriorStandardDeviation per sample on the candidate set `X`.

Args:
obj: A `sample_shape x batch_shape x q`-dim Tensor of MC objective values.

Returns:
A `sample_shape x batch_shape x q`-dim Tensor of acquisition values.
"""
mean = obj.mean(dim=0)
return (obj - mean).abs() * self._scale
10 changes: 10 additions & 0 deletions botorch/models/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ def __init__(self, model: Model) -> None:
def forward(self, X: Tensor) -> Tensor:
return self.model.posterior(X).mean

@property
def num_outputs(self) -> int:
r"""The number of outputs of the model."""
return self.model.num_outputs

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model."""
return self.model.batch_shape


class FixedSingleSampleModel(DeterministicModel):
r"""
Expand Down
3 changes: 2 additions & 1 deletion botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,8 @@ def rsample(
do a shape check but return the same mock samples."""
if sample_shape is None:
sample_shape = torch.Size()
return self._samples.expand(sample_shape + self._samples.shape)
extended_shape = self._extended_shape(sample_shape)
return self._samples.expand(extended_shape)

def rsample_from_base_samples(
self,
Expand Down
118 changes: 118 additions & 0 deletions test/acquisition/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
qExpectedImprovement,
qLowerConfidenceBound,
qNoisyExpectedImprovement,
qPosteriorStandardDeviation,
qProbabilityOfImprovement,
qSimpleRegret,
qUpperConfidenceBound,
Expand All @@ -37,6 +38,7 @@
from botorch.models import SingleTaskGP
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.low_rank import sample_cached_cholesky
from botorch.utils.sampling import draw_sobol_normal_samples
from botorch.utils.test_helpers import DummyNonScalarizingPosteriorTransform
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
from botorch.utils.transforms import standardize
Expand Down Expand Up @@ -1009,6 +1011,121 @@ def test_beta_prime(self):
super().test_beta_prime(negate=True)


class TestQPosteriorStandardDeviation(BotorchTestCase):
def test_q_pstd(self):
n_samples = 128
for dtype in (torch.float, torch.double):
# the event shape is `b x q x t` = 1 x 1 x 1
samples = draw_sobol_normal_samples(
1,
n_samples,
device=self.device,
dtype=dtype,
seed=0,
)[..., None, None]
# samples has shape (n_samples, 1, 1, 1)
std = samples.std(dim=0, correction=0).item()
mm = MockModel(
MockPosterior(samples=samples, base_shape=torch.Size([1, 1, 1]))
)
# X is `q x d` = 1 x 1. X is a dummy and unused b/c of mocking
X = torch.zeros(1, 1, device=self.device, dtype=dtype)

# basic test
sampler = IIDNormalSampler(sample_shape=torch.Size([n_samples]))
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
res = acqf(X)
self.assertAllClose(res.item(), std, rtol=0.02, atol=0)

# basic test
sampler = IIDNormalSampler(sample_shape=torch.Size([n_samples]), seed=12345)
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
res = acqf(X)
self.assertAllClose(res.item(), std, rtol=0.02, atol=0)
self.assertEqual(
acqf.sampler.base_samples.shape, torch.Size([n_samples, 1, 1, 1])
)
bs = acqf.sampler.base_samples.clone()
res = acqf(X)
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))

# basic test, qmc
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([n_samples]))
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
res = acqf(X)
self.assertAllClose(res.item(), std, rtol=0.02, atol=0)
self.assertEqual(
acqf.sampler.base_samples.shape, torch.Size([n_samples, 1, 1, 1])
)
bs = acqf.sampler.base_samples.clone()
acqf(X)
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))

# basic test for X_pending and warning
acqf.set_X_pending()
self.assertIsNone(acqf.X_pending)
acqf.set_X_pending(None)
self.assertIsNone(acqf.X_pending)
acqf.set_X_pending(X)
self.assertEqual(acqf.X_pending, X)
mm._posterior._base_shape = torch.Size([1, 2, 1])
mm._posterior._samples = mm._posterior._samples.expand(n_samples, 1, 2, 1)
res = acqf(X)
X2 = torch.zeros(
1, 1, 1, device=self.device, dtype=dtype, requires_grad=True
)
with warnings.catch_warnings(record=True) as ws:
acqf.set_X_pending(X2)
self.assertEqual(acqf.X_pending, X2)
self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1)

def test_q_pstd_batch(self):
# the event shape is `b x q x t` = 2 x 2 x 1
for dtype in (torch.float, torch.double):
samples = torch.zeros(2, 2, 1, device=self.device, dtype=dtype)
samples[0, 0, 0] = 1.0
mm = MockModel(MockPosterior(samples=samples))
# X is a dummy and unused b/c of mocking
X = torch.zeros(2, 2, 1, device=self.device, dtype=dtype)

# test batch mode
sampler = IIDNormalSampler(sample_shape=torch.Size([8]))
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
res = acqf(X)
self.assertEqual(res[0].item(), 0.0)
self.assertEqual(res[1].item(), 0.0)

# test batch mode
sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345)
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
res = acqf(X) # 1-dim batch
self.assertEqual(res[0].item(), 0.0)
self.assertEqual(res[1].item(), 0.0)
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
bs = acqf.sampler.base_samples.clone()
acqf(X)
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
res = acqf(X.expand(2, -1, 1)) # 2-dim batch
self.assertEqual(res[0].item(), 0.0)
self.assertEqual(res[1].item(), 0.0)
# the base samples should have the batch dim collapsed
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
bs = acqf.sampler.base_samples.clone()
acqf(X.expand(2, -1, 1))
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))

# test batch mode, qmc
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2]))
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
res = acqf(X)
self.assertEqual(res[0].item(), 0.0)
self.assertEqual(res[1].item(), 0.0)
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
bs = acqf.sampler.base_samples.clone()
acqf(X)
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))


class TestMCAcquisitionFunctionWithConstraints(BotorchTestCase):
def test_mc_acquisition_function_with_constraints(self):
for dtype in (torch.float, torch.double):
Expand All @@ -1033,6 +1150,7 @@ def _test_mc_acquisition_function_with_constraints(self, dtype: torch.dtype):
# cache_root=True not supported by MockModel, see test_cache_root
partial(qNoisyExpectedImprovement, cache_root=False, **nei_args),
partial(qNoisyExpectedImprovement, cache_root=True, **nei_args),
partial(qPosteriorStandardDeviation, model=mm),
]:
acqf = acqf_constructor()
mm._posterior._samples = (
Expand Down
2 changes: 2 additions & 0 deletions test/models/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def test_PosteriorMeanModel(self):
train_Y = torch.rand(2, 2)
model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
mean_model = PosteriorMeanModel(model=model)
self.assertTrue(mean_model.num_outputs == train_Y.shape[-1])
self.assertTrue(mean_model.batch_shape == torch.Size([]))

test_X = torch.rand(2, 3)
post = model.posterior(test_X)
Expand Down