-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
85 lines (71 loc) · 2.72 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Codes borrowed from Open Catalyst Project (OCP) https://github.com/Open-Catalyst-Project/ocp (MIT license)
"""
import copy
import logging
import submitit
from m2models.common.flags import flags
from m2models.common.utils import (
build_config,
create_grid,
new_trainer_context,
save_experiment_log,
setup_logging,
)
class Runner(submitit.helpers.Checkpointable):
def __init__(self):
self.config = None
def __call__(self, config):
with new_trainer_context(args=args, config=config) as ctx:
self.config = ctx.config
self.task = ctx.task
self.trainer = ctx.trainer
self.task.setup(self.trainer)
self.task.run()
def checkpoint(self, *args, **kwargs):
new_runner = Runner()
self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True)
self.config["checkpoint"] = self.task.chkpt_path
self.config["timestamp_id"] = self.trainer.timestamp_id
if self.trainer.logger is not None:
self.trainer.logger.mark_preempting()
return submitit.helpers.DelayedSubmission(new_runner, self.config)
if __name__ == "__main__":
setup_logging()
parser = flags.get_parser()
args, override_args = parser.parse_known_args()
config = build_config(args, override_args)
if args.submit: # Run on cluster
slurm_add_params = config.get(
"slurm", None
) # additional slurm arguments
if args.sweep_yml: # Run grid search
configs = create_grid(config, args.sweep_yml)
else:
configs = [config]
logging.info(f"Submitting {len(configs)} jobs")
executor = submitit.AutoExecutor(
folder=args.logdir / "%j", slurm_max_num_timeout=3
)
executor.update_parameters(
name=args.identifier,
mem_gb=args.slurm_mem,
timeout_min=args.slurm_timeout * 60,
slurm_partition=args.slurm_partition,
gpus_per_node=args.num_gpus,
cpus_per_task=(config["optim"]["num_workers"] + 1),
tasks_per_node=(args.num_gpus if args.distributed else 1),
nodes=args.num_nodes,
slurm_additional_parameters=slurm_add_params,
)
for config in configs:
config["slurm"] = copy.deepcopy(executor.parameters)
config["slurm"]["folder"] = str(executor.folder)
jobs = executor.map_array(Runner(), configs)
logging.info(
f"Submitted jobs: {', '.join([job.job_id for job in jobs])}"
)
log_file = save_experiment_log(args, jobs, configs)
logging.info(f"Experiment log saved to: {log_file}")
else: # Run locally
Runner()(config)