-
Notifications
You must be signed in to change notification settings - Fork 87
/
Copy pathrunner_utils.py
54 lines (40 loc) · 1.61 KB
/
runner_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from typing import Dict, Any
import logging
import sys
import os
from absl import flags
from clu import platform
from datetime import datetime
flags.DEFINE_string("workdir", "./", "workdir for logs and checkpoints of the experiment")
flags.DEFINE_boolean("pod_job", False, "For running on tpu pods with slurm")
FLAGS = flags.FLAGS
LOGGER = logging.Logger("Experiment", level=logging.INFO)
LOGGER_HANDLER = logging.StreamHandler(sys.stderr)
LOGGER_HANDLER.setFormatter(logging.Formatter("[%(asctime)s] FoT Tunning [%(levelname)s] : %(message)s"))
LOGGER.addHandler(LOGGER_HANDLER)
def override_flags(overrides: Dict[str, Any]):
for k, v in overrides.items():
field_names = k.split(".")
logging.info(f"Flags: Overriding {k} to {v}")
f = FLAGS
for fn in field_names[:-1]:
f = f.__getattr__(fn)
f.__setattr__(field_names[-1], v)
def prepare_for_run(config_dict: Dict[str, Any]):
override_flags(overrides=config_dict)
if FLAGS.pod_job:
import jax
for k in os.environ.keys():
if k.startswith("SLURM"):
os.environ.pop(k, None)
jax.distributed.initialize()
def create_workdir():
platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir")
def add_time_to_workdir():
cur_time = datetime.now().strftime("%d.%m.%Y_%H:%M:%S")
FLAGS.workdir = os.path.join(FLAGS.workdir, str(cur_time))
def run_from_dict(main_fn, config_dict: Dict[str, Any], post_override_callback):
prepare_for_run(config_dict=config_dict)
post_override_callback()
create_workdir()
main_fn(None)