Skip to content

Commit

Permalink
Remove options-level timeout and only pass timeout to run_n_trials (#…
Browse files Browse the repository at this point in the history
…3069)

Summary:
Pull Request resolved: #3069

As titled - T205636338 for more motivation on this clean up.

This change is to only pass the timeout to run_n_trials, to remove any ambiguity about whether the timeout is for one call or over the whole run

Reviewed By: esantorella

Differential Revision: D65947305

fbshipit-source-id: 95e66f3be7d62493dc88fec38eb8d1fce4498510
  • Loading branch information
paschai authored and facebook-github-bot committed Nov 14, 2024
1 parent 49698fc commit 241ad71
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
25 changes: 14 additions & 11 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def __init__(
self.markdown_messages = {
"Generation strategy": GS_TYPE_MSG.format(gs_name=generation_strategy.name)
}
self._timeout_hours = options.timeout_hours

@classmethod
def get_default_db_settings(cls) -> DBSettings:
Expand Down Expand Up @@ -736,7 +735,6 @@ def run_trials_and_yield_results(
raise UserInputError(
f"Expected `timeout_hours` >= 0, got {timeout_hours}."
)
self._timeout_hours = timeout_hours

self._latest_optimization_start_timestamp = current_timestamp_in_millis()
self.__ignore_global_stopping_strategy = ignore_global_stopping_strategy
Expand All @@ -755,7 +753,7 @@ def run_trials_and_yield_results(
self._num_remaining_requested_trials > 0
and not self.should_consider_optimization_complete()[0]
):
if self.should_abort_optimization():
if self.should_abort_optimization(timeout_hours=timeout_hours):
yield self._abort_optimization(num_preexisting_trials=n_existing)
return

Expand All @@ -766,7 +764,8 @@ def run_trials_and_yield_results(
self.candidate_trials
)
while self._num_remaining_requested_trials > 0 and self.run(
max_new_trials=n_remaining_to_generate
max_new_trials=n_remaining_to_generate,
timeout_hours=timeout_hours,
):
# Not checking `should_abort_optimization` on every trial for perf.
# reasons.
Expand Down Expand Up @@ -799,7 +798,7 @@ def run_trials_and_yield_results(
)

while self.running_trials:
if self.should_abort_optimization():
if self.should_abort_optimization(timeout_hours=timeout_hours):
yield self._abort_optimization(num_preexisting_trials=n_existing)
return
report_results = self._check_exit_status_and_report_results(
Expand Down Expand Up @@ -975,7 +974,7 @@ def should_consider_optimization_complete(self) -> tuple[bool, str]:
self.logger.info(f"Completing the optimization: {completion_message}.")
return should_complete, completion_message

def should_abort_optimization(self) -> bool:
def should_abort_optimization(self, timeout_hours: float | None = None) -> bool:
"""Checks whether this scheduler has reached some intertuption / abort
criterion, such as an overall optimization timeout, tolerated failure rate, etc.
"""
Expand All @@ -985,15 +984,15 @@ def should_abort_optimization(self) -> bool:

# if optimization is timed out, return True, else return False
timed_out = (
self._timeout_hours is not None
timeout_hours is not None
and self._latest_optimization_start_timestamp is not None
and current_timestamp_in_millis()
- none_throws(self._latest_optimization_start_timestamp)
>= none_throws(self._timeout_hours) * 60 * 60 * 1000
>= none_throws(timeout_hours) * 60 * 60 * 1000
)
if timed_out:
self.logger.error(
"Optimization timed out (timeout hours: " f"{self._timeout_hours})!"
"Optimization timed out (timeout hours: " f"{timeout_hours})!"
)
return timed_out

Expand Down Expand Up @@ -1179,7 +1178,7 @@ def _check_exit_status_and_report_results(
idle_callback, force_refit=True
)

def run(self, max_new_trials: int) -> bool:
def run(self, max_new_trials: int, timeout_hours: float | None = None) -> bool:
"""Schedules trial evaluation(s) if stopping criterion is not triggered,
maximum parallelism is not currently reached, and capacity allows.
Logs any failures / issues.
Expand All @@ -1189,6 +1188,10 @@ def run(self, max_new_trials: int) -> bool:
and run (useful when generating and running trials in batches). Note
that this function might also re-deploy existing ``CANDIDATE`` trials
that failed to deploy before, which will not count against this number.
timeout_hours: Maximum number of hours, for which
to run the optimization. This function will abort after running
for `timeout_hours` even if stopping criterion has not been reached.
If set to `None`, no optimization timeout will be applied.
Returns:
Boolean representing success status.
Expand All @@ -1204,7 +1207,7 @@ def run(self, max_new_trials: int) -> bool:
)
return False

if self.should_abort_optimization():
if self.should_abort_optimization(timeout_hours=timeout_hours):
self.logger.info(
"`should_abort_optimization` is `True`, not running more trials."
)
Expand Down
2 changes: 1 addition & 1 deletion ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ class AxSchedulerTestCase(TestCase):
"min_failed_trials_for_failure_rate_check=5, log_filepath=None, "
"logging_level=20, ttl_seconds_for_trials=None, init_seconds_between_"
"polls=10, min_seconds_before_poll=1.0, seconds_between_polls_backoff_"
"factor=1.5, timeout_hours=None, run_trials_in_batches=False, "
"factor=1.5, run_trials_in_batches=False, "
"debug_log_run_metadata=False, early_stopping_strategy=None, "
"global_stopping_strategy=None, suppress_storage_errors_after_"
"retries=False, wait_for_running_trials=True, fetch_kwargs={}, "
Expand Down
2 changes: 1 addition & 1 deletion ax/service/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class TestAxSchedulerMultiTypeExperiment(AxSchedulerTestCase):
"min_failed_trials_for_failure_rate_check=5, log_filepath=None, "
"logging_level=20, ttl_seconds_for_trials=None, init_seconds_between_"
"polls=10, min_seconds_before_poll=1.0, seconds_between_polls_backoff_"
"factor=1.5, timeout_hours=None, run_trials_in_batches=False, "
"factor=1.5, run_trials_in_batches=False, "
"debug_log_run_metadata=False, early_stopping_strategy=None, "
"global_stopping_strategy=None, suppress_storage_errors_after_"
"retries=False, wait_for_running_trials=True, fetch_kwargs={}, "
Expand Down
1 change: 0 additions & 1 deletion ax/service/utils/scheduler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ class SchedulerOptions:
init_seconds_between_polls: int | None = 1
min_seconds_before_poll: float = 1.0
seconds_between_polls_backoff_factor: float = 1.5
timeout_hours: float | None = None
run_trials_in_batches: bool = False
debug_log_run_metadata: bool = False
early_stopping_strategy: BaseEarlyStoppingStrategy | None = None
Expand Down

0 comments on commit 241ad71

Please sign in to comment.