From 0a90c9f6acf8d04f1849b0abfb2cbd57d83fe7f2 Mon Sep 17 00:00:00 2001 From: naik-aakash Date: Wed, 8 Jan 2025 17:36:19 +0100 Subject: [PATCH] add simple tests for RSSMaker __post_init__ --- src/autoplex/auto/rss/flows.py | 7 +- .../auto/rss/rss_default_configuration.yaml | 2 +- tests/auto/rss/test_flows.py | 19 ++- tests/test_data/rss/rss_config.yaml | 143 ++++++++++++++++++ 4 files changed, 165 insertions(+), 6 deletions(-) create mode 100644 tests/test_data/rss/rss_config.yaml diff --git a/src/autoplex/auto/rss/flows.py b/src/autoplex/auto/rss/flows.py index 8529aa44b..0d5d54056 100644 --- a/src/autoplex/auto/rss/flows.py +++ b/src/autoplex/auto/rss/flows.py @@ -22,7 +22,7 @@ class RssMaker(Maker): name: str Name of the flow. config_file: Path | str | None - Path to the configuration file that defines the setup parameters for the whole RSS workflow. + Path to the custom configuration file that defines the setup parameters for the whole RSS workflow. If not provided, the default file 'rss_default_configuration.yaml' will be used. """ @@ -36,10 +36,9 @@ def __post_init__(self) -> None: new_config = loadfn(self.config_file) for key, value in new_config.items(): - if key in self.CONFIG and isinstance(value, self.CONFIG[key]): + # TODO: Need better defaults in default file or we move to pydantic models + if key in self.CONFIG and isinstance(value, type(self.CONFIG[key])): self.CONFIG[key] = value - else: - raise ValueError(f"Invalid key '{key}' in the configuration file.") @job def make(self, **kwargs): diff --git a/src/autoplex/auto/rss/rss_default_configuration.yaml b/src/autoplex/auto/rss/rss_default_configuration.yaml index 89cdc713a..fd46d013d 100644 --- a/src/autoplex/auto/rss/rss_default_configuration.yaml +++ b/src/autoplex/auto/rss/rss_default_configuration.yaml @@ -1,5 +1,5 @@ # General Parameters -tag: +tag: '' train_from_scratch: true resume_from_previous_state: test_error: diff --git a/tests/auto/rss/test_flows.py b/tests/auto/rss/test_flows.py index da0093cbc..86d415992 100644 --- a/tests/auto/rss/test_flows.py +++ b/tests/auto/rss/test_flows.py @@ -1,8 +1,8 @@ import os -import pytest from pathlib import Path from jobflow import run_locally, Flow from tests.conftest import mock_rss, mock_do_rss_iterations, mock_do_rss_iterations_multi_jobs +from autoplex.auto.rss.flows import RssMaker os.environ["OMP_NUM_THREADS"] = "1" @@ -307,3 +307,20 @@ def test_mock_workflow_multi_node(test_dir, mock_vasp, memory_jobstore, clean_di selected_atoms = job2.output.resolve(memory_jobstore) assert len(selected_atoms) == 3 + +def test_rssmaker_custom_config(test_dir): + + # For now only test if __post_init is working and updating defaults + rss = RssMaker(config_file= test_dir / "rss" / "rss_config.yaml") + + # TODO: test needs to be more robust after updating default config files + assert rss.CONFIG["tag"] == "test" + assert rss.CONFIG["generated_struct_numbers"] == [9000, 1000] + assert rss.CONFIG["num_processes_buildcell"] == 64 + assert rss.CONFIG["num_processes_fit"] == 64 + assert rss.CONFIG["device_for_rss"] == "gpu" + assert rss.CONFIG["isolatedatom_box"] == [10, 10, 10] + assert rss.CONFIG["dimer_box"] == [10, 10, 10] + + + diff --git a/tests/test_data/rss/rss_config.yaml b/tests/test_data/rss/rss_config.yaml new file mode 100644 index 000000000..a146ade12 --- /dev/null +++ b/tests/test_data/rss/rss_config.yaml @@ -0,0 +1,143 @@ +# General Parameters +tag: "test" +train_from_scratch: true +resume_from_previous_state: + test_error: + pre_database_dir: + mlip_path: + isolated_atom_energies: + +# Buildcell Parameters +generated_struct_numbers: + - 9000 + - 1000 +buildcell_options: +fragment_file: null +fragment_numbers: null +num_processes_buildcell: 64 + +# Sampling Parameters +num_of_initial_selected_structs: + - 80 + - 20 +num_of_rss_selected_structs: 100 +initial_selection_enabled: true +rss_selection_method: 'bcur2i' +bcur_params: + soap_paras: + l_max: 12 + n_max: 12 + atom_sigma: 0.0875 + cutoff: 10.5 + cutoff_transition_width: 1.0 + zeta: 4.0 + average: true + species: true + frac_of_bcur: 0.8 + bolt_max_num: 3000 +random_seed: null + +# DFT Labelling Parameters +include_isolated_atom: true +isolatedatom_box: + - 10.0 + - 10.0 + - 10.0 +e0_spin: false +include_dimer: true +dimer_box: + - 10.0 + - 10.0 + - 10.0 +dimer_range: + - 1.0 + - 5.0 +dimer_num: 21 +custom_incar: + ISMEAR: 0 + SIGMA: 0.05 + PREC: 'Accurate' + ADDGRID: '.TRUE.' + EDIFF: 1e-7 + NELM: 250 + LWAVE: '.FALSE.' + LCHARG: '.FALSE.' + ALGO: 'Normal' + AMIX: null + LREAL: '.FALSE.' + ISYM: 0 + ENCUT: 520.0 + KSPACING: 0.20 + GGA: null + KPAR: 8 + NCORE: 16 + LSCALAPACK: '.FALSE.' + LPLANE: '.FALSE.' +custom_potcar: +vasp_ref_file: 'vasp_ref.extxyz' + +# Data Preprocessing Parameters +config_types: + - 'initial' + - 'traj_early' + - 'traj' +rss_group: + - 'traj' +test_ratio: 0.1 +regularization: true +scheme: 'linear-hull' +reg_minmax: + - [0.1, 1] + - [0.001, 0.1] + - [0.0316, 0.316] + - [0.0632, 0.632] +distillation: false +force_max: null +force_label: null +pre_database_dir: null + +# MLIP Parameters +mlip_type: 'GAP' +ref_energy_name: 'REF_energy' +ref_force_name: 'REF_forces' +ref_virial_name: 'REF_virial' +auto_delta: true +num_processes_fit: 64 +device_for_fitting: 'cpu' +##The following hyperparameters are only applicable to GAP. +##If you want to use other models, please replace the corresponding hyperparameters. +twob: + cutoff: 5.0 + n_sparse: 15 + theta_uniform: 1.0 +threeb: + cutoff: 3.0 +soap: + l_max: 10 + n_max: 10 + atom_sigma: 0.5 + n_sparse: 2500 + cutoff: 5.0 +general: + three_body: false + +# RSS Exploration Parameters +scalar_pressure_method: 'uniform' +scalar_exp_pressure: 1 +scalar_pressure_exponential_width: 0.2 +scalar_pressure_low: 0 +scalar_pressure_high: 25 +max_steps: 300 +force_tol: 0.01 +stress_tol: 0.01 +stop_criterion: 0.01 +max_iteration_number: 25 +num_groups: 6 +initial_kt: 0.3 +current_iter_index: 1 +hookean_repul: false +hookean_paras: +keep_symmetry: false +write_traj: true +num_processes_rss: 128 +device_for_rss: 'gpu'