Skip to content

Commit

Permalink
Add multiprocessing to run all benchmarks script.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694046858
  • Loading branch information
Nush395 authored and Torax team committed Nov 7, 2024
1 parent f78f714 commit d6f5cee
Showing 1 changed file with 41 additions and 12 deletions.
53 changes: 41 additions & 12 deletions torax/tests/scripts/run_and_save_all_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,27 @@
"""Helper script to run and regenerate all benchmarks under test_data."""

from collections.abc import Sequence
import functools
import importlib
import os
import time

from absl import app
from absl import flags
from absl import logging
from torax import simulation_app
from torax.config import build_sim
from torax.tests.test_lib import paths
from torax.tests.test_lib import sim_test_case

import multiprocessing


_OUTPUT_DIR = flags.DEFINE_string(
'output_dir', '/tmp/torax_sim_outputs', 'Where to save sim outputs.'
)
_NUM_PROCESSES = flags.DEFINE_integer(
'num_proc', 16, 'Number of processes to use.'
)


def _get_config_module(
Expand All @@ -48,10 +54,17 @@ def _get_config_module(
)


def _run_sim(test_data_dir: str, config_name: str):
def _run_sim(config_name: str, test_data_dir: str, output_dir: str):
"""Run simulation for given config."""
logging.info('Running %s', config_name)
config_module = _get_config_module(test_data_dir, config_name + '.py')
flags.FLAGS.mark_as_parsed()
print(f'Running {config_name}')

try:
config_module = _get_config_module(test_data_dir, config_name + '.py')
except ImportError:
print(f'Failed to import config module {config_name}, skipping.')
return

if hasattr(config_module, 'get_sim'):
# The config module likely uses the "advanced" configuration setup with
# python functions defining all the Sim object attributes.
Expand All @@ -66,24 +79,40 @@ def _run_sim(test_data_dir: str, config_name: str):
f'Config module {config_name} must either define a get_sim() method'
' or a CONFIG dictionary.'
)
simulation_app.main(
lambda: sim,
output_dir=os.path.join(_OUTPUT_DIR.value, config_name),
)
try:
_, output_file = simulation_app.main(
lambda: sim,
output_dir=os.path.join(output_dir, config_name),
)
print(f'Finished running {config_name}, output saved to {output_file}')
except Exception as e: # pylint: disable=broad-except
print(f'Failed to run {config_name}: {e}')


def main(argv: Sequence[str]) -> None:
start_time_s = time.time()
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
configs = []
test_data_dir = paths.test_data_dir()
for path in os.listdir(test_data_dir):
if path.endswith('.nc'):
basename = os.path.basename(path)
config_name, _ = basename.split('.')
try:
_run_sim(test_data_dir, config_name)
except Exception as e: # pylint: disable=broad-except
logging.exception('Failed to run %s: %s', config_name, e)
configs.append(config_name)
print(f'Found {len(configs)} config experiments to run.')
run_sim = functools.partial(
_run_sim,
test_data_dir=test_data_dir,
output_dir=_OUTPUT_DIR.value,
)
# Important to use 'spawn' over 'forkserver' as JAX is not fork-safe.
mp_context = multiprocessing.get_context('spawn')
with mp_context.Pool(processes=_NUM_PROCESSES.value) as pool:
pool.map(run_sim, configs)
pool.close()
pool.join()
print(f'Running took {time.time() - start_time_s}s')


if __name__ == '__main__':
Expand Down

0 comments on commit d6f5cee

Please sign in to comment.