Skip to content

Commit

Permalink
Benchmarks: restore some deleted functionality (#1683)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1683

Put percentiles back into AggregatedBenchmarkResult since Sait is using them

Reviewed By: Balandat

Differential Revision: D46921900

fbshipit-source-id: 5d71df1b01589521c795188eca996295fd3b4972
  • Loading branch information
esantorella authored and facebook-github-bot committed Jun 24, 2023
1 parent 298df24 commit f2aa68e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
14 changes: 11 additions & 3 deletions ax/benchmark/benchmark_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from ax.core.experiment import Experiment
from ax.utils.common.base import Base
from numpy import nanmean, ndarray
from numpy import nanmean, nanquantile, ndarray
from pandas import DataFrame
from scipy.stats import sem

Expand All @@ -18,6 +18,8 @@
# `BenchmarkResult` as return type annotation, used for serialization and rendering
# in the UI.

PERCENTILES = [0.25, 0.5, 0.75]


@dataclass(frozen=True, eq=False)
class BenchmarkResult(Base):
Expand Down Expand Up @@ -78,7 +80,7 @@ def from_benchmark_results(
trace_stats = {}
for name in ("optimization_trace", "score_trace"):
step_data = zip(*(getattr(res, name) for res in results))
stats = _get_stats(step_data=step_data)
stats = _get_stats(step_data=step_data, percentiles=PERCENTILES)
trace_stats[name] = stats

# Return aggregated results
Expand All @@ -91,9 +93,15 @@ def from_benchmark_results(
)


def _get_stats(step_data: Iterable[np.ndarray]) -> Dict[str, List[float]]:
def _get_stats(
step_data: Iterable[np.ndarray],
percentiles: List[float],
) -> Dict[str, List[float]]:
quantiles = []
stats = {"mean": [], "sem": []}
for step_vals in step_data:
stats["mean"].append(nanmean(step_vals))
stats["sem"].append(sem(step_vals, ddof=1, nan_policy="propagate"))
quantiles.append(nanquantile(step_vals, q=percentiles))
stats.update({f"P{100 * p:.0f}": q for p, q in zip(percentiles, zip(*quantiles))})
return stats
7 changes: 7 additions & 0 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def test_test(self) -> None:
"All experiments must have 4 trials",
)

for col in ["mean", "P25", "P50", "P75"]:
self.assertTrue((agg.score_trace[col] <= 100).all())

@fast_botorch_optimize
def test_full_run(self) -> None:
aggs = benchmark_full_run(
Expand All @@ -85,6 +88,10 @@ def test_full_run(self) -> None:

self.assertEqual(len(aggs), 2)

for agg in aggs:
for col in ["mean", "P25", "P50", "P75"]:
self.assertTrue((agg.score_trace[col] <= 100).all())

def test_timeout(self) -> None:
problem = SingleObjectiveBenchmarkProblem.from_botorch_synthetic(
test_problem_class=Branin,
Expand Down

0 comments on commit f2aa68e

Please sign in to comment.