Skip to content

Commit

Permalink
load yamls using postin > no need to copy to remote server
Browse files Browse the repository at this point in the history
  • Loading branch information
naik-aakash committed Dec 22, 2024
1 parent 5c231fb commit acb670d
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions src/autoplex/auto/rss/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,41 @@ class RssMaker(Maker):
path_to_default_config_parameters: Path | str | None
Path to the default RSS configuration file 'rss_default_configuration.yaml'.
If None, the default path will be used.
config_file: Path | str | None
Path to the configuration file that defines the setup parameters for the whole RSS workflow.
If not provided, the default file 'rss_default_configuration.yaml' will be used.
"""

name: str = "ml-driven rss"
path_to_default_config_parameters: Path | str | None = None
config_file: Path | str | None = None

def __post_init__(self) -> None:
"""Ensure that custom configuration parameters are loaded when the maker is initialized."""
rss_default_config_path = (
self.path_to_default_config_parameters
or Path(__file__).absolute().parent / "rss_default_configuration.yaml"
)

yaml = YAML(typ="safe", pure=True)

with open(rss_default_config_path) as f:
config = yaml.load(f)

if self.config_file and os.path.exists(self.config_file):
with open(self.config_file) as f:
new_config = yaml.load(f)
config.update(new_config)

self.config = config

@job
def make(self, config_file: str | None = None, **kwargs):
def make(self, **kwargs):
"""
Make a rss workflow using the specified configuration file and additional keyword arguments.
Parameters
----------
config_file: str | None
Path to the configuration file that defines the setup parameters for the whole RSS workflow.
If not provided, the default file 'rss_default_configuration.yaml' will be used.
kwargs: dict, optional
Additional optional keyword arguments to customize the job execution.
Expand Down Expand Up @@ -244,25 +264,10 @@ def make(self, config_file: str | None = None, **kwargs):
kb_temp: float
The temperature (in eV) for Boltzmann sampling.
"""
rss_default_config_path = (
self.path_to_default_config_parameters
or Path(__file__).absolute().parent / "rss_default_configuration.yaml"
)

yaml = YAML(typ="safe", pure=True)

with open(rss_default_config_path) as f:
config = yaml.load(f)

if config_file and os.path.exists(config_file):
with open(config_file) as f:
new_config = yaml.load(f)
config.update(new_config)

config.update(kwargs)
self._process_hookean_paras(config)
self.config.update(kwargs)
self._process_hookean_paras(self.config)

config_params = config.copy()
config_params = self.config.copy()

if "train_from_scratch" not in config_params:
raise ValueError(
Expand Down

0 comments on commit acb670d

Please sign in to comment.