From 541cb6863ccadb59d7e25050eca863ab446775c3 Mon Sep 17 00:00:00 2001 From: hvarfner Date: Fri, 21 Feb 2025 11:11:46 +0100 Subject: [PATCH 1/4] Boltzmann sampling function added in utils/sampling to remove duplicate code, reshuffling of other sampling methods (that don't take an acqf) --- botorch/optim/initializers.py | 180 +++++--------------------------- botorch/utils/sampling.py | 173 +++++++++++++++++++++++++++++- test/optim/test_initializers.py | 108 +------------------ test/utils/test_sampling.py | 156 ++++++++++++++++++++++++++- 4 files changed, 351 insertions(+), 266 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index 753ca86124..e5e81f3dcc 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -17,7 +17,6 @@ import warnings from collections.abc import Callable -from math import ceil from typing import Optional, Union import torch @@ -43,14 +42,15 @@ from botorch.optim.utils import fix_features, get_X_baseline from botorch.utils.multi_objective.pareto import is_non_dominated from botorch.utils.sampling import ( - batched_multinomial, + boltzmann_sample, draw_sobol_samples, get_polytope_samples, manual_seed, + sample_perturbed_subset_dims, + sample_truncated_normal_perturbations, ) -from botorch.utils.transforms import normalize, standardize, unnormalize +from botorch.utils.transforms import unnormalize from torch import Tensor -from torch.distributions import Normal from torch.quasirandom import SobolEngine TGenInitialConditions = Callable[ @@ -578,10 +578,12 @@ def gen_one_shot_kg_initial_conditions( # sampling from the optimizers n_value = int((1 - frac_random) * (q_aug - q)) # number of non-random ICs - eta = options.get("eta", 2.0) - weights = torch.exp(eta * standardize(fantasy_vals)) - idx = torch.multinomial(weights, num_restarts * n_value, replacement=True) - + idx = boltzmann_sample( + function_values=fantasy_vals, + num_samples=num_restarts * n_value, + eta=options.get("eta", 2.0), + replacement=True, + ) # set the respective initial conditions to the sampled optimizers ics[..., -n_value:, :] = fantasy_cands[idx, 0].view(num_restarts, n_value, -1) return ics @@ -699,14 +701,14 @@ def gen_one_shot_hvkg_initial_conditions( sequential=False, ) # sampling from the optimizers - eta = options.get("eta", 2.0) if num_optim_restarts > 0: - probs = torch.nn.functional.softmax(eta * standardize(fantasy_vals), dim=0) - idx = torch.multinomial( - probs, - num_optim_restarts * acq_function.num_fantasies, + idx = boltzmann_sample( + function_values=fantasy_vals, + num_samples=num_optim_restarts * acq_function.num_fantasies, + eta=options.get("eta", 2.0), replacement=True, ) + optim_ics = fantasy_cands[idx] if is_mf_hvkg: # add fixed features @@ -885,11 +887,10 @@ def gen_value_function_initial_conditions( # sampling from the optimizers n_value = int((1 - frac_random) * raw_samples) # number of non-random ICs if n_value > 0: - eta = options.get("eta", 2.0) - weights = torch.exp(eta * standardize(fantasy_vals)) - idx = batched_multinomial( - weights=weights.expand(*batch_shape, -1), + idx = boltzmann_sample( + function_values=fantasy_vals.expand(*batch_shape, -1), num_samples=n_value, + eta=options.get("eta", 2.0), replacement=True, ).permute(-1, *range(len(batch_shape))) resampled = fantasy_cands[idx] @@ -979,18 +980,12 @@ def initialize_q_batch( return X[idcs], acq_vals[idcs] max_val, max_idx = torch.max(acq_vals, dim=0) - Z = (acq_vals - acq_vals.mean(dim=0)) / Ystd - etaZ = eta * Z - weights = torch.exp(etaZ) - while torch.isinf(weights).any(): - etaZ *= 0.5 - weights = torch.exp(etaZ) - if batch_shape == torch.Size(): - idcs = torch.multinomial(weights, n) - else: - idcs = batched_multinomial( - weights=weights.permute(*range(1, len(batch_shape) + 1), 0), num_samples=n - ).permute(-1, *range(len(batch_shape))) + idcs = boltzmann_sample( + acq_vals.permute(*range(1, len(batch_shape) + 1), 0), + num_samples=n, + eta=eta, + ).permute(-1, *range(len(batch_shape))) + # make sure we get the maximum if max_idx not in idcs: idcs[-1] = max_idx @@ -1239,133 +1234,6 @@ def sample_points_around_best( return perturbed_X -def sample_truncated_normal_perturbations( - X: Tensor, - n_discrete_points: int, - sigma: float, - bounds: Tensor, - qmc: bool = True, -) -> Tensor: - r"""Sample points around `X`. - - Sample perturbed points around `X` such that the added perturbations - are sampled from N(0, sigma^2 I) and truncated to be within [0,1]^d. - - Args: - X: A `n x d`-dim tensor starting points. - n_discrete_points: The number of points to sample. - sigma: The standard deviation of the additive gaussian noise for - perturbing the points. - bounds: A `2 x d`-dim tensor containing the bounds. - qmc: A boolean indicating whether to use qmc. - - Returns: - A `n_discrete_points x d`-dim tensor containing the sampled points. - """ - X = normalize(X, bounds=bounds) - d = X.shape[1] - # sample points from N(X_center, sigma^2 I), truncated to be within - # [0, 1]^d. - if X.shape[0] > 1: - rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device) - X = X[rand_indices] - if qmc: - std_bounds = torch.zeros(2, d, dtype=X.dtype, device=X.device) - std_bounds[1] = 1 - u = draw_sobol_samples(bounds=std_bounds, n=n_discrete_points, q=1).squeeze(1) - else: - u = torch.rand((n_discrete_points, d), dtype=X.dtype, device=X.device) - # compute bounds to sample from - a = -X - b = 1 - X - # compute z-score of bounds - alpha = a / sigma - beta = b / sigma - normal = Normal(0, 1) - cdf_alpha = normal.cdf(alpha) - # use inverse transform - perturbation = normal.icdf(cdf_alpha + u * (normal.cdf(beta) - cdf_alpha)) * sigma - # add perturbation and clip points that are still outside - perturbed_X = (X + perturbation).clamp(0.0, 1.0) - return unnormalize(perturbed_X, bounds=bounds) - - -def sample_perturbed_subset_dims( - X: Tensor, - bounds: Tensor, - n_discrete_points: int, - sigma: float = 1e-1, - qmc: bool = True, - prob_perturb: float | None = None, -) -> Tensor: - r"""Sample around `X` by perturbing a subset of the dimensions. - - By default, dimensions are perturbed with probability equal to - `min(20 / d, 1)`. As shown in [Regis]_, perturbing a small number - of dimensions can be beneificial. The perturbations are sampled - from N(0, sigma^2 I) and truncated to be within [0,1]^d. - - Args: - X: A `n x d`-dim tensor starting points. `X` - must be normalized to be within `[0, 1]^d`. - bounds: The bounds to sample perturbed values from - n_discrete_points: The number of points to sample. - sigma: The standard deviation of the additive gaussian noise for - perturbing the points. - qmc: A boolean indicating whether to use qmc. - prob_perturb: The probability of perturbing each dimension. If omitted, - defaults to `min(20 / d, 1)`. - - Returns: - A `n_discrete_points x d`-dim tensor containing the sampled points. - - """ - if bounds.ndim != 2: - raise BotorchTensorDimensionError("bounds must be a `2 x d`-dim tensor.") - elif X.ndim != 2: - raise BotorchTensorDimensionError("X must be a `n x d`-dim tensor.") - d = bounds.shape[-1] - if prob_perturb is None: - # Only perturb a subset of the features - prob_perturb = min(20.0 / d, 1.0) - - if X.shape[0] == 1: - X_cand = X.repeat(n_discrete_points, 1) - else: - rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device) - X_cand = X[rand_indices] - pert = sample_truncated_normal_perturbations( - X=X_cand, - n_discrete_points=n_discrete_points, - sigma=sigma, - bounds=bounds, - qmc=qmc, - ) - - # find cases where we are not perturbing any dimensions - mask = ( - torch.rand( - n_discrete_points, - d, - dtype=bounds.dtype, - device=bounds.device, - ) - <= prob_perturb - ) - ind = (~mask).all(dim=-1).nonzero() - # perturb `n_perturb` of the dimensions - n_perturb = ceil(d * prob_perturb) - perturb_mask = torch.zeros(d, dtype=mask.dtype, device=mask.device) - perturb_mask[:n_perturb].fill_(1) - # TODO: use batched `torch.randperm` when available: - # https://github.com/pytorch/pytorch/issues/42502 - for idx in ind: - mask[idx] = perturb_mask[torch.randperm(d, device=bounds.device)] - # Create candidate points - X_cand[mask] = pert[mask] - return X_cand - - def is_nonnegative(acq_function: AcquisitionFunction) -> bool: r"""Determine whether a given acquisition function is non-negative. diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index 5f366d45e6..adcb58fe2a 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -21,18 +21,27 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Generator, Iterable from contextlib import contextmanager +from math import ceil from typing import TYPE_CHECKING import numpy as np import numpy.typing as npt import scipy + import torch -from botorch.exceptions.errors import BotorchError, InfeasibilityError + +from botorch.exceptions.errors import ( + BotorchError, + BotorchTensorDimensionError, + InfeasibilityError, +) from botorch.exceptions.warnings import UserInputWarning from botorch.sampling.qmc import NormalQMCEngine -from botorch.utils.transforms import unnormalize + +from botorch.utils.transforms import normalize, standardize, unnormalize from scipy.spatial import Delaunay, HalfspaceIntersection from torch import LongTensor, Tensor +from torch.distributions import Normal from torch.quasirandom import SobolEngine @@ -1061,3 +1070,163 @@ def path_func(x) -> Tensor: f_opt = paths(X_opt.unsqueeze(-2)).squeeze(-2) return X_opt, f_opt + + +def boltzmann_sample( + function_values: Tensor, + num_samples: int, + eta: float, + replacement: bool = False, + temp_decrease: float = 0.5, +): + """ + Perform Boltzmann sampling from a set of function values, weighted by the + exponentiated difference between function values and their standardized mean. + + Args: + function_values: A [batch_shape] x N tensor of function values. + num_samples: The number of samples (restarts) to draw. + eta: The Boltzmann temperature, controls the sharpness of the weighting. If the + temperature is too high, causing NaN values, the eta parameter is + succesively decreased by 'temp_decrease'. + replacement: If True, samples are drawn with replacement, allowing duplicates. + temp_decrease: The rate at which temperature decreases in case of inf weights. + Returns: + A [batch_shape] x num_samples tensor of indices of sampled positions. + """ + norm_weights = standardize(function_values) + weights = torch.exp(eta * norm_weights) + while torch.isinf(weights).any(): + eta *= temp_decrease + weights = torch.exp(eta * norm_weights) + + return batched_multinomial( + weights=weights, num_samples=num_samples, replacement=replacement + ) + + +def sample_truncated_normal_perturbations( + X: Tensor, + n_discrete_points: int, + sigma: float, + bounds: Tensor, + qmc: bool = True, +) -> Tensor: + r"""Sample points around `X`. + + Sample perturbed points around `X` such that the added perturbations + are sampled from N(0, sigma^2 I) and truncated to be within [0,1]^d. + + Args: + X: A `n x d`-dim tensor starting points. + n_discrete_points: The number of points to sample. + sigma: The standard deviation of the additive gaussian noise for + perturbing the points. + bounds: A `2 x d`-dim tensor containing the bounds. + qmc: A boolean indicating whether to use qmc. + + Returns: + A `n_discrete_points x d`-dim tensor containing the sampled points. + """ + X = normalize(X, bounds=bounds) + d = X.shape[1] + # sample points from N(X_center, sigma^2 I), truncated to be within + # [0, 1]^d. + if X.shape[0] > 1: + rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device) + X = X[rand_indices] + if qmc: + std_bounds = torch.zeros(2, d, dtype=X.dtype, device=X.device) + std_bounds[1] = 1 + u = draw_sobol_samples(bounds=std_bounds, n=n_discrete_points, q=1).squeeze(1) + else: + u = torch.rand((n_discrete_points, d), dtype=X.dtype, device=X.device) + # compute bounds to sample from + a = -X + b = 1 - X + # compute z-score of bounds + alpha = a / sigma + beta = b / sigma + normal = Normal(0, 1) + cdf_alpha = normal.cdf(alpha) + # use inverse transform + perturbation = normal.icdf(cdf_alpha + u * (normal.cdf(beta) - cdf_alpha)) * sigma + # add perturbation and clip points that are still outside + perturbed_X = (X + perturbation).clamp(0.0, 1.0) + return unnormalize(perturbed_X, bounds=bounds) + + +def sample_perturbed_subset_dims( + X: Tensor, + bounds: Tensor, + n_discrete_points: int, + sigma: float = 1e-1, + qmc: bool = True, + prob_perturb: float | None = None, +) -> Tensor: + r"""Sample around `X` by perturbing a subset of the dimensions. + + By default, dimensions are perturbed with probability equal to + `min(20 / d, 1)`. As shown in [Regis]_, perturbing a small number + of dimensions can be beneificial. The perturbations are sampled + from N(0, sigma^2 I) and truncated to be within [0,1]^d. + + Args: + X: A `n x d`-dim tensor starting points. `X` + must be normalized to be within `[0, 1]^d`. + bounds: The bounds to sample perturbed values from + n_discrete_points: The number of points to sample. + sigma: The standard deviation of the additive gaussian noise for + perturbing the points. + qmc: A boolean indicating whether to use qmc. + prob_perturb: The probability of perturbing each dimension. If omitted, + defaults to `min(20 / d, 1)`. + + Returns: + A `n_discrete_points x d`-dim tensor containing the sampled points. + + """ + if bounds.ndim != 2: + raise BotorchTensorDimensionError("bounds must be a `2 x d`-dim tensor.") + elif X.ndim != 2: + raise BotorchTensorDimensionError("X must be a `n x d`-dim tensor.") + d = bounds.shape[-1] + if prob_perturb is None: + # Only perturb a subset of the features + prob_perturb = min(20.0 / d, 1.0) + + if X.shape[0] == 1: + X_cand = X.repeat(n_discrete_points, 1) + else: + rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device) + X_cand = X[rand_indices] + pert = sample_truncated_normal_perturbations( + X=X_cand, + n_discrete_points=n_discrete_points, + sigma=sigma, + bounds=bounds, + qmc=qmc, + ) + + # find cases where we are not perturbing any dimensions + mask = ( + torch.rand( + n_discrete_points, + d, + dtype=bounds.dtype, + device=bounds.device, + ) + <= prob_perturb + ) + ind = (~mask).all(dim=-1).nonzero() + # perturb `n_perturb` of the dimensions + n_perturb = ceil(d * prob_perturb) + perturb_mask = torch.zeros(d, dtype=mask.dtype, device=mask.device) + perturb_mask[:n_perturb].fill_(1) + # TODO: use batched `torch.randperm` when available: + # https://github.com/pytorch/pytorch/issues/42502 + for idx in ind: + mask[idx] = perturb_mask[torch.randperm(d, device=bounds.device)] + # Create candidate points + X_cand[mask] = pert[mask] + return X_cand diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 65c3e2b6bb..187a09d7f3 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -41,13 +41,12 @@ sample_perturbed_subset_dims, sample_points_around_best, sample_q_batches_from_polytope, - sample_truncated_normal_perturbations, transform_constraints, transform_inter_point_constraint, transform_intra_point_constraint, ) from botorch.sampling.normal import IIDNormalSampler -from botorch.utils.sampling import draw_sobol_samples, manual_seed, unnormalize +from botorch.utils.sampling import manual_seed, unnormalize from botorch.utils.testing import ( _get_max_violation_of_bounds, _get_max_violation_of_constraints, @@ -1266,111 +1265,6 @@ def test_gen_value_function_initial_conditions(self): class TestSampleAroundBest(BotorchTestCase): - def test_sample_truncated_normal_perturbations(self): - tkwargs = {"device": self.device} - n_discrete_points = 5 - _bounds = torch.ones(2, 4) - _bounds[1] = 2 - for dtype in (torch.float, torch.double): - tkwargs["dtype"] = dtype - bounds = _bounds.to(**tkwargs) - for n_best in (1, 2): - X = 1 + torch.rand(n_best, 4, **tkwargs) - # basic test - perturbed_X = sample_truncated_normal_perturbations( - X=X, - n_discrete_points=n_discrete_points, - sigma=4, - bounds=bounds, - qmc=False, - ) - self.assertEqual(perturbed_X.shape, torch.Size([n_discrete_points, 4])) - self.assertTrue((perturbed_X >= 1).all()) - self.assertTrue((perturbed_X <= 2).all()) - # test qmc - with mock.patch( - "botorch.optim.initializers.draw_sobol_samples", - wraps=draw_sobol_samples, - ) as mock_sobol: - perturbed_X = sample_truncated_normal_perturbations( - X=X, - n_discrete_points=n_discrete_points, - sigma=4, - bounds=bounds, - qmc=True, - ) - mock_sobol.assert_called_once() - self.assertEqual(perturbed_X.shape, torch.Size([n_discrete_points, 4])) - self.assertTrue((perturbed_X >= 1).all()) - self.assertTrue((perturbed_X <= 2).all()) - - def test_sample_perturbed_subset_dims(self): - tkwargs = {"device": self.device} - n_discrete_points = 5 - - # test that errors are raised - with self.assertRaises(BotorchTensorDimensionError): - sample_perturbed_subset_dims( - X=torch.zeros(1, 1), - n_discrete_points=1, - sigma=1e-3, - bounds=torch.zeros(1, 2, 1), - ) - with self.assertRaises(BotorchTensorDimensionError): - sample_perturbed_subset_dims( - X=torch.zeros(1, 1, 1), - n_discrete_points=1, - sigma=1e-3, - bounds=torch.zeros(2, 1), - ) - for dtype in (torch.float, torch.double): - for n_best in (1, 2): - tkwargs["dtype"] = dtype - bounds = torch.zeros(2, 21, **tkwargs) - bounds[1] = 1 - X = torch.rand(n_best, 21, **tkwargs) - # basic test - with mock.patch( - "botorch.optim.initializers.draw_sobol_samples", - ) as mock_sobol: - perturbed_X = sample_perturbed_subset_dims( - X=X, - n_discrete_points=n_discrete_points, - qmc=False, - sigma=1e-3, - bounds=bounds, - ) - mock_sobol.assert_not_called() - self.assertEqual(perturbed_X.shape, torch.Size([n_discrete_points, 21])) - self.assertTrue((perturbed_X >= 0).all()) - self.assertTrue((perturbed_X <= 1).all()) - # test qmc - with mock.patch( - "botorch.optim.initializers.draw_sobol_samples", - wraps=draw_sobol_samples, - ) as mock_sobol: - perturbed_X = sample_perturbed_subset_dims( - X=X, - n_discrete_points=n_discrete_points, - sigma=1e-3, - bounds=bounds, - ) - mock_sobol.assert_called_once() - self.assertEqual(perturbed_X.shape, torch.Size([n_discrete_points, 21])) - self.assertTrue((perturbed_X >= 0).all()) - self.assertTrue((perturbed_X <= 1).all()) - # for each point in perturbed_X compute the number of - # dimensions it has in common with each point in X - # and take the maximum number - max_equal_dims = ( - (perturbed_X.unsqueeze(0) == X.unsqueeze(1)) - .sum(dim=-1) - .max(dim=0) - .values - ) - # check that at least one dimension is perturbed - self.assertTrue((20 - max_equal_dims >= 1).all()) - def test_sample_points_around_best(self): tkwargs = {"device": self.device} _bounds = torch.ones(2, 2) diff --git a/test/utils/test_sampling.py b/test/utils/test_sampling.py index e1c7e65a4c..62dd0b5bbd 100644 --- a/test/utils/test_sampling.py +++ b/test/utils/test_sampling.py @@ -18,13 +18,14 @@ LinearMCObjective, ScalarizedPosteriorTransform, ) -from botorch.exceptions.errors import BotorchError +from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError from botorch.exceptions.warnings import UserInputWarning from botorch.models.gp_regression import SingleTaskGP from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model from botorch.utils.sampling import ( _convert_bounds_to_inequality_constraints, batched_multinomial, + boltzmann_sample, DelaunayPolytopeSampler, draw_sobol_samples, find_interior_point, @@ -35,8 +36,10 @@ optimize_posterior_samples, PolytopeSampler, sample_hypersphere, + sample_perturbed_subset_dims, sample_polytope, sample_simplex, + sample_truncated_normal_perturbations, sparse_to_dense_constraints, ) from botorch.utils.testing import BotorchTestCase @@ -684,3 +687,154 @@ def test_optimize_posterior_samples_multi_objective(self): ) correct_f_shape = (num_optima,) + batch_shape + (1,) self.assertEqual(f_opt.shape, correct_f_shape) + + +class TestSampleTruncatedNormalPerturbations(BotorchTestCase): + def test_sample_truncated_normal_perturbations(self): + tkwargs = {"device": self.device} + n_discrete_points = 5 + _bounds = torch.ones(2, 4) + _bounds[1] = 2 + for dtype in (torch.float, torch.double): + tkwargs["dtype"] = dtype + bounds = _bounds.to(**tkwargs) + for n_best in (1, 2): + X = 1 + torch.rand(n_best, 4, **tkwargs) + # basic test + perturbed_X = sample_truncated_normal_perturbations( + X=X, + n_discrete_points=n_discrete_points, + sigma=4, + bounds=bounds, + qmc=False, + ) + self.assertEqual(perturbed_X.shape, torch.Size([n_discrete_points, 4])) + self.assertTrue((perturbed_X >= 1).all()) + self.assertTrue((perturbed_X <= 2).all()) + # test qmc + with mock.patch( + "botorch.utils.sampling.draw_sobol_samples", + wraps=draw_sobol_samples, + ) as mock_sobol: + perturbed_X = sample_truncated_normal_perturbations( + X=X, + n_discrete_points=n_discrete_points, + sigma=4, + bounds=bounds, + qmc=True, + ) + mock_sobol.assert_called_once() + self.assertEqual(perturbed_X.shape, torch.Size([n_discrete_points, 4])) + self.assertTrue((perturbed_X >= 1).all()) + self.assertTrue((perturbed_X <= 2).all()) + + +class TestSamplePerturbedSubsetDims(BotorchTestCase): + def test_sample_perturbed_subset_dims(self): + tkwargs = {"device": self.device} + n_discrete_points = 5 + + # test that errors are raised + with self.assertRaises(BotorchTensorDimensionError): + sample_perturbed_subset_dims( + X=torch.zeros(1, 1), + n_discrete_points=1, + sigma=1e-3, + bounds=torch.zeros(1, 2, 1), + ) + with self.assertRaises(BotorchTensorDimensionError): + sample_perturbed_subset_dims( + X=torch.zeros(1, 1, 1), + n_discrete_points=1, + sigma=1e-3, + bounds=torch.zeros(2, 1), + ) + for dtype in (torch.float, torch.double): + for n_best in (1, 2): + tkwargs["dtype"] = dtype + bounds = torch.zeros(2, 21, **tkwargs) + bounds[1] = 1 + X = torch.rand(n_best, 21, **tkwargs) + # basic test + with mock.patch( + "botorch.utils.sampling.draw_sobol_samples", + ) as mock_sobol: + perturbed_X = sample_perturbed_subset_dims( + X=X, + n_discrete_points=n_discrete_points, + qmc=False, + sigma=1e-3, + bounds=bounds, + ) + mock_sobol.assert_not_called() + self.assertEqual(perturbed_X.shape, torch.Size([n_discrete_points, 21])) + self.assertTrue((perturbed_X >= 0).all()) + self.assertTrue((perturbed_X <= 1).all()) + # test qmc + with mock.patch( + "botorch.utils.sampling.draw_sobol_samples", + wraps=draw_sobol_samples, + ) as mock_sobol: + perturbed_X = sample_perturbed_subset_dims( + X=X, + n_discrete_points=n_discrete_points, + sigma=1e-3, + bounds=bounds, + ) + mock_sobol.assert_called_once() + self.assertEqual(perturbed_X.shape, torch.Size([n_discrete_points, 21])) + self.assertTrue((perturbed_X >= 0).all()) + self.assertTrue((perturbed_X <= 1).all()) + # for each point in perturbed_X compute the number of + # dimensions it has in common with each point in X + # and take the maximum number + max_equal_dims = ( + (perturbed_X.unsqueeze(0) == X.unsqueeze(1)) + .sum(dim=-1) + .max(dim=0) + .values + ) + # check that at least one dimension is perturbed + self.assertTrue((20 - max_equal_dims >= 1).all()) + + +class TestBoltzmannSample(BotorchTestCase): + def test_boltzmann_sample(self): + tkwargs = {"device": self.device} + for dtype in (torch.float32, torch.float64): + tkwargs["dtype"] = dtype + + function_values = torch.tensor([1.0, 2.0, 3.0, 4.0], **tkwargs) + num_samples = 2 + eta = 1.0 + result = boltzmann_sample(function_values, num_samples, eta) + self.assertEqual(result.shape, (num_samples,)) + + # test batch dimensions + function_values = torch.tensor( + [[-1.0, 2.0, -3.0], [-4.0, -3.0, 1.0]], **tkwargs + ) + num_samples = 2 + eta = 1.0 + result = boltzmann_sample(function_values, num_samples, eta) + self.assertEqual(result.shape, (function_values.shape[0], num_samples)) + + function_values = torch.tensor([1.0, 2.0, 3.0, 4.0], **tkwargs) + num_samples = 5 + eta = 0.1 + + # With replacement (should succeed even if num_samples > len(function_values)) + result_with_replacement = boltzmann_sample( + function_values, num_samples, eta, replacement=True + ) + self.assertEqual(result_with_replacement.shape, (num_samples,)) + + # Without replacement (should fail if num_samples > len(function_values)) + with self.assertRaises(RuntimeError): + boltzmann_sample(function_values, num_samples, eta, replacement=False) + + function_values = torch.tensor([1.0, 2.0, 3.0, 4.0], **tkwargs) + num_samples = 2 + large_eta = 1000.0 + result = boltzmann_sample(function_values, num_samples, large_eta) + self.assertEqual(result.shape, (num_samples,)) From 3ec3cff7b5de0b9f52b7f86446a9df21260b1638 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 24 Feb 2025 18:23:24 -0500 Subject: [PATCH 2/4] Update botorch/utils/sampling.py Co-authored-by: Elizabeth Santorella --- botorch/utils/sampling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index adcb58fe2a..ed88ee4fdd 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -1091,6 +1091,7 @@ def boltzmann_sample( succesively decreased by 'temp_decrease'. replacement: If True, samples are drawn with replacement, allowing duplicates. temp_decrease: The rate at which temperature decreases in case of inf weights. + Returns: A [batch_shape] x num_samples tensor of indices of sampled positions. """ From 1cdd7777373da593528dc8cb15fa5ffcb3000bd5 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 24 Feb 2025 18:25:48 -0500 Subject: [PATCH 3/4] Update botorch/utils/sampling.py --- botorch/utils/sampling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index ed88ee4fdd..adcb58fe2a 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -1091,7 +1091,6 @@ def boltzmann_sample( succesively decreased by 'temp_decrease'. replacement: If True, samples are drawn with replacement, allowing duplicates. temp_decrease: The rate at which temperature decreases in case of inf weights. - Returns: A [batch_shape] x num_samples tensor of indices of sampled positions. """ From 6995223c843567427223b34b18be64ba64b4ac15 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 24 Feb 2025 18:26:19 -0500 Subject: [PATCH 4/4] Update botorch/utils/sampling.py --- botorch/utils/sampling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index adcb58fe2a..7066578b9d 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -1091,7 +1091,8 @@ def boltzmann_sample( succesively decreased by 'temp_decrease'. replacement: If True, samples are drawn with replacement, allowing duplicates. temp_decrease: The rate at which temperature decreases in case of inf weights. - Returns: + + Returns: A [batch_shape] x num_samples tensor of indices of sampled positions. """ norm_weights = standardize(function_values)