Skip to content

Commit

Permalink
Model fit metrics for logging (#1682)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1682

This commit adds metrics in order to quantify and log model fit quality for each *experimental* metric. To start, the commit adds metrics based on the posterior statistics of the model, which can be extended readily by adding to the `fit_metrics` dict, and can be generalized with other metric types in follow up work.

Reviewed By: Balandat

Differential Revision: D46816506

fbshipit-source-id: 0e4f9d9d8f4030b9793bdcf9ec5c218fccb91990
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jun 24, 2023
1 parent 94363dd commit e4aa110
Show file tree
Hide file tree
Showing 12 changed files with 671 additions and 97 deletions.
108 changes: 107 additions & 1 deletion ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from dataclasses import dataclass, field

from logging import Logger
from typing import Any, Dict, List, MutableMapping, Optional, Set, Tuple, Type
from typing import Any, cast, Dict, List, MutableMapping, Optional, Set, Tuple, Type

import numpy as np
from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.experiment import Experiment
Expand All @@ -36,6 +37,13 @@
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.stats.model_fit_stats import (
coefficient_of_determination,
compute_model_fit_metrics,
mean_of_the_standardized_error,
ModelFitMetricProtocol,
std_of_the_standardized_error,
)
from botorch.exceptions.warnings import InputDataWarning

logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -918,6 +926,51 @@ def _cross_validate(
"""
raise NotImplementedError # pragma: no cover

def compute_model_fit_metrics(
self,
experiment: Experiment,
fit_metrics_dict: Optional[Dict[str, ModelFitMetricProtocol]] = None,
) -> Dict[str, Dict[str, float]]:
"""Computes the model fit metrics from the scheduler state.
Args:
experiment: The experiment with whose data to compute the model fit metrics.
fit_metrics_dict: An optional dictionary with model fit metric functions,
i.e. a ModelFitMetricProtocol, as values and their names as keys.
Returns:
A nested dictionary mapping from the *model fit* metric names and the
*experimental metric* names to the values of the model fit metrics.
Example for an imaginary AutoML experiment that seeks to minimize the test
error after training an expensive model, with respect to hyper-parameters:
```
model_fit_dict = model_fit_metrics_from_scheduler(scheduler)
model_fit_dict["coefficient_of_determination"]["test error"] =
`coefficient of determination of the test error predictions`
```
"""
# TODO: cross_validate_by_trial-based generalization quality
# IDEA: store y_obs, y_pred, se_pred as well
y_obs, y_pred, se_pred = _predict_on_training_data(
model_bridge=self, experiment=experiment
)
if fit_metrics_dict is None:
fit_metrics_dict = {
"coefficient_of_determination": coefficient_of_determination,
"mean_of_the_standardized_error": mean_of_the_standardized_error,
"std_of_the_standardized_error": std_of_the_standardized_error,
}
fit_metrics_dict = cast(Dict[str, ModelFitMetricProtocol], fit_metrics_dict)

return compute_model_fit_metrics(
y_obs=y_obs,
y_pred=y_pred,
se_pred=se_pred,
fit_metrics_dict=fit_metrics_dict,
)

def _set_kwargs_to_save(
self,
model_key: str,
Expand Down Expand Up @@ -1099,3 +1152,56 @@ def clamp_observation_features(
)
obsf.parameters[p.name] = p.upper
return observation_features


"""
############################## Model Fit Metrics Utils ##############################
"""


def _predict_on_training_data(
model_bridge: ModelBridge,
experiment: Experiment,
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, np.ndarray],]:
"""Makes predictions on the training data of a given experiment using a ModelBridge
and returning the observed values, and the corresponding predictive means and
predictive standard deviations of the model.
NOTE: This is a helper function for `ModelBridge.compute_model_fit_metrics` and
could be attached to the class.
Args:
model_bridge: A ModelBridge object with which to make predictions.
experiment: The experiment with whose data to compute the model fit metrics.
Returns:
A tuple containing three dictionaries for 1) observed metric values, and the
model's associated 2) predictive means and 3) predictive standard deviations.
"""
data = experiment.fetch_data()
observations = observations_from_data(
experiment=experiment, data=data
) # List[Observation]
observation_features = [obs.features for obs in observations]
mean_predicted, cov_predicted = model_bridge.predict(
observation_features=observation_features
) # Dict[str, List[float]]
mean_observed = [
obs.data.means_dict for obs in observations
] # List[Dict[str, float]]
metric_names = list(data.metric_names)
mean_observed = _list_of_dicts_to_dict_of_lists(
list_of_dicts=mean_observed, keys=metric_names
)
# converting dictionary values to arrays
mean_observed = {k: np.array(v) for k, v in mean_observed.items()}
mean_predicted = {k: np.array(v) for k, v in mean_predicted.items()}
std_predicted = {m: np.sqrt(np.array(cov_predicted[m][m])) for m in cov_predicted}
return mean_observed, mean_predicted, std_predicted


def _list_of_dicts_to_dict_of_lists(
list_of_dicts: List[Dict[str, float]], keys: List[str]
) -> Dict[str, List[float]]:
"""Converts a list of dicts indexed by a string to a dict of lists."""
return {key: [d[key] for d in list_of_dicts] for key in keys}
131 changes: 53 additions & 78 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,37 @@

from logging import Logger
from numbers import Number
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Set,
Tuple,
)

import numpy as np
from ax.core.observation import Observation, ObservationData
from ax.core.optimization_config import OptimizationConfig
from ax.modelbridge.base import ModelBridge
from ax.utils.common.logger import get_logger
from scipy.stats import fisher_exact, norm, pearsonr, spearmanr

from ax.utils.stats.model_fit_stats import (
_correlation_coefficient,
_fisher_exact_test_p,
_log_likelihood,
_mape,
_mean_prediction_ci,
_rank_correlation,
_total_raw_effect,
compute_model_fit_metrics,
ModelFitMetricProtocol,
)

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -225,27 +248,36 @@ def compute_diagnostics(result: List[CVResult]) -> CVDiagnostics:
k = res.predicted.metric_names.index(metric_name)
y_pred[metric_name].append(res.predicted.means[k])
se_pred[metric_name].append(np.sqrt(res.predicted.covariance[k, k]))
y_obs = _arrayify_dict_values(y_obs)
y_pred = _arrayify_dict_values(y_pred)
se_pred = _arrayify_dict_values(se_pred)

# We need to cast here since pyre infers specific types T < ModelFitMetricProtocol
# for the dict values, which is type variant upon initialization, leading
# diagnostic_fns to not be recognized as a Mapping[str, ModelFitMetricProtocol],
# see the last tip in the Pyre docs on [9] Incompatible Variable Type:
# https://staticdocs.internalfb.com/pyre/docs/errors/#9-incompatible-variable-type
diagnostic_fns = cast(
Mapping[str, ModelFitMetricProtocol],
{
MEAN_PREDICTION_CI: _mean_prediction_ci,
MAPE: _mape,
TOTAL_RAW_EFFECT: _total_raw_effect,
CORRELATION_COEFFICIENT: _correlation_coefficient,
RANK_CORRELATION: _rank_correlation,
FISHER_EXACT_TEST_P: _fisher_exact_test_p,
LOG_LIKELIHOOD: _log_likelihood,
},
)
diagnostics = compute_model_fit_metrics(
y_obs=y_obs, y_pred=y_pred, se_pred=se_pred, fit_metrics_dict=diagnostic_fns
)
return diagnostics

diagnostic_fns = {
MEAN_PREDICTION_CI: _mean_prediction_ci,
MAPE: _mape,
TOTAL_RAW_EFFECT: _total_raw_effect,
CORRELATION_COEFFICIENT: _correlation_coefficient,
RANK_CORRELATION: _rank_correlation,
FISHER_EXACT_TEST_P: _fisher_exact_test_p,
LOG_LIKELIHOOD: _log_likelihood,
}

diagnostics: Dict[str, Dict[str, float]] = defaultdict(dict)
# Get all per-metric diagnostics.
for metric_name in y_obs:
for name, fn in diagnostic_fns.items():
diagnostics[name][metric_name] = fn(
y_obs=np.array(y_obs[metric_name]),
y_pred=np.array(y_pred[metric_name]),
se_pred=np.array(se_pred[metric_name]),
)
return diagnostics
def _arrayify_dict_values(d: Dict[str, List[float]]) -> Dict[str, np.ndarray]:
"""Helper to convert dictionary values to numpy arrays."""
return {k: np.array(v) for k, v in d.items()}


def assess_model_fit(
Expand Down Expand Up @@ -339,63 +371,6 @@ def _gen_train_test_split(
yield set(arm_names[:-n_test]), set(arm_names[-n_test:])


def _mean_prediction_ci(
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
) -> float:
# Pyre does not allow float * np.ndarray.
return float(np.mean(1.96 * 2 * se_pred / np.abs(y_obs)))


def _log_likelihood(
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
) -> float:
return float(np.sum(norm.logpdf(y_obs, loc=y_pred, scale=se_pred)))


def _mape(y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray) -> float:
return float(np.mean(np.abs((y_pred - y_obs) / y_obs)))


def _total_raw_effect(
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
) -> float:
min_y_obs = np.min(y_obs)
return float((np.max(y_obs) - min_y_obs) / min_y_obs)


def _correlation_coefficient(
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
) -> float:
with np.errstate(invalid="ignore"):
rho, _ = pearsonr(y_pred, y_obs)
return float(rho)


def _rank_correlation(
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
) -> float:
with np.errstate(invalid="ignore"):
rho, _ = spearmanr(y_pred, y_obs)
return float(rho)


def _fisher_exact_test_p(
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
) -> float:
n_half = len(y_obs) // 2
top_obs = y_obs.argsort(axis=0)[-n_half:]
top_est = y_pred.argsort(axis=0)[-n_half:]
# Construct contingency table
tp = len(set(top_est).intersection(top_obs))
fp = n_half - tp
fn = n_half - tp
tn = (len(y_obs) - n_half) - (n_half - tp)
table = np.array([[tp, fp], [fn, tn]])
# Compute the test statistic
_, p = fisher_exact(table, alternative="greater")
return float(p)


class BestModelSelector(ABC):
@abstractmethod
def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int:
Expand Down
83 changes: 83 additions & 0 deletions ax/modelbridge/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Dict

from ax.core.experiment import Experiment
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.metrics.branin import BraninMetric
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.runners.synthetic import SyntheticRunner
from ax.service.scheduler import get_fitted_model_bridge, Scheduler, SchedulerOptions
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_search_space

NUM_SOBOL = 5


class TestModelBridgeFitMetrics(TestCase):
def setUp(self) -> None:
# setting up experiment and generation strategy
self.runner = SyntheticRunner()
self.branin_experiment = Experiment(
name="branin_test_experiment",
search_space=get_branin_search_space(),
runner=self.runner,
optimization_config=OptimizationConfig(
objective=Objective(
metric=BraninMetric(name="branin", param_names=["x1", "x2"]),
minimize=True,
),
),
is_test=True,
)
self.branin_experiment._properties[
Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF
] = True
self.generation_strategy = GenerationStrategy(
steps=[
GenerationStep(
model=Models.SOBOL, num_trials=NUM_SOBOL, max_parallelism=NUM_SOBOL
),
GenerationStep(model=Models.GPEI, num_trials=-1),
]
)

def test_model_fit_metrics(self) -> None:
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=self.generation_strategy,
options=SchedulerOptions(),
)
# need to run some trials to initialize the ModelBridge
scheduler.run_n_trials(max_trials=NUM_SOBOL + 1)
model_bridge = get_fitted_model_bridge(scheduler)

# testing ModelBridge.compute_model_fit_metrics with default metrics
fit_metrics = model_bridge.compute_model_fit_metrics(self.branin_experiment)
r2 = fit_metrics.get("coefficient_of_determination")
self.assertIsInstance(r2, dict)
r2 = cast(Dict[str, float], r2)
self.assertTrue("branin" in r2)
r2_branin = r2["branin"]
self.assertIsInstance(r2_branin, float)

std = fit_metrics.get("std_of_the_standardized_error")
self.assertIsInstance(std, dict)
std = cast(Dict[str, float], std)
self.assertTrue("branin" in std)
std_branin = std["branin"]
self.assertIsInstance(std_branin, float)

# testing with empty metrics
empty_metrics = model_bridge.compute_model_fit_metrics(
self.branin_experiment, fit_metrics_dict={}
)
self.assertIsInstance(empty_metrics, dict)
self.assertTrue(len(empty_metrics) == 0)
Loading

0 comments on commit e4aa110

Please sign in to comment.