Skip to content

Commit

Permalink
add git hash changes announcement
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Jan 31, 2025
1 parent d513ffa commit d042ffa
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
36 changes: 24 additions & 12 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class SchedulerConfig:
mode: SchedulerType = SchedulerType.LOCAL
ranks_per_node: int = 1
num_nodes: int = 1
num_jobs: int = 1
slurm: dict = field(
default_factory=lambda: {
"mem_gb": 80, # slurm mem in GB
Expand Down Expand Up @@ -83,7 +84,7 @@ def checkpoint_dir(self) -> str:


class Submitit(Checkpointable):
def __call__(self, dict_config: DictConfig) -> None:
def __call__(self, dict_config: DictConfig, **run_kwargs) -> None:
self.config = dict_config
job_config: JobConfig = OmegaConf.to_object(dict_config.job)
# TODO: setup_imports is not needed if we stop instantiating models with Registry.
Expand All @@ -94,7 +95,7 @@ def __call__(self, dict_config: DictConfig) -> None:
runner: Runner = hydra.utils.instantiate(dict_config.runner)
runner.job_config = job_config
runner.load_state()
runner.run()
runner.run(**run_kwargs)
distutils.cleanup()

def _init_logger(self) -> None:
Expand Down Expand Up @@ -151,10 +152,6 @@ def get_hydra_config_from_yaml(
return hydra.compose(config_name=config_name, overrides=overrides_args)


def runner_wrapper(config: DictConfig):
Submitit()(config)


def main(
args: argparse.Namespace | None = None, override_args: list[str] | None = None
):
Expand Down Expand Up @@ -188,10 +185,25 @@ def main(
slurm_qos=scheduler_cfg.slurm.qos,
slurm_account=scheduler_cfg.slurm.account,
)
job = executor.submit(runner_wrapper, cfg)
logging.info(
f"Submitted job id: {job_cfg.timestamp_id}, slurm id: {job.job_id}, logs: {job_cfg.log_dir}"
)
if scheduler_cfg.num_jobs == 1:
job = executor.submit(Submitit(), cfg)
logging.info(
f"Submitted job id: {job_cfg.timestamp_id}, slurm id: {job.job_id}, logs: {job_cfg.log_dir}"
)
elif scheduler_cfg.num_jobs > 1:
executor.update_parameters(slurm_array_parallelism=scheduler_cfg.num_jobs)

jobs = []
with executor.batch():
for job_number in range(scheduler_cfg.num_jobs):
job = executor.submit(
Submitit(),
cfg,
job_number=job_number,
num_jobs=scheduler_cfg.num_jobs,
)
jobs.append(job)
logging.info(f"Submitted {len(jobs)} jobs: {jobs[0].job_id.split("_")[0]}")
else:
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

Expand All @@ -204,8 +216,8 @@ def main(
rdzv_backend="c10d",
max_restarts=0,
)
elastic_launch(launch_config, runner_wrapper)(cfg)
elastic_launch(launch_config, Submitit())(cfg)
else:
logging.info("Running in local mode without elastic launch")
distutils.setup_env_local()
runner_wrapper(cfg)
Submitit()(cfg)
2 changes: 1 addition & 1 deletion src/fairchem/core/components/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def job_config(self, cfg: DictConfig):
self._job_config = cfg

@abstractmethod
def run(self) -> Any:
def run(self, **kwargs) -> Any:
raise NotImplementedError

@abstractmethod
Expand Down

0 comments on commit d042ffa

Please sign in to comment.