Skip to content

Commit

Permalink
Refactor model hiearchy (#601)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #601

This creates a base VariationalGPModel class that GPClassificationModel and SemiP inherit from, and it moves more functionality to the mixin.

Reviewed By: JasonKChow

Differential Revision: D68982469

fbshipit-source-id: 85b0c1e1becc873e463ab89f1612b8028c2b4131
  • Loading branch information
crasanders authored and facebook-github-bot committed Feb 4, 2025
1 parent a3a86a8 commit 18c52ca
Show file tree
Hide file tree
Showing 17 changed files with 439 additions and 363 deletions.
10 changes: 5 additions & 5 deletions aepsych/benchmark/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from functools import cached_property
from typing import Any, Dict, List, Union

import aepsych
import numpy as np
import torch
from aepsych.models.model_protocol import ModelProtocol
from aepsych.models.utils import p_below_threshold
from aepsych.strategy import SequentialStrategy
from aepsych.utils import make_scaled_sobol
Expand Down Expand Up @@ -78,11 +78,11 @@ def sample_y(
"""
return bernoulli.rvs(self.p(x))

def f_hat(self, model: aepsych.models.base.ModelProtocol) -> torch.Tensor:
def f_hat(self, model: ModelProtocol) -> torch.Tensor:
"""Generate mean predictions from the model over the evaluation grid.
Args:
model (aepsych.models.base.ModelProtocol): Model to evaluate.
model (TensoModelProtocolr): Model to evaluate.
Returns:
torch.Tensor: Posterior mean from underlying model over the evaluation grid.
Expand All @@ -109,11 +109,11 @@ def p_true(self) -> torch.Tensor:
normal_dist = torch.distributions.Normal(0, 1)
return normal_dist.cdf(self.f_true)

def p_hat(self, model: aepsych.models.base.ModelProtocol) -> torch.Tensor:
def p_hat(self, model: ModelProtocol) -> torch.Tensor:
"""Generate mean predictions from the model over the evaluation grid.
Args:
model (aepsych.models.base.ModelProtocol): Model to evaluate.
model (TensoModelProtocolr): Model to evaluate.
Returns:
torch.Tensor: Posterior mean from underlying model over the evaluation grid.
Expand Down
2 changes: 1 addition & 1 deletion aepsych/generators/acqf_grid_search_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import torch
from aepsych.models.base import ModelProtocol
from aepsych.models.model_protocol import ModelProtocol
from aepsych.utils_logging import getLogger
from numpy.random import choice

Expand Down
2 changes: 1 addition & 1 deletion aepsych/generators/acqf_thompson_sampler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import torch
from aepsych.models.base import ModelProtocol
from aepsych.models.model_protocol import ModelProtocol
from aepsych.utils_logging import getLogger
from numpy.random import choice

Expand Down
2 changes: 1 addition & 1 deletion aepsych/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption

from ..models.base import ModelProtocol
from ..models.model_protocol import ModelProtocol

AEPsychModelType = TypeVar("AEPsychModelType", bound=AEPsychMixin)

Expand Down
2 changes: 1 addition & 1 deletion aepsych/generators/epsilon_greedy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from aepsych.config import Config

from ..models.base import ModelProtocol
from ..models.model_protocol import ModelProtocol
from .base import AEPsychGenerator
from .optimize_acqf_generator import OptimizeAcqfGenerator

Expand Down
2 changes: 1 addition & 1 deletion aepsych/generators/grid_eval_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aepsych.config import Config
from aepsych.generators.base import AcqfGenerator, AEPsychGenerator
from aepsych.generators.sobol_generator import SobolGenerator
from aepsych.models.base import ModelProtocol
from aepsych.models.model_protocol import ModelProtocol
from aepsych.utils_logging import getLogger
from botorch.acquisition import AcquisitionFunction

Expand Down
2 changes: 1 addition & 1 deletion aepsych/generators/optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aepsych.acquisition.lookahead import LookaheadAcquisitionFunction
from aepsych.config import Config
from aepsych.generators.base import AcqfGenerator
from aepsych.models.base import ModelProtocol
from aepsych.models.model_protocol import ModelProtocol
from aepsych.utils_logging import getLogger
from botorch.acquisition import AcquisitionFunction
from botorch.optim import optimize_acqf
Expand Down
5 changes: 3 additions & 2 deletions aepsych/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys

from ..config import Config
from .gp_classification import GPBetaRegressionModel, GPClassificationModel
from .gp_classification import GPClassificationModel
from .gp_regression import GPRegressionModel
from .monotonic_projection_gp import MonotonicProjectionGP
from .pairwise_probit import PairwiseProbitModel
Expand All @@ -17,6 +17,7 @@
semi_p_posterior_transform,
SemiParametricGPModel,
)
from .variationalgp import VariationalGPModel

__all__ = [
"GPClassificationModel",
Expand All @@ -25,8 +26,8 @@
"HadamardSemiPModel",
"SemiParametricGPModel",
"semi_p_posterior_transform",
"GPBetaRegressionModel",
"PairwiseProbitModel",
"VariationalGPModel",
]

Config.register_module(sys.modules[__name__])
134 changes: 76 additions & 58 deletions aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,75 +8,21 @@

import time
from copy import deepcopy
from typing import Any, Callable, Dict, List, Mapping, Optional, Protocol, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import gpytorch
import torch
from aepsych.config import Config, ConfigurableMixin
from aepsych.factory.default import default_mean_covar_factory
from aepsych.utils import get_dims, get_optimizer_options
from aepsych.utils import get_dims, get_optimizer_options, promote_0d
from aepsych.utils_logging import getLogger
from botorch.fit import fit_gpytorch_mll, fit_gpytorch_mll_scipy
from botorch.models.gpytorch import GPyTorchModel
from botorch.posteriors import GPyTorchPosterior
from gpytorch.likelihoods import Likelihood
from botorch.posteriors import TransformedPosterior
from gpytorch.mlls import MarginalLogLikelihood

logger = getLogger()


class ModelProtocol(Protocol):
@property
def _num_outputs(self) -> int:
pass

@property
def outcome_type(self) -> str:
pass

@property
def extremum_solver(self) -> str:
pass

@property
def train_inputs(self) -> torch.Tensor:
pass

@property
def dim(self) -> int:
pass

@property
def device(self) -> torch.device:
pass

def posterior(self, X: torch.Tensor) -> GPyTorchPosterior:
pass

def predict(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
pass

def predict_probability(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
pass

@property
def stimuli_per_trial(self) -> int:
pass

@property
def likelihood(self) -> Likelihood:
pass

def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor:
pass

def fit(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs: Any) -> None:
pass

def update(
self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs: Any
) -> None:
pass
logger = getLogger()


class AEPsychMixin(GPyTorchModel, ConfigurableMixin):
Expand Down Expand Up @@ -328,3 +274,75 @@ def train_targets(self, train_targets: Optional[torch.Tensor]) -> None:
# setting device on copy to not change original
train_targets = deepcopy(train_targets).to(self.device)
self._train_targets = train_targets

def predict(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Query the model for posterior mean and variance.
Args:
x (torch.Tensor): Points at which to predict from the model.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points.
"""
with torch.no_grad():
x = x.to(self.device)
post = self.posterior(x)
mean = post.mean.squeeze()
var = post.variance.squeeze()

return promote_0d(mean.to(self.device)), promote_0d(var.to(self.device))

def predict_transform(
self,
x: torch.Tensor,
transformed_posterior_cls: Optional[type[TransformedPosterior]] = None,
**transform_kwargs,
):
"""Query the model for posterior mean and variance under some tranformation.
Args:
x (torch.Tensor): Points at which to predict from the model.
transformed_posterior_cls (TransformedPosterior type, optional): The type of transformation to apply to the posterior.
Note that you should give TransformedPosterior itself, rather than an instance. Defaults to None, in which case no
transformation is applied.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Transformed posterior mean and variance at queries points.
"""
if transformed_posterior_cls is None:
return self.predict(x)
with torch.no_grad():
x = x.to(self.device)
post = self.posterior(x)
post = transformed_posterior_cls(post, **transform_kwargs)

mean = post.mean.squeeze()
var = post.variance.squeeze()

return promote_0d(mean.to(self.device)), promote_0d(var.to(self.device))

def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor:
"""Sample from underlying model.
Args:
x (torch.Tensor): Points at which to sample.
num_samples (int): Number of samples to return.
kwargs are ignored
Returns:
torch.Tensor: Posterior samples [num_samples x dim]
"""
x = x.to(self.device)
return self.posterior(x).sample(torch.Size([num_samples])).squeeze()

def update(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs):
"""Perform a warm-start update of the model from previous fit.
Args:
train_x (torch.Tensor): Inputs.
train_y (torch.Tensor): Responses.
"""
return self.fit(train_x, train_y, **kwargs)
Loading

0 comments on commit 18c52ca

Please sign in to comment.