Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Jan 22, 2025
1 parent 9c23c96 commit a6dba0e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
10 changes: 5 additions & 5 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class SchedulerConfig:


@dataclass
class FairchemJobConfig:
class JobConfig:
run_name: str = field(default_factory=lambda: uuid.uuid4().hex.upper()[0:8])
timestamp_id: str = field(default_factory=lambda: get_timestamp_uid())
run_dir: str = field(default_factory=lambda: tempfile.TemporaryDirectory().name)
Expand All @@ -85,14 +85,14 @@ def checkpoint_dir(self) -> str:
class Submitit(Checkpointable):
def __call__(self, dict_config: DictConfig) -> None:
self.config = dict_config
job_config: FairchemJobConfig = OmegaConf.to_object(dict_config.job)
job_config: JobConfig = OmegaConf.to_object(dict_config.job)
# TODO: setup_imports is not needed if we stop instantiating models with Registry.
setup_imports()
setup_env_vars()
distutils.setup(map_job_config_to_dist_config(job_config))
self._init_logger()
runner: Runner = hydra.utils.instantiate(dict_config.runner)
runner.fairchem_config = job_config
runner.job_config = job_config
runner.load_state()
runner.run()
distutils.cleanup()
Expand Down Expand Up @@ -126,7 +126,7 @@ def checkpoint(self, *args, **kwargs) -> DelayedSubmission:
return DelayedSubmission(new_runner, self.config, self.cli_args)


def map_job_config_to_dist_config(job_cfg: FairchemJobConfig) -> dict:
def map_job_config_to_dist_config(job_cfg: JobConfig) -> dict:
scheduler_config = job_cfg.scheduler
return {
"world_size": scheduler_config.num_nodes * scheduler_config.ranks_per_node,
Expand Down Expand Up @@ -165,7 +165,7 @@ def main(

cfg = get_hydra_config_from_yaml(args.config_yml, override_args)
# merge default structured config with job
cfg = OmegaConf.merge({"job": OmegaConf.structured(FairchemJobConfig)}, cfg)
cfg = OmegaConf.merge({"job": OmegaConf.structured(JobConfig)}, cfg)
log_dir = OmegaConf.to_object(cfg.job).log_dir
os.makedirs(cfg.job.run_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
Expand Down
12 changes: 6 additions & 6 deletions src/fairchem/core/components/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
from omegaconf import DictConfig

from fairchem.core._cli_hydra import FairchemJobConfig
from fairchem.core._cli_hydra import JobConfig


class Runner(metaclass=ABCMeta):
Expand All @@ -17,12 +17,12 @@ class Runner(metaclass=ABCMeta):
"""

@property
def fairchem_config(self) -> FairchemJobConfig:
return self._fairchem_config
def job_config(self) -> JobConfig:
return self._job_config

@fairchem_config.setter
def fairchem_config(self, cfg: DictConfig):
self._fairchem_config = cfg
@job_config.setter
def job_config(self, cfg: DictConfig):
self._job_config = cfg

@abstractmethod
def run(self) -> Any:
Expand Down

0 comments on commit a6dba0e

Please sign in to comment.