Skip to content

Commit

Permalink
add simple tests for RSSMaker __post_init__
Browse files Browse the repository at this point in the history
  • Loading branch information
naik-aakash committed Jan 8, 2025
1 parent 354fbd6 commit 0a90c9f
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 6 deletions.
7 changes: 3 additions & 4 deletions src/autoplex/auto/rss/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class RssMaker(Maker):
name: str
Name of the flow.
config_file: Path | str | None
Path to the configuration file that defines the setup parameters for the whole RSS workflow.
Path to the custom configuration file that defines the setup parameters for the whole RSS workflow.
If not provided, the default file 'rss_default_configuration.yaml' will be used.
"""

Expand All @@ -36,10 +36,9 @@ def __post_init__(self) -> None:
new_config = loadfn(self.config_file)

for key, value in new_config.items():
if key in self.CONFIG and isinstance(value, self.CONFIG[key]):
# TODO: Need better defaults in default file or we move to pydantic models
if key in self.CONFIG and isinstance(value, type(self.CONFIG[key])):
self.CONFIG[key] = value
else:
raise ValueError(f"Invalid key '{key}' in the configuration file.")

@job
def make(self, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion src/autoplex/auto/rss/rss_default_configuration.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# General Parameters
tag:
tag: ''
train_from_scratch: true
resume_from_previous_state:
test_error:
Expand Down
19 changes: 18 additions & 1 deletion tests/auto/rss/test_flows.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import pytest
from pathlib import Path
from jobflow import run_locally, Flow
from tests.conftest import mock_rss, mock_do_rss_iterations, mock_do_rss_iterations_multi_jobs
from autoplex.auto.rss.flows import RssMaker

os.environ["OMP_NUM_THREADS"] = "1"

Expand Down Expand Up @@ -307,3 +307,20 @@ def test_mock_workflow_multi_node(test_dir, mock_vasp, memory_jobstore, clean_di
selected_atoms = job2.output.resolve(memory_jobstore)

assert len(selected_atoms) == 3

def test_rssmaker_custom_config(test_dir):

# For now only test if __post_init is working and updating defaults
rss = RssMaker(config_file= test_dir / "rss" / "rss_config.yaml")

# TODO: test needs to be more robust after updating default config files
assert rss.CONFIG["tag"] == "test"
assert rss.CONFIG["generated_struct_numbers"] == [9000, 1000]
assert rss.CONFIG["num_processes_buildcell"] == 64
assert rss.CONFIG["num_processes_fit"] == 64
assert rss.CONFIG["device_for_rss"] == "gpu"
assert rss.CONFIG["isolatedatom_box"] == [10, 10, 10]
assert rss.CONFIG["dimer_box"] == [10, 10, 10]



143 changes: 143 additions & 0 deletions tests/test_data/rss/rss_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# General Parameters
tag: "test"
train_from_scratch: true
resume_from_previous_state:
test_error:
pre_database_dir:
mlip_path:
isolated_atom_energies:

# Buildcell Parameters
generated_struct_numbers:
- 9000
- 1000
buildcell_options:
fragment_file: null
fragment_numbers: null
num_processes_buildcell: 64

# Sampling Parameters
num_of_initial_selected_structs:
- 80
- 20
num_of_rss_selected_structs: 100
initial_selection_enabled: true
rss_selection_method: 'bcur2i'
bcur_params:
soap_paras:
l_max: 12
n_max: 12
atom_sigma: 0.0875
cutoff: 10.5
cutoff_transition_width: 1.0
zeta: 4.0
average: true
species: true
frac_of_bcur: 0.8
bolt_max_num: 3000
random_seed: null

# DFT Labelling Parameters
include_isolated_atom: true
isolatedatom_box:
- 10.0
- 10.0
- 10.0
e0_spin: false
include_dimer: true
dimer_box:
- 10.0
- 10.0
- 10.0
dimer_range:
- 1.0
- 5.0
dimer_num: 21
custom_incar:
ISMEAR: 0
SIGMA: 0.05
PREC: 'Accurate'
ADDGRID: '.TRUE.'
EDIFF: 1e-7
NELM: 250
LWAVE: '.FALSE.'
LCHARG: '.FALSE.'
ALGO: 'Normal'
AMIX: null
LREAL: '.FALSE.'
ISYM: 0
ENCUT: 520.0
KSPACING: 0.20
GGA: null
KPAR: 8
NCORE: 16
LSCALAPACK: '.FALSE.'
LPLANE: '.FALSE.'
custom_potcar:
vasp_ref_file: 'vasp_ref.extxyz'

# Data Preprocessing Parameters
config_types:
- 'initial'
- 'traj_early'
- 'traj'
rss_group:
- 'traj'
test_ratio: 0.1
regularization: true
scheme: 'linear-hull'
reg_minmax:
- [0.1, 1]
- [0.001, 0.1]
- [0.0316, 0.316]
- [0.0632, 0.632]
distillation: false
force_max: null
force_label: null
pre_database_dir: null

# MLIP Parameters
mlip_type: 'GAP'
ref_energy_name: 'REF_energy'
ref_force_name: 'REF_forces'
ref_virial_name: 'REF_virial'
auto_delta: true
num_processes_fit: 64
device_for_fitting: 'cpu'
##The following hyperparameters are only applicable to GAP.
##If you want to use other models, please replace the corresponding hyperparameters.
twob:
cutoff: 5.0
n_sparse: 15
theta_uniform: 1.0
threeb:
cutoff: 3.0
soap:
l_max: 10
n_max: 10
atom_sigma: 0.5
n_sparse: 2500
cutoff: 5.0
general:
three_body: false

# RSS Exploration Parameters
scalar_pressure_method: 'uniform'
scalar_exp_pressure: 1
scalar_pressure_exponential_width: 0.2
scalar_pressure_low: 0
scalar_pressure_high: 25
max_steps: 300
force_tol: 0.01
stress_tol: 0.01
stop_criterion: 0.01
max_iteration_number: 25
num_groups: 6
initial_kt: 0.3
current_iter_index: 1
hookean_repul: false
hookean_paras:
keep_symmetry: false
write_traj: true
num_processes_rss: 128
device_for_rss: 'gpu'

0 comments on commit 0a90c9f

Please sign in to comment.