From 104b2660d70ce2e06014e18079771517aba73cb3 Mon Sep 17 00:00:00 2001 From: Yousif Alsaffar Date: Thu, 5 Dec 2024 09:39:29 -0800 Subject: [PATCH] Refactor Inducing Point Selection to Use Allocator Classes for Enhanced Flexibility (#377) (#435) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Description This PR proposes a solution to issue https://github.com/facebookresearch/aepsych/issues/378 . Although the implementation is fully functional and tested, it is presented as an overall proposal, open for discussion and potential refinement. ### Key Changes - **Refactor of `select_inducing_points` Function**: - **Previous Implementation**: Accepted `method` as a string (`"pivoted_chol"`, `"kmeans++"`, `"auto"`, `"sobol"`) and used conditional logic to select inducing points. - **New Implementation**: Now accepts an `InducingPointAllocator` instance, with `AutoAllocator` as the default if `allocator` is `None`. This approach allows users to directly pass allocator instances, aligning with the issue’s goal to enable flexible use of custom allocators like `GreedyImprovementReduction`. - **New `inducing_point_allocators.py` File**: - Introduces classes `SobolAllocator`, `KMeansAllocator`, and `AutoAllocator`, all implementing the `InducingPointAllocator` interface from Botorch. This modularizes allocator logic, moving it out of `select_inducing_points` while following the established base class structure. - **Modifications to Models and Example Files**: - Updated `gp_classification.py`, `monotonic_projection_gp.py`, `monotonic_rejection_gp.py`, `semi_p.py`, and `example_problems.py` to handle allocator class instances rather than string-based methods, improving overall consistency. - Added imports for the new allocator classes in `__init__.py` for cross-codebase accessibility. - **Updated Tests**: - Adjusted tests in `test_semi_p.py`, `test_utils.py`, and `test_config.py` to work with allocator classes instead of the previous string-based structure. ### Additional Notes This PR preserves most of the existing logic in `select_inducing_points` to keep changes minimal. I know further work is needed to confirm compatibility with additional Botorch allocators and to support advanced configurations using `from_config` for custom allocator setups. I’d love to hear your feedback on the overall approach before moving forward with these additional refinements. Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/435 Reviewed By: crasanders Differential Revision: D65451912 Pulled By: JasonKChow fbshipit-source-id: e0529e545e428ad94ef965cc9d642577cbeb2777 --- aepsych/benchmark/example_problems.py | 3 +- aepsych/models/__init__.py | 14 + aepsych/models/gp_classification.py | 44 +- aepsych/models/inducing_point_allocators.py | 594 +++++++++++++++++++ aepsych/models/monotonic_projection_gp.py | 18 +- aepsych/models/monotonic_rejection_gp.py | 16 +- aepsych/models/semi_p.py | 38 +- aepsych/models/utils.py | 84 +-- tests/models/test_gp_classification.py | 12 +- tests/models/test_semi_p.py | 11 +- tests/models/test_utils.py | 77 ++- tests/test_config.py | 13 +- tests/test_points_allocators.py | 606 ++++++++++++++++++++ 13 files changed, 1427 insertions(+), 103 deletions(-) create mode 100644 aepsych/models/inducing_point_allocators.py create mode 100644 tests/test_points_allocators.py diff --git a/aepsych/benchmark/example_problems.py b/aepsych/benchmark/example_problems.py index 9e130a2d5..28d8380e7 100644 --- a/aepsych/benchmark/example_problems.py +++ b/aepsych/benchmark/example_problems.py @@ -11,6 +11,7 @@ novel_discrimination_testfun, ) from aepsych.models import GPClassificationModel +from aepsych.models.inducing_point_allocators import KMeansAllocator """The DiscrimLowDim, DiscrimHighDim, ContrastSensitivity6d, and Hartmann6Binary classes are copied from bernoulli_lse github repository (https://github.com/facebookresearch/bernoulli_lse) @@ -109,7 +110,7 @@ def __init__( lb=self.bounds[0], ub=self.bounds[1], inducing_size=100, - inducing_point_method="kmeans++", + inducing_point_method=KMeansAllocator(), ) self.m.fit( diff --git a/aepsych/models/__init__.py b/aepsych/models/__init__.py index a3acf1294..463ab7f3e 100644 --- a/aepsych/models/__init__.py +++ b/aepsych/models/__init__.py @@ -10,6 +10,14 @@ from ..config import Config from .gp_classification import GPBetaRegressionModel, GPClassificationModel from .gp_regression import GPRegressionModel +from .inducing_point_allocators import ( + AutoAllocator, + DummyAllocator, + FixedAllocator, + GreedyVarianceReduction, + KMeansAllocator, + SobolAllocator, +) from .monotonic_projection_gp import MonotonicProjectionGP from .monotonic_rejection_gp import MonotonicRejectionGP from .multitask_regression import IndependentMultitaskGPRModel, MultitaskGPRModel @@ -34,6 +42,12 @@ "semi_p_posterior_transform", "GPBetaRegressionModel", "PairwiseProbitModel", + "AutoAllocator", + "KMeansAllocator", + "SobolAllocator", + "DummyAllocator", + "FixedAllocator", + "GreedyVarianceReduction", ] Config.register_module(sys.modules[__name__]) diff --git a/aepsych/models/gp_classification.py b/aepsych/models/gp_classification.py index 325fdd8cc..80133dcc2 100644 --- a/aepsych/models/gp_classification.py +++ b/aepsych/models/gp_classification.py @@ -16,9 +16,15 @@ from aepsych.config import Config from aepsych.factory.default import default_mean_covar_factory from aepsych.models.base import AEPsychModelDeviceMixin +from aepsych.models.inducing_point_allocators import ( + AutoAllocator, + DummyAllocator, + SobolAllocator, +) from aepsych.models.utils import select_inducing_points from aepsych.utils import _process_bounds, get_optimizer_options, promote_0d from aepsych.utils_logging import getLogger +from botorch.models.utils.inducing_point_allocators import InducingPointAllocator from gpytorch.likelihoods import BernoulliLikelihood, BetaLikelihood, Likelihood from gpytorch.models import ApproximateGP from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy @@ -56,7 +62,7 @@ def __init__( likelihood: Optional[Likelihood] = None, inducing_size: Optional[int] = None, max_fit_time: Optional[float] = None, - inducing_point_method: str = "auto", + inducing_point_method: InducingPointAllocator = AutoAllocator(), optimizer_options: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the GP Classification model @@ -74,11 +80,8 @@ def __init__( inducing_size (int, optional): Number of inducing points. Defaults to 99. max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None, there is no limit to the fitting time. - inducing_point_method (string): The method to use to select the inducing points. Defaults to "auto". - If "sobol", a number of Sobol points equal to inducing_size will be selected. - If "pivoted_chol", selects points based on the pivoted Cholesky heuristic. - If "kmeans++", selects points by performing kmeans++ clustering on the training data. - If "auto", tries to determine the best method automatically. + inducing_point_method (InducingPointAllocator): The method to use for selecting inducing points. + Defaults to AutoAllocator(). optimizer_options (Dict[str, Any], optional): Optimizer options to pass to the SciPy optimizer during fitting. Assumes we are using L-BFGS-B. """ @@ -106,10 +109,10 @@ def __init__( # initialize to sobol before we have data inducing_points = select_inducing_points( + allocator=DummyAllocator(bounds=torch.stack((lb, ub))), inducing_size=self.inducing_size, - bounds=torch.stack((lb, ub)), - method="sobol", ) + self.last_inducing_points_method = "DummyAllocator" variational_distribution = CholeskyVariationalDistribution( inducing_points.size(0), batch_shape=torch.Size([self._batch_size]) @@ -122,7 +125,6 @@ def __init__( learn_inducing_locations=False, ) super().__init__(variational_strategy) - if mean_module is None or covar_module is None: default_mean, default_covar = default_mean_covar_factory( dim=self.dim, stimuli_per_trial=self.stimuli_per_trial @@ -166,9 +168,14 @@ def from_config(cls, config: Config) -> GPClassificationModel: mean, covar = mean_covar_factory(config) max_fit_time = config.getfloat(classname, "max_fit_time", fallback=None) - inducing_point_method = config.get( - classname, "inducing_point_method", fallback="auto" + inducing_point_method_class = config.getobj( + classname, "inducing_point_method", fallback=AutoAllocator ) + # Check if allocator class has a `from_config` method + if hasattr(inducing_point_method_class, "from_config"): + inducing_point_method = inducing_point_method_class.from_config(config) + else: + inducing_point_method = inducing_point_method_class() likelihood_cls = config.getobj(classname, "likelihood", fallback=None) @@ -211,14 +218,16 @@ def _reset_variational_strategy(self) -> None: if self.train_inputs is not None: # remember original device device = self.device - inducing_points = select_inducing_points( + allocator=self.inducing_point_method, inducing_size=self.inducing_size, covar_module=self.covar_module, X=self.train_inputs[0], bounds=self.bounds, - method=self.inducing_point_method, ).to(device) + self.last_inducing_points_method = ( + self.inducing_point_method.__class__.__name__ + ) variational_distribution = CholeskyVariationalDistribution( inducing_points.size(0), batch_shape=torch.Size([self._batch_size]) @@ -255,7 +264,10 @@ def fit( if not warmstart_hyperparams: self._reset_hyperparameters() - if not warmstart_induc: + if not warmstart_induc or ( + self.last_inducing_points_method == "DummyAllocator" + and self.inducing_point_method.__class__.__name__ != "DummyAllocator" + ): self._reset_variational_strategy() n = train_y.shape[0] @@ -360,7 +372,7 @@ def __init__( likelihood: Optional[Likelihood] = None, inducing_size: Optional[int] = None, max_fit_time: Optional[float] = None, - inducing_point_method: str = "auto", + inducing_point_method: InducingPointAllocator = AutoAllocator(), optimizer_options: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the GP Beta Regression model @@ -378,7 +390,7 @@ def __init__( inducing_size (int, optional): Number of inducing points. Defaults to 100. max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None, there is no limit to the fitting time. Defaults to None. - inducing_point_method (string): The method to use to select the inducing points. Defaults to "auto". + inducing_point_method (InducingPointAllocator): The method to use to select the inducing points. If None, defaults to AutoAllocator(). """ if likelihood is None: likelihood = BetaLikelihood() diff --git a/aepsych/models/inducing_point_allocators.py b/aepsych/models/inducing_point_allocators.py new file mode 100644 index 000000000..630da1807 --- /dev/null +++ b/aepsych/models/inducing_point_allocators.py @@ -0,0 +1,594 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch + +from aepsych.config import Config, ConfigurableMixin +from aepsych.utils import get_bounds +from botorch.models.utils.inducing_point_allocators import ( + GreedyVarianceReduction as BaseGreedyVarianceReduction, + InducingPointAllocator, +) +from botorch.utils.sampling import draw_sobol_samples +from scipy.cluster.vq import kmeans2 + + +class BaseAllocator(InducingPointAllocator, ConfigurableMixin): + """Base class for inducing point allocators.""" + + def __init__(self, bounds: Optional[torch.Tensor] = None) -> None: + """ + Initialize the allocator with optional bounds. + + Args: + bounds (torch.Tensor, optional): Bounds for allocating points. Should be of shape (2, d). + """ + self.bounds = bounds + self.dim = self._initialize_dim() + + def _initialize_dim(self) -> Optional[int]: + """ + Initialize the dimension `dim` based on the bounds, if available. + + Returns: + int: The dimension `d` if bounds are provided, or None otherwise. + """ + if self.bounds is not None: + # Validate bounds and extract dimension + assert self.bounds.shape[0] == 2, "Bounds must have shape (2, d)!" + lb, ub = self.bounds[0], self.bounds[1] + for i, (l, u) in enumerate(zip(lb, ub)): + assert ( + l <= u + ), f"Lower bound {l} is not less than or equal to upper bound {u} on dimension {i}!" + return self.bounds.shape[1] # Number of dimensions (d) + return None + + def _determine_dim_from_inputs(self, inputs: torch.Tensor) -> int: + """ + Determine dimension `dim` from the inputs tensor. + + Args: + inputs (torch.Tensor): Input tensor of shape (..., d). + + Returns: + int: The inferred dimension `d`. + """ + return inputs.shape[-1] + + @abstractmethod + def allocate_inducing_points( + self, + inputs: Optional[torch.Tensor], + covar_module: Optional[torch.nn.Module], + num_inducing: int, + input_batch_shape: torch.Size, + ) -> torch.Tensor: + """ + Abstract method for allocating inducing points. + + Args: + inputs (torch.Tensor, optional): Input tensor, implementation-specific. + covar_module (torch.nn.Module, optional): Kernel covariance module. + num_inducing (int): Number of inducing points to allocate. + input_batch_shape (torch.Size): Shape of the input batch. + + Returns: + torch.Tensor: Allocated inducing points. + """ + if self.dim is None and inputs is not None: + self.dim = self._determine_dim_from_inputs(inputs) + + raise NotImplementedError("This method should be implemented by subclasses.") + + @abstractmethod + def _get_quality_function(self) -> Optional[Any]: + """ + Abstract method for returning a quality function if required. + + Returns: + None or Callable: Quality function if needed. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + +class SobolAllocator(BaseAllocator): + """An inducing point allocator that uses Sobol sequences to allocate inducing points.""" + + def __init__(self, bounds: torch.Tensor) -> None: + """Initialize the SobolAllocator with bounds.""" + self.bounds: torch.Tensor = bounds + super().__init__(bounds=bounds) + + def _get_quality_function(self) -> None: + """Sobol sampling does not require a quality function, so this returns None.""" + return None + + def allocate_inducing_points( + self, + inputs: Optional[torch.Tensor] = None, + covar_module: Optional[torch.nn.Module] = None, + num_inducing: int = 10, + input_batch_shape: torch.Size = torch.Size([]), + ) -> torch.Tensor: + """ + Generates `num_inducing` inducing points within the specified bounds using Sobol sampling. + + Args: + inputs (torch.Tensor): Input tensor, not required for Sobol sampling. + covar_module (torch.nn.Module, optional): Kernel covariance module; included for API compatibility, but not used here. + num_inducing (int, optional): The number of inducing points to generate. Defaults to 10. + input_batch_shape (torch.Size, optional): Batch shape, defaults to an empty size; included for API compatibility, but not used here. + + + Returns: + torch.Tensor: A (num_inducing, d)-dimensional tensor of inducing points within the specified bounds. + + Raises: + ValueError: If `bounds` is not provided. + """ + + # Validate bounds shape + assert ( + self.bounds.shape[0] == 2 + ), "Bounds must have shape (2, d) for Sobol sampling." + # if bounds are long, make them float + if self.bounds.dtype == torch.long: + self.bounds = self.bounds.float() + # Generate Sobol samples within the unit cube [0,1]^d and rescale to [bounds[0], bounds[1]] + inducing_points = draw_sobol_samples( + bounds=self.bounds, n=num_inducing, q=1 + ).squeeze() + + # Ensure correct shape in case Sobol sampling returns a 1D tensor + if inducing_points.ndim == 1: + inducing_points = inducing_points.view(-1, 1) + + return inducing_points + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Get configuration options for the SobolAllocator. + + Args: + config (Config): Configuration object. + name (str, optional): Name of the allocator, defaults to None. + options (Dict[str, Any], optional): Additional options, defaults to None. + + Returns: + Dict[str, Any]: Configuration options for the SobolAllocator. + """ + if name is None: + name = cls.__name__ + lb = config.gettensor("common", "lb") + ub = config.gettensor("common", "ub") + bounds = torch.stack((lb, ub)) + return {"bounds": bounds} + + +class KMeansAllocator(BaseAllocator): + """An inducing point allocator that uses k-means++ to allocate inducing points.""" + + def __init__(self, bounds: Optional[torch.Tensor] = None) -> None: + """Initialize the KMeansAllocator.""" + super().__init__(bounds=bounds) + if bounds is not None: + self.bounds = bounds + self.dummy_allocator = DummyAllocator(bounds) + + def _get_quality_function(self) -> None: + """K-means++ does not require a quality function, so this returns None.""" + return None + + def allocate_inducing_points( + self, + inputs: Optional[torch.Tensor] = None, + covar_module: Optional[torch.nn.Module] = None, + num_inducing: int = 10, + input_batch_shape: torch.Size = torch.Size([]), + ) -> torch.Tensor: + """ + Generates `num_inducing` inducing points using k-means++ initialization on the input data. + + Args: + inputs (torch.Tensor): A tensor of shape (n, d) containing the input data. + covar_module (torch.nn.Module, optional): Kernel covariance module; included for API compatibility, but not used here. + num_inducing (int, optional): The number of inducing points to generate. Defaults to 10. + input_batch_shape (torch.Size, optional): Batch shape, defaults to an empty size; included for API compatibility, but not used here. + + Returns: + torch.Tensor: A (num_inducing, d)-dimensional tensor of inducing points selected via k-means++. + """ + if inputs is None and self.bounds is not None: + return self.dummy_allocator.allocate_inducing_points( + inputs=inputs, + covar_module=covar_module, + num_inducing=num_inducing, + input_batch_shape=input_batch_shape, + ) + elif inputs is None and self.bounds is None: + raise ValueError("Either inputs or bounds must be provided.") + # Ensure inputs are unique to avoid duplication issues with k-means++ + unique_inputs = torch.unique(inputs, dim=0) + + # If unique inputs are less than or equal to the required inducing points, return them directly + if unique_inputs.shape[0] <= num_inducing: + return unique_inputs + + # Run k-means++ on the unique inputs to select inducing points + inducing_points = torch.tensor( + kmeans2(unique_inputs.cpu().numpy(), num_inducing, minit="++")[0] + ) + + return inducing_points + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Get configuration options for the KMeansAllocator. + + Args: + config (Config): Configuration object. + name (str, optional): Name of the allocator, defaults to None. + options (Dict[str, Any], optional): Additional options, defaults to None. + + Returns: + Dict[str, Any]: Configuration options for the KMeansAllocator. + """ + if name is None: + name = cls.__name__ + lb = config.gettensor("common", "lb") + ub = config.gettensor("common", "ub") + bounds = torch.stack((lb, ub)) + return {"bounds": bounds} + + +class DummyAllocator(BaseAllocator): + def __init__(self, bounds: torch.Tensor) -> None: + """Initialize the DummyAllocator with bounds. + + Args: + bounds (torch.Tensor): Bounds for allocating points. Should be of shape (2, d). + """ + super().__init__(bounds=bounds) + self.bounds: torch.Tensor = bounds + + def _get_quality_function(self) -> None: + """DummyAllocator does not require a quality function, so this returns None.""" + return None + + def allocate_inducing_points( + self, + inputs: Optional[torch.Tensor] = None, + covar_module: Optional[torch.nn.Module] = None, + num_inducing: int = 10, + input_batch_shape: torch.Size = torch.Size([]), + ) -> torch.Tensor: + """Allocate inducing points by returning zeros of the appropriate shape. + + Args: + inputs (torch.Tensor): Input tensor, not required for DummyAllocator. + covar_module (torch.nn.Module, optional): Kernel covariance module; included for API compatibility, but not used here. + num_inducing (int, optional): The number of inducing points to generate. Defaults to 10. + input_batch_shape (torch.Size, optional): Batch shape, defaults to an empty size; included for API compatibility, but not used here. + + Returns: + torch.Tensor: A (num_inducing, d)-dimensional tensor of zeros. + """ + return torch.zeros(num_inducing, self.bounds.shape[-1]) + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Get configuration options for the DummyAllocator. + + Args: + config (Config): Configuration object. + name (str, optional): Name of the allocator, defaults to None. + options (Dict[str, Any], optional): Additional options, defaults to None. + + Returns: + Dict[str, Any]: Configuration options for the DummyAllocator. + """ + if name is None: + name = cls.__name__ + lb = config.gettensor("common", "lb") + ub = config.gettensor("common", "ub") + bounds = torch.stack((lb, ub)) + return {"bounds": bounds} + + +class AutoAllocator(BaseAllocator): + """An inducing point allocator that dynamically chooses an allocation strategy + based on the number of unique data points available.""" + + def __init__( + self, + bounds: Optional[torch.Tensor] = None, + fallback_allocator: InducingPointAllocator = KMeansAllocator(), + ) -> None: + """ + Initialize the AutoAllocator with a fallback allocator. + + Args: + fallback_allocator (InducingPointAllocator, optional): Allocator to use if there are + more unique points than required. + """ + super().__init__(bounds=bounds) + self.fallback_allocator = fallback_allocator + if bounds is not None: + self.bounds = bounds + self.dummy_allocator = DummyAllocator(bounds=bounds) + + def _get_quality_function(self) -> None: + """AutoAllocator does not require a quality function, so this returns None.""" + return None + + def allocate_inducing_points( + self, + inputs: Optional[torch.Tensor], + covar_module: Optional[torch.nn.Module] = None, + num_inducing: int = 10, + input_batch_shape: torch.Size = torch.Size([]), + ) -> torch.Tensor: + """ + Allocate inducing points by either using the unique input data directly + or falling back to another allocation method if there are too many unique points. + + Args: + inputs (torch.Tensor): A tensor of shape (n, d) containing the input data. + covar_module (torch.nn.Module, optional): Kernel covariance module; included for API compatibility, but not used here. + num_inducing (int, optional): The number of inducing points to generate. + input_batch_shape (torch.Size, optional): Batch shape, defaults to an empty size; included for API compatibility, but not used here. + + Returns: + torch.Tensor: A (num_inducing, d)-dimensional tensor of inducing points. + """ + # Ensure inputs are not None + if inputs is None and self.bounds is not None: + return self.dummy_allocator.allocate_inducing_points( + inputs=inputs, + covar_module=covar_module, + num_inducing=num_inducing, + input_batch_shape=input_batch_shape, + ) + elif inputs is None and self.bounds is None: + raise ValueError(f"Either inputs or bounds must be provided.{self.bounds}") + + assert ( + inputs is not None + ), "inputs should not be None here" # to make mypy happy + + unique_inputs = torch.unique(inputs, dim=0) + + # If there are fewer unique points than required, return unique inputs directly + if unique_inputs.shape[0] <= num_inducing: + return unique_inputs + + # Otherwise, fall back to the provided allocator (e.g., KMeansAllocator) + if inputs.shape[0] <= num_inducing: + return inputs + else: + return self.fallback_allocator.allocate_inducing_points( + inputs=inputs, + covar_module=covar_module, + num_inducing=num_inducing, + input_batch_shape=input_batch_shape, + ) + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Get configuration options for the AutoAllocator. + + Args: + config (Config): Configuration object. + name (str, optional): Name of the allocator, defaults to None. + options (Dict[str, Any], optional): Additional options, defaults to None. + + Returns: + Dict[str, Any]: Configuration options for the AutoAllocator. + """ + if name is None: + name = cls.__name__ + lb = config.gettensor("common", "lb") + ub = config.gettensor("common", "ub") + bounds = torch.stack((lb, ub)) + fallback_allocator_cls = config.getobj( + name, "fallback_allocator", fallback=KMeansAllocator + ) + fallback_allocator = ( + fallback_allocator_cls.from_config(config) + if hasattr(fallback_allocator_cls, "from_config") + else fallback_allocator_cls() + ) + + return {"fallback_allocator": fallback_allocator, "bounds": bounds} + + +class FixedAllocator(BaseAllocator): + def __init__( + self, points: torch.Tensor, bounds: Optional[torch.Tensor] = None + ) -> None: + """Initialize the FixedAllocator with inducing points and bounds. + + Args: + points (torch.Tensor): Inducing points to use. + bounds (torch.Tensor, optional): Bounds for allocating points. Should be of shape (2, d). + """ + super().__init__(bounds=bounds) + self.points = points + + def _get_quality_function(self) -> None: + """FixedAllocator does not require a quality function, so this returns None.""" + return None + + def allocate_inducing_points( + self, + inputs: Optional[torch.Tensor] = None, + covar_module: Optional[torch.nn.Module] = None, + num_inducing: int = 10, + input_batch_shape: torch.Size = torch.Size([]), + ) -> torch.Tensor: + """Allocate inducing points by returning the fixed inducing points. + + Args: + inputs (torch.Tensor): Input tensor, not required for FixedAllocator. + covar_module (torch.nn.Module, optional): Kernel covariance module; included for API compatibility, but not used here. + num_inducing (int, optional): The number of inducing points to generate. Defaults to 10. + input_batch_shape (torch.Size, optional): Batch shape, defaults to an empty size; included for API compatibility, but not used here. + + Returns: + torch.Tensor: The fixed inducing points. + """ + return self.points + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Get configuration options for the FixedAllocator. + + Args: + config (Config): Configuration object. + name (str, optional): Name of the allocator, defaults to None. + options (Dict[str, Any], optional): Additional options, defaults to None. + + Returns: + Dict[str, Any]: Configuration options for the FixedAllocator. + """ + if name is None: + name = cls.__name__ + lb = config.gettensor("common", "lb") + ub = config.gettensor("common", "ub") + bounds = torch.stack((lb, ub)) + num_inducing = config.getint("common", "num_inducing", fallback=99) + fallback_allocator = config.getobj( + name, "fallback_allocator", fallback=DummyAllocator(bounds=bounds) + ) + points = config.gettensor( + name, + "points", + fallback=fallback_allocator.allocate_inducing_points( + num_inducing=num_inducing + ), + ) + return {"points": points, "bounds": bounds} + + +class GreedyVarianceReduction(BaseGreedyVarianceReduction, ConfigurableMixin): + def __init__(self, bounds: Optional[torch.Tensor] = None) -> None: + """Initialize the GreedyVarianceReduction with bounds. + + Args: + bounds (torch.Tensor, optional): Bounds for allocating points. Should be of shape (2, d). + """ + super().__init__() + + self.bounds = bounds + if bounds is not None: + self.dummy_allocator = DummyAllocator(bounds) + self.dim = self._initialize_dim() + + def _initialize_dim(self) -> Optional[int]: + """Initialize the dimension `dim` based on the bounds, if available. + + Returns: + int: The dimension `d` if bounds are provided, or None otherwise. + """ + if self.bounds is not None: + assert self.bounds.shape[0] == 2, "Bounds must have shape (2, d)!" + lb, ub = self.bounds[0], self.bounds[1] + for i, (l, u) in enumerate(zip(lb, ub)): + assert ( + l <= u + ), f"Lower bound {l} is not less than or equal to upper bound {u} on dimension {i}!" + return self.bounds.shape[1] + return None + + def allocate_inducing_points( + self, + inputs: Optional[torch.Tensor] = None, + covar_module: Optional[torch.nn.Module] = None, + num_inducing: int = 10, + input_batch_shape: torch.Size = torch.Size([]), + ) -> torch.Tensor: + """Allocate inducing points using the GreedyVarianceReduction strategy. + + Args: + inputs (torch.Tensor): Input tensor, not required for GreedyVarianceReduction. + covar_module (torch.nn.Module, optional): Kernel covariance module; included for API compatibility, but not used here. + num_inducing (int, optional): The number of inducing points to generate. Defaults to 10. + input_batch_shape (torch.Size, optional): Batch shape, defaults to an empty size; included for API compatibility, but not used here. + + Returns: + torch.Tensor: The allocated inducing points. + """ + if inputs is None and self.bounds is not None: + return self.dummy_allocator.allocate_inducing_points( + inputs=inputs, + covar_module=covar_module, + num_inducing=num_inducing, + input_batch_shape=input_batch_shape, + ) + elif inputs is None and self.bounds is None: + raise ValueError("Either inputs or bounds must be provided.") + else: + return super().allocate_inducing_points( + inputs=inputs, + covar_module=covar_module, + num_inducing=num_inducing, + input_batch_shape=input_batch_shape, + ) + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Get configuration options for the GreedyVarianceReduction allocator. + + Args: + config (Config): Configuration object. + name (str, optional): Name of the allocator, defaults to None. + options (Dict[str, Any], optional): Additional options, defaults to None. + + Returns: + Dict[str, Any]: Configuration options for the GreedyVarianceReduction allocator. + """ + if name is None: + name = cls.__name__ + lb = config.gettensor("common", "lb") + ub = config.gettensor("common", "ub") + bounds = torch.stack((lb, ub)) + return {"bounds": bounds} diff --git a/aepsych/models/monotonic_projection_gp.py b/aepsych/models/monotonic_projection_gp.py index bdfe50bdf..80141eabc 100644 --- a/aepsych/models/monotonic_projection_gp.py +++ b/aepsych/models/monotonic_projection_gp.py @@ -15,7 +15,9 @@ from aepsych.config import Config from aepsych.factory.default import default_mean_covar_factory from aepsych.models.gp_classification import GPClassificationModel +from aepsych.models.inducing_point_allocators import AutoAllocator from aepsych.utils import get_optimizer_options +from botorch.models.utils.inducing_point_allocators import InducingPointAllocator from botorch.posteriors.gpytorch import GPyTorchPosterior from gpytorch.likelihoods import Likelihood from statsmodels.stats.moment_helpers import corr2cov, cov2corr @@ -104,7 +106,7 @@ def __init__( likelihood: Optional[Likelihood] = None, inducing_size: Optional[int] = None, max_fit_time: Optional[float] = None, - inducing_point_method: str = "auto", + inducing_point_method: InducingPointAllocator = AutoAllocator(), optimizer_options: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the MonotonicProjectionGP model. @@ -126,8 +128,8 @@ def __init__( inducing_size (int, optional): The number of inducing points to use. Defaults to None. max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None, there is no limit to the fitting time. Defaults to None. - inducing_point_method (string): The method to use to select the inducing points. Defaults to "auto". - """ + inducing_point_method (InducingPointAllocator, optional): The method to use for allocating inducing points. + Defaults to AutoAllocator.""" assert len(monotonic_dims) > 0 self.monotonic_dims = [int(d) for d in monotonic_dims] self.mon_grid_size = monotonic_grid_size @@ -243,10 +245,14 @@ def from_config(cls, config: Config) -> MonotonicProjectionGP: mean, covar = mean_covar_factory(config) max_fit_time = config.getfloat(classname, "max_fit_time", fallback=None) - inducing_point_method = config.get( - classname, "inducing_point_method", fallback="auto" + inducing_point_method_class = config.getobj( + classname, "inducing_point_method", fallback=AutoAllocator ) - + # Check if allocator class has a `from_config` method + if hasattr(inducing_point_method_class, "from_config"): + inducing_point_method = inducing_point_method_class.from_config(config) + else: + inducing_point_method = inducing_point_method_class() likelihood_cls = config.getobj(classname, "likelihood", fallback=None) if likelihood_cls is not None: diff --git a/aepsych/models/monotonic_rejection_gp.py b/aepsych/models/monotonic_rejection_gp.py index 02a7c6c0c..5025ad0a9 100644 --- a/aepsych/models/monotonic_rejection_gp.py +++ b/aepsych/models/monotonic_rejection_gp.py @@ -19,9 +19,14 @@ from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad from aepsych.models.base import AEPsychMixin +from aepsych.models.inducing_point_allocators import AutoAllocator, SobolAllocator from aepsych.models.utils import select_inducing_points from aepsych.utils import _process_bounds, get_optimizer_options, promote_0d from botorch.fit import fit_gpytorch_mll +from botorch.models.utils.inducing_point_allocators import ( + GreedyVarianceReduction, + InducingPointAllocator, +) from gpytorch.kernels import Kernel from gpytorch.likelihoods import BernoulliLikelihood, Likelihood from gpytorch.means import Mean @@ -29,6 +34,7 @@ from gpytorch.models import ApproximateGP from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy from scipy.stats import norm +from torch import Tensor class MonotonicRejectionGP(AEPsychMixin, ApproximateGP): @@ -61,7 +67,7 @@ def __init__( num_induc: int = 25, num_samples: int = 250, num_rejection_samples: int = 5000, - inducing_point_method: str = "auto", + inducing_point_method: InducingPointAllocator = AutoAllocator(), optimizer_options: Optional[Dict[str, Any]] = None, ) -> None: """Initialize MonotonicRejectionGP. @@ -82,7 +88,7 @@ def __init__( num_samples (int): Number of samples for estimating posterior on preDict or acquisition function evaluation. Defaults to 250. num_rejection_samples (int): Number of samples used for rejection sampling. Defaults to 4096. - inducing_point_method (str): Method for selecting inducing points. Defaults to "auto". + inducing_point_method (InducingPointAllocator): Method for selecting inducing points. Defaults to AutoAllocator(). optimizer_options (Dict[str, Any], optional): Optimizer options to pass to the SciPy optimizer during fitting. Assumes we are using L-BFGS-B. """ @@ -92,10 +98,10 @@ def __init__( self.inducing_size = num_induc self.inducing_point_method = inducing_point_method + inducing_points = select_inducing_points( + allocator=SobolAllocator(bounds=torch.stack((self.lb, self.ub))), inducing_size=self.inducing_size, - bounds=self.bounds, - method="sobol", ) inducing_points_aug = self._augment_with_deriv_index(inducing_points, 0) @@ -162,11 +168,11 @@ def fit(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs) -> None: self.set_train_data(train_x, train_y) self.inducing_points = select_inducing_points( + allocator=self.inducing_point_method, inducing_size=self.inducing_size, covar_module=self.covar_module, X=self.train_inputs[0], bounds=self.bounds, - method=self.inducing_point_method, ) self._set_model(train_x, train_y) diff --git a/aepsych/models/semi_p.py b/aepsych/models/semi_p.py index c038a686b..3217c7c56 100644 --- a/aepsych/models/semi_p.py +++ b/aepsych/models/semi_p.py @@ -18,9 +18,11 @@ from aepsych.config import Config from aepsych.likelihoods import BernoulliObjectiveLikelihood, LinearBernoulliLikelihood from aepsych.models import GPClassificationModel +from aepsych.models.inducing_point_allocators import AutoAllocator from aepsych.utils import _process_bounds, get_optimizer_options, promote_0d from aepsych.utils_logging import getLogger from botorch.acquisition.objective import PosteriorTransform +from botorch.models.utils.inducing_point_allocators import InducingPointAllocator from botorch.optim.fit import fit_gpytorch_mll_scipy from botorch.posteriors import GPyTorchPosterior from gpytorch.distributions import MultivariateNormal @@ -259,7 +261,7 @@ def __init__( slope_mean: float = 2, inducing_size: Optional[int] = None, max_fit_time: Optional[float] = None, - inducing_point_method: str = "auto", + inducing_point_method: InducingPointAllocator = AutoAllocator(), optimizer_options: Optional[Dict[str, Any]] = None, ) -> None: """ @@ -279,11 +281,7 @@ def __init__( inducing_size (int, optional): Number of inducing points. Defaults to 99. max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None, there is no limit to the fitting time. - inducing_point_method (string): The method to use to select the inducing points. Defaults to "auto". - If "sobol", a number of Sobol points equal to inducing_size will be selected. - If "pivoted_chol", selects points based on the pivoted Cholesky heuristic. - If "kmeans++", selects points by performing kmeans++ clustering on the training data. - If "auto", tries to determine the best method automatically. + inducing_point_method (InducingPointAllocator): The method to use to select the inducing points. Defaults to AutoAllocator. optimizer_options (Dict[str, Any], optional): Optimizer options to pass to the SciPy optimizer during fitting. Assumes we are using L-BFGS-B. """ @@ -352,10 +350,14 @@ def from_config(cls, config: Config) -> SemiParametricGPModel: max_fit_time = config.getfloat(classname, "max_fit_time", fallback=None) - inducing_point_method = config.get( - classname, "inducing_point_method", fallback="auto" + inducing_point_method_class = config.getobj( + classname, "inducing_point_method", fallback=AutoAllocator ) - + # Check if allocator class has a `from_config` method + if hasattr(inducing_point_method_class, "from_config"): + inducing_point_method = inducing_point_method_class.from_config(config) + else: + inducing_point_method = inducing_point_method_class() likelihood_cls = config.getobj(classname, "likelihood", fallback=None) if hasattr(likelihood_cls, "from_config"): @@ -528,7 +530,7 @@ def __init__( slope_mean: float = 2, inducing_size: Optional[int] = None, max_fit_time: Optional[float] = None, - inducing_point_method: str = "auto", + inducing_point_method: InducingPointAllocator = AutoAllocator(), optimizer_options: Optional[Dict[str, Any]] = None, ) -> None: """ @@ -548,11 +550,7 @@ def __init__( inducing_size (int, optional): Number of inducing points. Defaults to 99. max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None, there is no limit to the fitting time. - inducing_point_method (string): The method to use to select the inducing points. Defaults to "auto". - If "sobol", a number of Sobol points equal to inducing_size will be selected. - If "pivoted_chol", selects points based on the pivoted Cholesky heuristic. - If "kmeans++", selects points by performing kmeans++ clustering on the training data. - If "auto", tries to determine the best method automatically. + inducing_point_method (InducingPointAllocator): The method to use to select the inducing points. Defaults to AutoAllocator. optimizer_options (Dict[str, Any], optional): Optimizer options to pass to the SciPy optimizer during fitting. Assumes we are using L-BFGS-B. """ @@ -675,10 +673,14 @@ def from_config(cls, config: Config) -> HadamardSemiPModel: max_fit_time = config.getfloat(classname, "max_fit_time", fallback=None) - inducing_point_method = config.get( - classname, "inducing_point_method", fallback="auto" + inducing_point_method_class = config.getobj( + classname, "inducing_point_method", fallback=AutoAllocator ) - + # Check if allocator class has a `from_config` method + if hasattr(inducing_point_method_class, "from_config"): + inducing_point_method = inducing_point_method_class.from_config(config) + else: + inducing_point_method = inducing_point_method_class() likelihood_cls = config.getobj(classname, "likelihood", fallback=None) if hasattr(likelihood_cls, "from_config"): likelihood = likelihood_cls.from_config(config) diff --git a/aepsych/models/utils.py b/aepsych/models/utils.py index 2f0c3fe3a..2050b84cc 100644 --- a/aepsych/models/utils.py +++ b/aepsych/models/utils.py @@ -17,7 +17,10 @@ ScalarizedPosteriorTransform, ) from botorch.models.model import Model -from botorch.models.utils.inducing_point_allocators import GreedyVarianceReduction +from botorch.models.utils.inducing_point_allocators import ( + GreedyVarianceReduction, + InducingPointAllocator, +) from botorch.optim import optimize_acqf from botorch.posteriors import GPyTorchPosterior from botorch.utils.sampling import draw_sobol_samples @@ -61,72 +64,81 @@ def compute_p_quantile( def select_inducing_points( inducing_size: int, - covar_module: Kernel = None, + allocator: Union[str, InducingPointAllocator], + covar_module: Optional[torch.nn.Module] = None, X: Optional[torch.Tensor] = None, bounds: Optional[torch.Tensor] = None, - method: str = "auto", ) -> torch.Tensor: - """Select inducing points for GP model + """ + Select inducing points using a specified allocator instance or legacy method. Args: - inducing_size (int): Number of inducing points to select. - covar_module (Kernel): The kernel module to use for inducing point selection. - X (torch.Tensor, optional): The training data. - bounds (torch.Tensor, optional): The bounds of the input space. - method (str): The method to use for inducing point selection. One of - "pivoted_chol", "kmeans++", "auto", or "sobol". + inducing_size (int): Number of inducing points. + allocator (Union[str, InducingPointAllocator]): An inducing point allocator or a legacy string indicating method. + covar_module (torch.nn.Module, optional): Covariance module, required for some allocators. + X (torch.Tensor, optional): Input data tensor, required for most allocators. + bounds (torch.Tensor, optional): Bounds for Sobol sampling in legacy mode. Returns: - torch.Tensor: The selected inducing points. + torch.Tensor: Selected inducing points. """ - with torch.no_grad(): - assert ( - method - in ( - "pivoted_chol", - "kmeans++", - "auto", - "sobol", - ) - ), f"Inducing point method should be one of pivoted_chol, kmeans++, sobol, or auto; got {method}" + # Handle legacy string methods with a deprecation warning + if isinstance(allocator, str): + warnings.warn( + f"Using string '{allocator}' for inducing point method is deprecated. " + "Please use an InducingPointAllocator class instead.", + DeprecationWarning, + ) - if method == "sobol": - assert bounds is not None, "Must pass bounds for sobol inducing points!" + if allocator == "sobol": + assert ( + bounds is not None + ), "Bounds must be provided for Sobol inducing points!" inducing_points = ( draw_sobol_samples(bounds=bounds, n=inducing_size, q=1) .squeeze() - .to(bounds) + .to(bounds.device) ) - if len(inducing_points.shape) == 1: - inducing_points = inducing_points.reshape(-1, 1) + if inducing_points.ndim == 1: + inducing_points = inducing_points.view(-1, 1) return inducing_points - assert X is not None, "Must pass X for non-sobol inducing point selection!" - # remove dupes from X, which is both wasteful for inducing points - # and would break kmeans++ + assert X is not None, "Must pass X for non-Sobol inducing point selection!" + unique_X = torch.unique(X, dim=0) - if method == "auto": + if allocator == "auto": if unique_X.shape[0] <= inducing_size: return unique_X else: - method = "kmeans++" + allocator = "kmeans++" - if method == "pivoted_chol": + if allocator == "pivoted_chol": inducing_point_allocator = GreedyVarianceReduction() inducing_points = inducing_point_allocator.allocate_inducing_points( inputs=X, covar_module=covar_module, num_inducing=inducing_size, input_batch_shape=torch.Size([]), - ).to(X) - elif method == "kmeans++": - # initialize using kmeans + ).to(X.device) + + elif allocator == "kmeans++": inducing_points = torch.tensor( kmeans2(unique_X.cpu().numpy(), inducing_size, minit="++")[0], dtype=X.dtype, - ).to(X) + ).to(X.device) + return inducing_points + # Call allocate_inducing_points with allocator instance + inducing_points = allocator.allocate_inducing_points( + inputs=X, + covar_module=covar_module, + num_inducing=inducing_size, + input_batch_shape=torch.Size([]), + ) + + return inducing_points + def get_probability_space( likelihood: Likelihood, posterior: GPyTorchPosterior diff --git a/tests/models/test_gp_classification.py b/tests/models/test_gp_classification.py index f8218687a..ccb6f902d 100644 --- a/tests/models/test_gp_classification.py +++ b/tests/models/test_gp_classification.py @@ -801,7 +801,17 @@ def test_hyperparam_consistency(self): m1 = GPClassificationModel(lb=[1, 2], ub=[3, 4]) m2 = GPClassificationModel.from_config( - config=Config(config_dict={"common": {"lb": "[1,2]", "ub": "[3,4]"}}) + config=Config( + config_dict={ + "common": { + "parnames": ["par1", "par2"], + "lb": "[1, 2]", + "ub": "[3, 4]", + }, + "par1": {"value_type": "float"}, + "par2": {"value_type": "float"}, + } + ) ) self.assertTrue(isinstance(m1.covar_module, type(m2.covar_module))) self.assertTrue(isinstance(m1.covar_module, type(m2.covar_module))) diff --git a/tests/models/test_semi_p.py b/tests/models/test_semi_p.py index becfc130a..d5fc80d9c 100644 --- a/tests/models/test_semi_p.py +++ b/tests/models/test_semi_p.py @@ -26,6 +26,7 @@ from aepsych.likelihoods import BernoulliObjectiveLikelihood from aepsych.likelihoods.semi_p import LinearBernoulliLikelihood from aepsych.models import HadamardSemiPModel, SemiParametricGPModel +from aepsych.models.inducing_point_allocators import AutoAllocator from aepsych.models.semi_p import _hadamard_mvn_approx, semi_p_posterior_transform from aepsych.strategy import SequentialStrategy, Strategy from gpytorch.distributions import MultivariateNormal @@ -39,7 +40,7 @@ def _hadamard_model_constructor(lb, ub, stim_dim, floor, objective=FloorLogitObj stim_dim=stim_dim, likelihood=BernoulliObjectiveLikelihood(objective=objective(floor=floor)), inducing_size=10, - inducing_point_method="auto", + inducing_point_method=AutoAllocator(), max_fit_time=0.5, ) @@ -51,7 +52,7 @@ def _semip_model_constructor(lb, ub, stim_dim, floor, objective=FloorLogitObject stim_dim=stim_dim, likelihood=LinearBernoulliLikelihood(objective=objective(floor=floor)), inducing_size=10, - inducing_point_method="auto", + inducing_point_method=AutoAllocator(), ) @@ -97,7 +98,7 @@ def test_mc_generation(self, objective): stim_dim=self.stim_dim, likelihood=LinearBernoulliLikelihood(), inducing_size=10, - inducing_point_method="auto", + inducing_point_method=AutoAllocator(), ) generator = OptimizeAcqfGenerator( @@ -303,7 +304,7 @@ def test_slope_mean_setting(self): likelihood=LinearBernoulliLikelihood(), inducing_size=10, slope_mean=slope_mean, - inducing_point_method="auto", + inducing_point_method=AutoAllocator(), ) with self.subTest(model=model, slope_mean=slope_mean): self.assertEqual(model.mean_module.constant[-1], slope_mean) @@ -314,7 +315,7 @@ def test_slope_mean_setting(self): likelihood=BernoulliObjectiveLikelihood(objective=ProbitObjective()), inducing_size=10, slope_mean=slope_mean, - inducing_point_method="auto", + inducing_point_method=AutoAllocator(), ) with self.subTest(model=model, slope_mean=slope_mean): self.assertEqual(model.slope_mean_module.constant.item(), slope_mean) diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index b95864588..ba2f6bad1 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -10,7 +10,16 @@ import numpy as np import torch from aepsych.models import GPClassificationModel +from aepsych.models.inducing_point_allocators import ( + AutoAllocator, + DummyAllocator, + FixedAllocator, + GreedyVarianceReduction, + KMeansAllocator, + SobolAllocator, +) from aepsych.models.utils import select_inducing_points + from sklearn.datasets import make_classification @@ -38,11 +47,11 @@ def test_select_inducing_points(self): self.assertTrue( np.allclose( select_inducing_points( + allocator=AutoAllocator(), inducing_size=inducing_size, covar_module=model.covar_module, X=model.train_inputs[0], bounds=model.bounds, - method="auto", ), X[:10].sort(0).values, ) @@ -53,11 +62,11 @@ def test_select_inducing_points(self): self.assertTrue( len( select_inducing_points( + allocator=AutoAllocator(), inducing_size=inducing_size, covar_module=model.covar_module, X=model.train_inputs[0], bounds=model.bounds, - method="auto", ) ) <= 20 @@ -66,11 +75,10 @@ def test_select_inducing_points(self): self.assertTrue( len( select_inducing_points( + allocator=GreedyVarianceReduction(), inducing_size=inducing_size, covar_module=model.covar_module, X=model.train_inputs[0], - bounds=model.bounds, - method="pivoted_chol", ) ) <= 20 @@ -79,24 +87,67 @@ def test_select_inducing_points(self): self.assertEqual( len( select_inducing_points( + allocator=KMeansAllocator(), inducing_size=inducing_size, covar_module=model.covar_module, X=model.train_inputs[0], bounds=model.bounds, - method="kmeans++", ) ), 20, ) - - with self.assertRaises(AssertionError): - select_inducing_points( - inducing_size=inducing_size, - covar_module=model.covar_module, - X=model.train_inputs[0], - bounds=model.bounds, - method="12345", + self.assertTrue( + len( + select_inducing_points( + allocator="auto", + inducing_size=inducing_size, + covar_module=model.covar_module, + X=model.train_inputs[0], + bounds=model.bounds, + ) + ) + <= 20 + ) + self.assertTrue( + len( + select_inducing_points( + allocator=SobolAllocator( + bounds=torch.stack([torch.tensor([0]), torch.tensor([1])]) + ), + inducing_size=inducing_size, + covar_module=model.covar_module, + X=model.train_inputs[0], + bounds=model.bounds, + ) + ) + <= 20 + ) + self.assertTrue( + len( + select_inducing_points( + allocator=DummyAllocator( + bounds=torch.stack([torch.tensor([0]), torch.tensor([1])]) + ), + inducing_size=inducing_size, + covar_module=model.covar_module, + X=model.train_inputs[0], + bounds=model.bounds, + ) ) + <= 20 + ) + self.assertTrue( + len( + select_inducing_points( + allocator=FixedAllocator(points=torch.tensor([0, 1, 2, 3])), + inducing_size=inducing_size, + covar_module=model.covar_module, + X=model.train_inputs[0], + bounds=model.bounds, + ) + ) + <= 20 + ) if __name__ == "__main__": diff --git a/tests/test_config.py b/tests/test_config.py index 599b592ad..70641506f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -27,6 +27,7 @@ MonotonicRejectionGP, PairwiseProbitModel, ) +from aepsych.models.inducing_point_allocators import SobolAllocator from aepsych.server import AEPsychServer from aepsych.server.message_handlers.handle_setup import configure from aepsych.strategy import SequentialStrategy, Strategy @@ -999,7 +1000,7 @@ def test_semip_config(self): [HadamardSemiPModel] stim_dim = 1 inducing_size = 10 - inducing_point_method = sobol + inducing_point_method = SobolAllocator likelihood = BernoulliObjectiveLikelihood [BernoulliObjectiveLikelihood] @@ -1011,6 +1012,7 @@ def test_semip_config(self): [OptimizeAcqfGenerator] restarts = 10 samps = 1000 + """ config = Config() @@ -1026,7 +1028,14 @@ def test_semip_config(self): self.assertTrue(model.dim == 2) self.assertTrue(model.inducing_size == 10) self.assertTrue(model.stim_dim == 1) - self.assertTrue(model.inducing_point_method == "sobol") + + # Verify the allocator and bounds + self.assertTrue(isinstance(model.inducing_point_method, SobolAllocator)) + expected_bounds = torch.tensor([[0.0, 0.0], [1.0, 1.0]], dtype=torch.float64) + self.assertTrue( + torch.equal(model.inducing_point_method.bounds, expected_bounds) + ) + self.assertTrue(isinstance(model.likelihood, BernoulliObjectiveLikelihood)) self.assertTrue(isinstance(model.likelihood.objective, FloorGumbelObjective)) diff --git a/tests/test_points_allocators.py b/tests/test_points_allocators.py new file mode 100644 index 000000000..3c604dc3c --- /dev/null +++ b/tests/test_points_allocators.py @@ -0,0 +1,606 @@ +import unittest + +import torch +from aepsych.config import Config +from aepsych.models.gp_classification import GPClassificationModel +from aepsych.models.inducing_point_allocators import ( + AutoAllocator, + DummyAllocator, + FixedAllocator, + GreedyVarianceReduction, + KMeansAllocator, + SobolAllocator, +) +from aepsych.models.utils import select_inducing_points + +from aepsych.strategy import Strategy +from aepsych.transforms.parameters import ParameterTransforms, transform_options +from botorch.models.utils.inducing_point_allocators import GreedyImprovementReduction +from botorch.utils.sampling import draw_sobol_samples +from sklearn.datasets import make_classification + + +class TestInducingPointAllocators(unittest.TestCase): + def test_sobol_allocator_from_config(self): + config_str = """ + [common] + parnames = [par1] + + [par1] + par_type = continuous + lower_bound = 0.0 + upper_bound = 1.0 + log_scale = true + + """ + config = Config() + config.update(config_str=config_str) + allocator = SobolAllocator.from_config(config) + + # Check if bounds are correctly loaded + expected_bounds = torch.tensor([[0.0], [1.0]]) + self.assertTrue(torch.equal(allocator.bounds, expected_bounds)) + + def test_kmeans_allocator_from_config(self): + config_str = """ + [common] + parnames = [par1] + + [par1] + par_type = continuous + lower_bound = 0.0 + upper_bound = 1.0 + log_scale = true + + [KMeansAllocator] + """ + config = Config() + config.update(config_str=config_str) + allocator = KMeansAllocator.from_config(config) + + # No specific configuration to check, just test instantiation + self.assertTrue(isinstance(allocator, KMeansAllocator)) + + def test_auto_allocator_from_config_with_fallback(self): + config_str = """ + [common] + parnames = [par1] + + [par1] + par_type = continuous + lower_bound = 0.0 + upper_bound = 1.0 + log_scale = true + + """ + config = Config() + config.update(config_str=config_str) + allocator = AutoAllocator.from_config(config) + + # Check if fallback allocator is an instance of SobolAllocator with correct bounds + self.assertTrue(isinstance(allocator.fallback_allocator, KMeansAllocator)) + + def test_sobol_allocator_allocate_inducing_points(self): + bounds = torch.tensor([[0.0], [1.0]]) + allocator = SobolAllocator(bounds=bounds) + inducing_points = allocator.allocate_inducing_points(num_inducing=5) + + # Check shape and bounds of inducing points + self.assertEqual(inducing_points.shape, (5, 1)) + self.assertTrue( + torch.all(inducing_points >= bounds[0]) + and torch.all(inducing_points <= bounds[1]) + ) + + def test_kmeans_allocator_allocate_inducing_points(self): + inputs = torch.rand(100, 2) # 100 points in 2D + allocator = KMeansAllocator() + inducing_points = allocator.allocate_inducing_points( + inputs=inputs, num_inducing=10 + ) + + # Check shape of inducing points + self.assertEqual(inducing_points.shape, (10, 2)) + + def test_auto_allocator_with_kmeans_fallback(self): + inputs = torch.rand(50, 2) + fallback_allocator = KMeansAllocator() + allocator = AutoAllocator(fallback_allocator=fallback_allocator) + inducing_points = allocator.allocate_inducing_points( + inputs=inputs, num_inducing=10 + ) + + # Check shape of inducing points and that fallback allocator is used + self.assertEqual(inducing_points.shape, (10, 2)) + + def test_select_inducing_points_legacy(self): + with self.assertWarns(DeprecationWarning): + # Call select_inducing_points directly with a string for allocator to trigger the warning + bounds = torch.tensor([[0.0], [1.0]]) + points = select_inducing_points( + inducing_size=5, + allocator="sobol", # Legacy string argument to trigger DeprecationWarning + bounds=bounds, + ) + self.assertEqual(points.shape, (5, 1)) + + def test_auto_allocator_allocate_inducing_points(self): + train_X = torch.tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]]) + train_Y = torch.tensor([[1.0], [2.0], [3.0]]) + model = GPClassificationModel( + lb=torch.tensor([0, 0]), + ub=torch.tensor([4, 4]), + inducing_point_method=AutoAllocator(), + inducing_size=3, + ) + self.assertTrue(model.last_inducing_points_method == "DummyAllocator") + auto_inducing_points = AutoAllocator( + bounds=torch.stack([torch.tensor([0, 0]), torch.tensor([4, 4])]) + ).allocate_inducing_points( + inputs=train_X, + covar_module=model.covar_module, + num_inducing=model.inducing_size, + ) + inital_inducing_points = DummyAllocator( + bounds=torch.stack([torch.tensor([0, 0]), torch.tensor([4, 4])]) + ).allocate_inducing_points( + inputs=train_X, + covar_module=model.covar_module, + num_inducing=model.inducing_size, + ) + + # Should be different from the initial inducing points + self.assertFalse( + torch.allclose( + auto_inducing_points, model.variational_strategy.inducing_points + ) + ) + self.assertTrue( + torch.allclose( + inital_inducing_points, model.variational_strategy.inducing_points + ) + ) + + model.fit(train_X, train_Y) + self.assertTrue(model.last_inducing_points_method == "AutoAllocator") + self.assertEqual( + model.variational_strategy.inducing_points.shape, auto_inducing_points.shape + ) + + # Check that inducing points are updated after fitting + self.assertTrue( + torch.allclose( + auto_inducing_points, model.variational_strategy.inducing_points + ) + ) + + def test_sobol_allocator_from_model_config(self): + config_str = """ + [common] + parnames = [par1] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat] + + [par1] + par_type = continuous + lower_bound = 10 + upper_bound = 1000 + log_scale = True + + [init_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + min_asks = 2 + model = GPClassificationModel + + [GPClassificationModel] + inducing_point_method = SobolAllocator + inducing_size = 2 + """ + + config = Config() + config.update(config_str=config_str) + strat = Strategy.from_config(config, "init_strat") + print(strat.model.inducing_point_method) + self.assertTrue(isinstance(strat.model.inducing_point_method, SobolAllocator)) + + # check that the bounds are scaled correctly + self.assertFalse( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[10], [100]]) + ) + ) + self.assertTrue( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[0], [1]]) + ) + ) + + def test_kmeans_allocator_from_model_config(self): + config_str = """ + [common] + parnames = [par1] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat] + + [par1] + par_type = continuous + lower_bound = 10 + upper_bound = 1000 + log_scale = True + + [init_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + min_asks = 2 + model = GPClassificationModel + + [GPClassificationModel] + inducing_point_method = KMeansAllocator + inducing_size = 2 + """ + + config = Config() + config.update(config_str=config_str) + strat = Strategy.from_config(config, "init_strat") + self.assertTrue(isinstance(strat.model.inducing_point_method, KMeansAllocator)) + + # check that the bounds are scaled correctly + self.assertFalse( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[10], [100]]) + ) + ) + self.assertTrue( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[0], [1]]) + ) + ) + + def test_auto_allocator_from_model_config(self): + config_str = """ + [common] + parnames = [par1] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat] + + [par1] + par_type = continuous + lower_bound = 10 + upper_bound = 1000 + log_scale = True + + [init_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + min_asks = 2 + model = GPClassificationModel + + [GPClassificationModel] + inducing_point_method = AutoAllocator + inducing_size = 2 + """ + + config = Config() + config.update(config_str=config_str) + strat = Strategy.from_config(config, "init_strat") + self.assertTrue(isinstance(strat.model.inducing_point_method, AutoAllocator)) + + # check that the bounds are scaled correctly + self.assertFalse( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[10], [100]]) + ) + ) + self.assertTrue( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[0], [1]]) + ) + ) + + def test_dummy_allocator_from_model_config(self): + config_str = """ + [common] + parnames = [par1] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat] + + [par1] + par_type = continuous + lower_bound = 10 + upper_bound = 1000 + log_scale = True + + [init_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + min_asks = 2 + model = GPClassificationModel + + [GPClassificationModel] + inducing_point_method = DummyAllocator + inducing_size = 2 + """ + + config = Config() + config.update(config_str=config_str) + strat = Strategy.from_config(config, "init_strat") + self.assertTrue(isinstance(strat.model.inducing_point_method, DummyAllocator)) + + # check that the bounds are scaled correctly + self.assertFalse( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[10], [100]]) + ) + ) + self.assertTrue( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[0], [1]]) + ) + ) + + def test_inducing_point_before_and_after_auto(self): + config_str = """ + [common] + parnames = [par1] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat] + + [par1] + par_type = continuous + lower_bound = 10 + upper_bound = 1000 + log_scale = True + + [init_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + min_asks = 2 + model = GPClassificationModel + + [GPClassificationModel] + inducing_point_method = AutoAllocator + inducing_size = 2 + """ + + config = Config() + config.update(config_str=config_str) + strat = Strategy.from_config(config, "init_strat") + self.assertTrue(isinstance(strat.model.inducing_point_method, AutoAllocator)) + + # check that the bounds are scaled correctly + self.assertFalse( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[10], [100]]) + ) + ) + self.assertTrue( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[0], [1]]) + ) + ) + + train_X = torch.tensor([[0.0], [1.0]]) + train_Y = torch.tensor([[1.0], [0.0]]) + + auto_inducing_points = AutoAllocator( + bounds=torch.stack([torch.tensor([0]), torch.tensor([1])]) + ).allocate_inducing_points( + inputs=train_X, + covar_module=strat.model.covar_module, + num_inducing=strat.model.inducing_size, + ) + inital_inducing_points = DummyAllocator( + bounds=torch.stack([torch.tensor([0]), torch.tensor([1])]) + ).allocate_inducing_points( + inputs=train_X, + covar_module=strat.model.covar_module, + num_inducing=strat.model.inducing_size, + ) + + # Should be different from the initial inducing points + self.assertFalse( + torch.allclose( + auto_inducing_points, strat.model.variational_strategy.inducing_points + ) + ) + # Should be the same as the initial inducing points + self.assertTrue( + torch.allclose( + inital_inducing_points, strat.model.variational_strategy.inducing_points + ) + ) + + # Fit the model and check that the inducing points are updated + strat.add_data(train_X, train_Y) + strat.fit() + self.assertEqual( + strat.model.variational_strategy.inducing_points.shape, + auto_inducing_points.shape, + ) + + def test_fixed_allocator_allocate_inducing_points(self): + config_str = """ + [common] + parnames = [par1] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat] + + [par1] + par_type = continuous + lower_bound = 10 + upper_bound = 1000 + log_scale = True + + [init_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + min_asks = 2 + model = GPClassificationModel + + [GPClassificationModel] + inducing_point_method = FixedAllocator + inducing_size = 2 + + [FixedAllocator] + points = [[0.1], [0.2]] + + """ + + config = Config() + config.update(config_str=config_str) + strat = Strategy.from_config(config, "init_strat") + self.assertTrue(isinstance(strat.model.inducing_point_method, FixedAllocator)) + + # check that the bounds are scaled correctly + self.assertFalse( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[10], [100]]) + ) + ) + self.assertTrue( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[0], [1]]) + ) + ) + + # Check that the inducing points are the same as the fixed points (pre-transformation) + inducing_points_pre_transform = FixedAllocator( + points=torch.tensor([[0.1], [0.2]]) + ).allocate_inducing_points(num_inducing=2) + self.assertTrue( + torch.equal(inducing_points_pre_transform, torch.tensor([[0.1], [0.2]])) + ) + + # Check that the inducing points are not the same as the fixed points (post-transformation) + inducing_points_after_transform = ( + strat.model.inducing_point_method.allocate_inducing_points(num_inducing=2) + ) + self.assertFalse( + torch.equal(inducing_points_after_transform, torch.tensor([[0.1], [0.2]])) + ) + + # make the transformation + transforms = ParameterTransforms.from_config(config) + transformed_config = transform_options(config, transforms) + transformed_points = torch.tensor( + eval(transformed_config["FixedAllocator"]["points"]) + ) + # Check that the inducing points are the same as the fixed points (post-transformation) + self.assertTrue( + torch.equal(inducing_points_after_transform, transformed_points) + ) + + # Fit the model and check that the inducing points are updated + train_X = torch.tensor([[6.0], [3.0]]) + train_Y = torch.tensor([[1.0], [1.0]]) + strat.add_data(train_X, train_Y) + strat.fit() + self.assertTrue( + torch.equal( + strat.model.variational_strategy.inducing_points, transformed_points + ) + ) + + +class TestGreedyAllocators(unittest.TestCase): + def test_greedy_variance_reduction_allocate_inducing_points(self): + # Mock data for testing + train_X = torch.rand(100, 1) + train_Y = torch.rand(100, 1) + model = GPClassificationModel( + lb=0, + ub=1, + inducing_point_method=GreedyVarianceReduction(), + inducing_size=10, + ) + + # Instantiate GreedyVarianceReduction allocator + allocator = GreedyVarianceReduction() + + # Allocate inducing points and verify output shape + inducing_points = allocator.allocate_inducing_points( + inputs=train_X, + covar_module=model.covar_module, + num_inducing=10, + input_batch_shape=torch.Size([]), + ) + inital_inducing_points = DummyAllocator( + bounds=torch.stack([torch.tensor([0]), torch.tensor([1])]) + ).allocate_inducing_points( + inputs=train_X, + covar_module=model.covar_module, + num_inducing=10, + input_batch_shape=torch.Size([]), + ) + self.assertEqual(inducing_points.shape, (10, 1)) + # Should be different from the initial inducing points + self.assertFalse( + torch.allclose(inducing_points, model.variational_strategy.inducing_points) + ) + self.assertTrue( + torch.allclose( + inital_inducing_points, model.variational_strategy.inducing_points + ) + ) + + # Then fit the model and check that the inducing points are updated + model.fit(train_X, train_Y) + + self.assertTrue( + torch.allclose(inducing_points, model.variational_strategy.inducing_points) + ) + + def test_greedy_variance_from_config(self): + config_str = """ + [common] + parnames = [par1] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat] + + [par1] + par_type = continuous + lower_bound = 10 + upper_bound = 1000 + log_scale = True + + [init_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + min_asks = 2 + model = GPClassificationModel + + [GPClassificationModel] + inducing_point_method = GreedyVarianceReduction + inducing_size = 2 + """ + + config = Config() + config.update(config_str=config_str) + strat = Strategy.from_config(config, "init_strat") + self.assertTrue( + isinstance(strat.model.inducing_point_method, GreedyVarianceReduction) + ) + + # check that the bounds are scaled correctly + self.assertFalse( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[10], [100]]) + ) + ) + self.assertTrue( + torch.equal( + strat.model.inducing_point_method.bounds, torch.tensor([[0], [1]]) + ) + ) + + +if __name__ == "__main__": + unittest.main()