Skip to content

Commit

Permalink
remove config_file arg and use pydantic model instead
Browse files Browse the repository at this point in the history
  • Loading branch information
naik-aakash committed Jan 9, 2025
1 parent 4ee2e2c commit 3e77720
Showing 1 changed file with 12 additions and 22 deletions.
34 changes: 12 additions & 22 deletions src/autoplex/auto/rss/flows.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""RSS (random structure searching) flow for exploring and learning potential energy surfaces from scratch."""

import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path

from jobflow import Flow, Maker, Response, job
from monty.serialization import loadfn

from autoplex.auto.rss.jobs import do_rss_iterations, initial_rss
from autoplex.settings import RssConfig

MODULE_DIR = Path(os.path.dirname(__file__))

Expand All @@ -21,24 +21,13 @@ class RssMaker(Maker):
----------
name: str
Name of the flow.
config_file: Path | str | None
Path to the custom configuration file that defines the setup parameters for the whole RSS workflow.
If not provided, the default file 'rss_default_configuration.yaml' will be used.
config: RssConfig | None
Pydantic model that defines the setup parameters for the whole RSS workflow.
If not provided, the defaults from 'autoplex.settings.RssConfig' will be used.
"""

name: str = "ml-driven rss"
config_file: Path | str | None = None
CONFIG = loadfn(MODULE_DIR / "rss_default_configuration.yaml")

def __post_init__(self) -> None:
"""Ensure that custom configuration parameters are loaded when the maker is initialized."""
if self.config_file and Path(self.config_file).resolve(strict=True):
new_config = loadfn(Path(self.config_file).resolve())

for key, value in new_config.items():
# TODO: Need better defaults in default file or we move to pydantic models
if key in self.CONFIG and isinstance(value, type(self.CONFIG[key])):
self.CONFIG[key] = value
config: RssConfig | None = field(default_factory=lambda: RssConfig())

@job
def make(self, **kwargs):
Expand Down Expand Up @@ -75,7 +64,7 @@ def make(self, **kwargs):
buildcell_options: list[dict] | None
Customized parameters for buildcell. Default is None.
fragment: Atoms | list[Atoms] | None
Fragment(s) for random structures, e.g., molecules, to be placed indivudally intact.
Fragment(s) for random structures, e.g., molecules, to be placed individually intact.
atoms.arrays should have a 'fragment_id' key with unique identifiers for each fragment if in same Atoms.
atoms.cell must be defined (e.g., Atoms.cell = np.eye(3)*20).
fragment_numbers: list[str] | None
Expand Down Expand Up @@ -245,10 +234,10 @@ def make(self, **kwargs):
- 'current_iter': int, The current iteration index.
- 'kb_temp': float, The temperature (in eV) for Boltzmann sampling.
"""
self.CONFIG.update(kwargs)
self._process_hookean_paras(self.CONFIG)
updated_config = self.config.model_copy(update=kwargs)
config_params = updated_config.model_dump()

config_params = self.CONFIG.copy()
self._process_hookean_paras(config_params)

if "train_from_scratch" not in config_params:
raise ValueError(
Expand Down Expand Up @@ -345,7 +334,8 @@ def make(self, **kwargs):

return Response(replace=Flow(rss_flow), output=do_rss_job.output)

def _process_hookean_paras(self, config):
@staticmethod
def _process_hookean_paras(config):
if "hookean_paras" in config:
config["hookean_paras"] = {
tuple(map(int, k.strip("()").split(", "))): tuple(v)
Expand Down

0 comments on commit 3e77720

Please sign in to comment.