Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
naik-aakash committed Dec 15, 2024
1 parent 0484b5f commit e9365ee
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 13 deletions.
7 changes: 2 additions & 5 deletions tests/fitting/common/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_mlip_fit_maker_jace(

@pytest.mark.skip(reason="We can enable this after mock_nep fixture is added")
def test_mlip_fit_maker_nep(
test_dir, memory_jobstore, vasp_test_dir, fit_input_dict,
test_dir, memory_jobstore, vasp_test_dir, fit_input_dict, clean_dir
):
from pathlib import Path
from jobflow import run_locally
Expand All @@ -244,9 +244,6 @@ def test_mlip_fit_maker_nep(
mlip_type="NEP",
pre_database_dir=str(test_files_dir),
pre_xyz_files=["pre_xyz_train.extxyz", "pre_xyz_test.extxyz"],
ref_energy_name= "energy",
ref_force_name= "force",
ref_virial_name= "virial",
num_processes_fit=1,
apply_data_preprocessing=True,
).make(
Expand All @@ -255,7 +252,7 @@ def test_mlip_fit_maker_nep(
**{"generation": 100, "batch": 100},
)

run_locally(
_ = run_locally(
nepfit, ensure_success=True, create_folders=True, store=memory_jobstore
)

Expand Down
13 changes: 5 additions & 8 deletions tests/fitting/common/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,20 @@ def test_jace_fit_maker(test_dir, memory_jobstore, clean_dir):


@pytest.mark.skip(reason="We can enable this after mock_nep fixture is added")
def test_nep_fit_maker(test_dir, memory_jobstore):
database_dir = test_dir / "fitting/NEP/"
def test_nep_fit_maker(test_dir, memory_jobstore, clean_dir):
database_dir = test_dir / "fitting/ref_files/"

nepfit = MLIPFitMaker(
mlip_type="NEP",
num_processes_fit=1,
apply_data_preprocessing=False,
ref_force_name="forces",
ref_energy_name="energy",
ref_virial_name="virial",
database_dir=database_dir,
).make(
species_list=["Al"],
**{"generation": 100, "batch": 100},
species_list=["Li", "Cl"],
**{"generation": 100, "batch": 100, "type_weight":[0.5, 1.0]},
)

responses = run_locally(
_ = run_locally(
nepfit, ensure_success=True, create_folders=True, store=memory_jobstore
)

Expand Down

0 comments on commit e9365ee

Please sign in to comment.