Skip to content

Commit

Permalink
Update all models to use automated init parser (#621)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #621

Update all models using automated init parser with bespoke behavior implemented where needed.

Reviewed By: crasanders

Differential Revision: D69158866

fbshipit-source-id: 5db257f7472a9706184936e2fada6de6b1a8d787
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Feb 5, 2025
1 parent bbbc7e8 commit 5c8c00e
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 193 deletions.
64 changes: 31 additions & 33 deletions aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import time
from abc import abstractmethod
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple

Expand All @@ -20,7 +21,7 @@
from botorch.models.gpytorch import GPyTorchModel
from botorch.posteriors import TransformedPosterior
from gpytorch.mlls import MarginalLogLikelihood

from torch.nn import Module

logger = getLogger()

Expand Down Expand Up @@ -132,45 +133,42 @@ def get_config_options(
Returns:
Dict[str, Any]: a dictionary of the search criteria described in the experiment's config
"""
options = options or {}
# NOTE: This get_config_options implies there should be an __init__ in this base
# class, but because the exact order of superclasses in this class's
# subclasses is very particular to ensure the MRO is exactly right, we cannot
# have a __init__ here. Expect the arguments, dim, mean_module, covar_module,
# likelihood, max_fit_time, and options. Look at subclasses for typing.

options = super().get_config_options(config=config, name=name, options=options)

name = name or cls.__name__

dim = config.getint(name, "dim", fallback=None)
if dim is None:
dim = get_dims(config)
# Missing dims
if "dim" not in options:
options["dim"] = get_dims(config)

mean_covar_factory = config.getobj(
name, "mean_covar_factory", fallback=default_mean_covar_factory
)
# Missing mean/covar modules
if (
options.get("mean_module", None) is None
and options.get("mean_module", None) is None
):
# Get the factory
mean_covar_factory = config.getobj(
name, "mean_covar_factory", fallback=default_mean_covar_factory
)

mean, covar = mean_covar_factory(
config, stimuli_per_trial=cls.stimuli_per_trial
)
max_fit_time = config.getfloat(name, "max_fit_time", fallback=None)
mean_module, covar_module = mean_covar_factory(
config, stimuli_per_trial=cls.stimuli_per_trial
)

likelihood_cls = config.getobj(name, "likelihood", fallback=None)
options["mean_module"] = mean_module
options["covar_module"] = covar_module

if likelihood_cls is not None:
if hasattr(likelihood_cls, "from_config"):
likelihood = likelihood_cls.from_config(config)
else:
likelihood = likelihood_cls()
else:
likelihood = None # fall back to __init__ default

optimizer_options = get_optimizer_options(config, name)

options.update(
{
"dim": dim,
"mean_module": mean,
"covar_module": covar,
"max_fit_time": max_fit_time,
"likelihood": likelihood,
"optimizer_options": optimizer_options,
}
)
if "likelihood" in options and isinstance(options["likelihood"], type):
options["likelihood"] = options["likelihood"]() # Initialize it

# Get optimize options, this is necessarily bespoke
options["optimizer_options"] = get_optimizer_options(config, name)

return options

Expand Down
1 change: 0 additions & 1 deletion aepsych/models/gp_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from gpytorch.likelihoods import BernoulliLikelihood, Likelihood

from .transformed_posteriors import BernoulliProbitProbabilityPosterior

from .variationalgp import VariationalGPModel

logger = getLogger()
Expand Down
52 changes: 6 additions & 46 deletions aepsych/models/monotonic_projection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
lb: torch.Tensor,
ub: torch.Tensor,
dim: int,
monotonic_dims: List[int],
monotonic_dims: Optional[List[int]] = None,
monotonic_grid_size: int = 20,
min_f_val: Optional[float] = None,
mean_module: Optional[gpytorch.means.Mean] = None,
Expand All @@ -115,8 +115,8 @@ def __init__(
lb (torch.Tensor): Lower bounds of the parameters.
ub (torch.Tensor): Upper bounds of the parameters.
dim (int, optional): The number of dimensions in the parameter space.
monotonic_dims (List[int]): A list of the dimensions on which monotonicity should
be enforced.
monotonic_dims (List[int], optional): A list of the dimensions on which monotonicity should
be enforced. If not set, it will default to [-1].
monotonic_grid_size (int): The size of the grid, s, in 1. above. Defaults to 20.
min_f_val (float, optional): If provided, maintains this minimum in the projection in 5. Defaults to None.
mean_module (gpytorch.means.Mean, optional): GP mean class. Defaults to a constant with a normal prior. Defaults to None.
Expand All @@ -130,6 +130,9 @@ def __init__(
max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None,
there is no limit to the fitting time. Defaults to None.
"""
if monotonic_dims is None:
monotonic_dims = [-1]

assert len(monotonic_dims) > 0
self.monotonic_dims = [int(d) for d in monotonic_dims]
self.mon_grid_size = monotonic_grid_size
Expand Down Expand Up @@ -217,46 +220,3 @@ def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor:
if self.min_f_val is not None:
samps = samps.clamp(min=self.min_f_val)
return samps

@classmethod
def get_config_options(
cls,
config: Config,
name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Get configuration options for the model.
Args:
config (Config): Configuration object.
name (str, optional): Name of the model, defaults to None.
options (Dict[str, Any], optional): Additional options, defaults to None.
Returns:
Dict[str, Any]: Configuration options for the model.
"""
options = options or {}
options.update(super().get_config_options(config, name, options))

name = name or cls.__name__

lb = config.gettensor(name, "lb")
ub = config.gettensor(name, "ub")

monotonic_dims: List[int] = config.getlist(
name, "monotonic_dims", fallback=[-1]
)
monotonic_grid_size = config.getint(name, "monotonic_grid_size", fallback=20)
min_f_val = config.getfloat(name, "min_f_val", fallback=None)

options.update(
{
"lb": lb,
"ub": ub,
"monotonic_dims": monotonic_dims,
"monotonic_grid_size": monotonic_grid_size,
"min_f_val": min_f_val,
}
)

return options
14 changes: 1 addition & 13 deletions aepsych/models/pairwise_probit.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,7 @@ def get_config_options(
Returns:
GPClassificationModel: Configured class instance.
"""
options = options or {}
options.update(super().get_config_options(config, name, options))
name = name or cls.__name__
options = super().get_config_options(config, name, options)

# no way of passing mean into PairwiseGP right now
if "mean_module" in options:
Expand All @@ -301,14 +299,4 @@ def get_config_options(
if "likelihood" in options:
del options["likelihood"]

lb = config.gettensor(name, "lb")
ub = config.gettensor(name, "ub")

options.update(
{
"lb": lb,
"ub": ub,
}
)

return options
61 changes: 2 additions & 59 deletions aepsych/models/semi_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __init__(
mean_module: Optional[gpytorch.means.Mean] = None,
covar_module: Optional[gpytorch.kernels.Kernel] = None,
likelihood: Optional[Any] = None,
slope_mean: float = 2,
slope_mean: float = 2.0,
inducing_point_method: Optional[InducingPointAllocator] = None,
inducing_size: int = 100,
max_fit_time: Optional[float] = None,
Expand Down Expand Up @@ -320,41 +320,6 @@ def __init__(
optimizer_options=optimizer_options,
)

@classmethod
def get_config_options(
cls,
config: Config,
name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Get configuration options for the model.
Args:
config (Config): Configuration object.
name (str, optional): Name of the model, defaults to None.
options (Dict[str, Any], optional): Additional options, defaults to None.
Returns:
Dict[str, Any]: Configuration options for the model.
"""

options = options or {}
options.update(super().get_config_options(config, name, options))

name = name or cls.__name__

stim_dim = config.getint(name, "stim_dim", fallback=0)
slope_mean = config.getfloat(name, "slope_mean", fallback=2)

options.update(
{
"stim_dim": stim_dim,
"slope_mean": slope_mean,
}
)

return options

def fit(
self,
train_x: torch.Tensor,
Expand Down Expand Up @@ -619,9 +584,7 @@ def get_config_options(
Returns:
Dict[str, Any]: Configuration options for the model.
"""

options = options or {}
options.update(super().get_config_options(config, name, options))
options = super().get_config_options(config, name, options)

# This model has special modules
if "mean_module" in options:
Expand All @@ -630,26 +593,6 @@ def get_config_options(
if "covar_module" in options:
del options["covar_module"]

name = name or cls.__name__

slope_mean_module = config.getobj(name, "slope_mean_module", fallback=None)
slope_covar_module = config.getobj(name, "slope_covar_module", fallback=None)
offset_mean_module = config.getobj(name, "offset_mean_module", fallback=None)
offset_covar_module = config.getobj(name, "offset_covar_module", fallback=None)
stim_dim = config.getint(name, "stim_dim", fallback=0)
slope_mean = config.getfloat(name, "slope_mean", fallback=2)

options.update(
{
"stim_dim": stim_dim,
"slope_mean_module": slope_mean_module,
"slope_covar_module": slope_covar_module,
"offset_mean_module": offset_mean_module,
"offset_covar_module": offset_covar_module,
"slope_mean": slope_mean,
}
)

return options

def predict(
Expand Down
41 changes: 0 additions & 41 deletions aepsych/models/variationalgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy


logger = getLogger()


Expand Down Expand Up @@ -101,46 +100,6 @@ def __init__(
self._fresh_state_dict = deepcopy(self.state_dict())
self._fresh_likelihood_dict = deepcopy(self.likelihood.state_dict())

@classmethod
def get_config_options(
cls,
config: Config,
name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Get configuration options for the model.
Args:
config (Config): Configuration object.
name (str, optional): Name of the model, defaults to None.
options (Dict[str, Any], optional): Additional options, defaults to None.
Returns:
Dict[str, Any]: Configuration options for the model.
"""
options = options or {}
options.update(super().get_config_options(config, name, options))

name = name or cls.__name__
inducing_size = config.getint(name, "inducing_size", fallback=100)
inducing_point_method_class = config.getobj(
name, "inducing_point_method", fallback=GreedyVarianceReduction
)
# Check if allocator class has a `from_config` method
if hasattr(inducing_point_method_class, "from_config"):
inducing_point_method = inducing_point_method_class.from_config(config)
else:
inducing_point_method = inducing_point_method_class()

options.update(
{
"inducing_size": inducing_size,
"inducing_point_method": inducing_point_method,
}
)

return options

def _reset_hyperparameters(self) -> None:
"""Reset hyperparameters to their initial values."""
# warmstart_hyperparams affects hyperparams but not the variational strat,
Expand Down

0 comments on commit 5c8c00e

Please sign in to comment.