diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index b8d7f5ffb0..cb08714829 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -63,7 +63,7 @@ class SchedulerConfig: @dataclass -class FairchemJobConfig: +class JobConfig: run_name: str = field(default_factory=lambda: uuid.uuid4().hex.upper()[0:8]) timestamp_id: str = field(default_factory=lambda: get_timestamp_uid()) run_dir: str = field(default_factory=lambda: tempfile.TemporaryDirectory().name) @@ -85,14 +85,14 @@ def checkpoint_dir(self) -> str: class Submitit(Checkpointable): def __call__(self, dict_config: DictConfig) -> None: self.config = dict_config - job_config: FairchemJobConfig = OmegaConf.to_object(dict_config.job) + job_config: JobConfig = OmegaConf.to_object(dict_config.job) # TODO: setup_imports is not needed if we stop instantiating models with Registry. setup_imports() setup_env_vars() distutils.setup(map_job_config_to_dist_config(job_config)) self._init_logger() runner: Runner = hydra.utils.instantiate(dict_config.runner) - runner.fairchem_config = job_config + runner.job_config = job_config runner.load_state() runner.run() distutils.cleanup() @@ -126,7 +126,7 @@ def checkpoint(self, *args, **kwargs) -> DelayedSubmission: return DelayedSubmission(new_runner, self.config, self.cli_args) -def map_job_config_to_dist_config(job_cfg: FairchemJobConfig) -> dict: +def map_job_config_to_dist_config(job_cfg: JobConfig) -> dict: scheduler_config = job_cfg.scheduler return { "world_size": scheduler_config.num_nodes * scheduler_config.ranks_per_node, @@ -165,7 +165,7 @@ def main( cfg = get_hydra_config_from_yaml(args.config_yml, override_args) # merge default structured config with job - cfg = OmegaConf.merge({"job": OmegaConf.structured(FairchemJobConfig)}, cfg) + cfg = OmegaConf.merge({"job": OmegaConf.structured(JobConfig)}, cfg) log_dir = OmegaConf.to_object(cfg.job).log_dir os.makedirs(cfg.job.run_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) diff --git a/src/fairchem/core/components/runner.py b/src/fairchem/core/components/runner.py index a76a3ea135..6f5bdfefb8 100644 --- a/src/fairchem/core/components/runner.py +++ b/src/fairchem/core/components/runner.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from omegaconf import DictConfig - from fairchem.core._cli_hydra import FairchemJobConfig + from fairchem.core._cli_hydra import JobConfig class Runner(metaclass=ABCMeta): @@ -17,12 +17,12 @@ class Runner(metaclass=ABCMeta): """ @property - def fairchem_config(self) -> FairchemJobConfig: - return self._fairchem_config + def job_config(self) -> JobConfig: + return self._job_config - @fairchem_config.setter - def fairchem_config(self, cfg: DictConfig): - self._fairchem_config = cfg + @job_config.setter + def job_config(self, cfg: DictConfig): + self._job_config = cfg @abstractmethod def run(self) -> Any: