Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
JaGeo committed Jan 26, 2024
1 parent 49a9e7e commit 959b73c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 53 deletions.
66 changes: 33 additions & 33 deletions autoplex/auto/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,14 @@ class CompleteDFTvsMLBenchmarkWorkflow(
uc: bool = False

def make(
self,
structure_list: list[Structure],
mp_ids,
xyz_file: str | None = None,
dft_reference: PhononBSDOSDoc | None = None,
benchmark_structure: Structure | None = None,
mp_id: str | None = None,
**fit_kwargs,
self,
structure_list: list[Structure],
mp_ids,
xyz_file: str | None = None,
dft_reference: PhononBSDOSDoc | None = None,
benchmark_structure: Structure | None = None,
mp_id: str | None = None,
**fit_kwargs,
):
"""
Make flow for adding data to the dataset.
Expand Down Expand Up @@ -295,7 +295,7 @@ def make(
if (mp_id in mp_ids) and self.add_dft_phonon_struct:
dft_reference = fit_input[mp_id]["phonon_data"]["001"]
elif (mp_id not in mp_ids) or ( # else?
self.add_dft_phonon_struct is False
self.add_dft_phonon_struct is False
):
dft_phonons = DFTPhononMaker(
symprec=self.symprec,
Expand Down Expand Up @@ -329,12 +329,12 @@ def make(
return Flow(flows, collect_bm.output)

def add_dft_phonons(
self,
structure: Structure,
displacements,
symprec,
phonon_displacement_maker,
min_length,
self,
structure: Structure,
displacements,
symprec,
phonon_displacement_maker,
min_length,
):
additonal_dft_phonon = dft_phonopy_gen_data(
structure, displacements, symprec, phonon_displacement_maker, min_length
Expand All @@ -349,13 +349,13 @@ def add_dft_phonons(
)

def add_dft_random(
self,
structure: Structure,
mp_id,
phonon_displacement_maker,
n_struct,
uc,
supercell_matrix: Matrix3D | None = None,
self,
structure: Structure,
mp_id,
phonon_displacement_maker,
n_struct,
uc,
supercell_matrix: Matrix3D | None = None,
):
additonal_dft_random = dft_random_gen_data(
structure, mp_id, phonon_displacement_maker, n_struct, uc, supercell_matrix
Expand Down Expand Up @@ -474,12 +474,12 @@ class PhononDFTMLFitFlow(Maker):
name: str = "ML fit"

def make(
self,
species,
isolated_atoms_energy,
fit_input: dict,
xyz_file: str | None = None,
**fit_kwargs,
self,
species,
isolated_atoms_energy,
fit_input: dict,
xyz_file: str | None = None,
**fit_kwargs,
):
"""
Make flow for to fit potential.
Expand Down Expand Up @@ -525,11 +525,11 @@ class PhononDFTMLBenchmarkFlow(Maker):
name: str = "ML DFT benchmark"

def make(
self,
structure: Structure,
mp_id: str,
ml_phonon_task_doc: PhononBSDOSDoc,
dft_phonon_task_doc: PhononBSDOSDoc,
self,
structure: Structure,
mp_id: str,
ml_phonon_task_doc: PhononBSDOSDoc,
dft_phonon_task_doc: PhononBSDOSDoc,
):
"""
Create flow to benchmark the ML potential.
Expand Down
58 changes: 38 additions & 20 deletions tests/auto/test_auto_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def test_complete_dft_vs_ml_benchmark_workflow(
vasp_test_dir, mock_vasp, test_dir, memory_jobstore, clean_dir
vasp_test_dir, mock_vasp, test_dir, memory_jobstore, clean_dir
):
import pytest
from jobflow import run_locally
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_complete_dft_vs_ml_benchmark_workflow(


def test_add_data_to_dataset_workflow(
vasp_test_dir, mock_vasp, test_dir, memory_jobstore, clean_dir
vasp_test_dir, mock_vasp, test_dir, memory_jobstore, clean_dir
):
import pytest
from monty.serialization import loadfn
Expand All @@ -100,66 +100,81 @@ def test_add_data_to_dataset_workflow(
dft_reference: PhononBSDOSDoc = dft_data["output"]

add_data_workflow = CompleteDFTvsMLBenchmarkWorkflow(
n_struct=3, symprec=1e-2, min_length=8, displacements=[0.01],
phonon_displacement_maker=PhononDisplacementMaker()
n_struct=3,
symprec=1e-2,
min_length=8,
displacements=[0.01],
phonon_displacement_maker=PhononDisplacementMaker(),
).make(
structure_list=[structure],
mp_ids=["test"],
mp_id="mp-22905",
benchmark_structure=structure,
xyz_file=test_dir / "fitting" / "ref_files" / "trainGAP.xyz",
dft_reference=None
dft_reference=None,
)

add_data_workflow_with_dft_reference = CompleteDFTvsMLBenchmarkWorkflow(
n_struct=3, symprec=1e-2, min_length=8, displacements=[0.01],
n_struct=3,
symprec=1e-2,
min_length=8,
displacements=[0.01],
add_dft_phonon_struct=False,
phonon_displacement_maker=PhononDisplacementMaker()
phonon_displacement_maker=PhononDisplacementMaker(),
).make(
structure_list=[structure],
mp_ids=["test"],
mp_id="mp-22905",
benchmark_structure=structure,
xyz_file=test_dir / "fitting" / "ref_files" / "trainGAP.xyz",
dft_reference=dft_reference
dft_reference=dft_reference,
)

add_data_workflow_add_phonon_false = CompleteDFTvsMLBenchmarkWorkflow(
n_struct=3, symprec=1e-2, min_length=8, displacements=[0.01],
n_struct=3,
symprec=1e-2,
min_length=8,
displacements=[0.01],
add_dft_phonon_struct=False,
phonon_displacement_maker=PhononDisplacementMaker()
phonon_displacement_maker=PhononDisplacementMaker(),
).make(
structure_list=[structure],
mp_ids=["test"],
mp_id="mp-22905",
benchmark_structure=structure,
xyz_file=test_dir / "fitting" / "ref_files" / "trainGAP.xyz",
dft_reference=None
dft_reference=None,
)

add_data_workflow_add_random_false = CompleteDFTvsMLBenchmarkWorkflow(
n_struct=3, symprec=1e-2, min_length=8, displacements=[0.01],
n_struct=3,
symprec=1e-2,
min_length=8,
displacements=[0.01],
add_dft_random_struct=False,
phonon_displacement_maker=PhononDisplacementMaker()
phonon_displacement_maker=PhononDisplacementMaker(),
).make(
structure_list=[structure],
mp_ids=["test"],
mp_id="mp-22905",
benchmark_structure=structure,
xyz_file=test_dir / "fitting" / "ref_files" / "trainGAP.xyz",
dft_reference=None
dft_reference=None,
)

add_data_workflow_with_same_mpid = CompleteDFTvsMLBenchmarkWorkflow(
n_struct=3, symprec=1e-2, min_length=8, displacements=[0.01],
phonon_displacement_maker=PhononDisplacementMaker()
n_struct=3,
symprec=1e-2,
min_length=8,
displacements=[0.01],
phonon_displacement_maker=PhononDisplacementMaker(),
).make(
structure_list=[structure],
mp_ids=["mp-22905"],
mp_id="mp-22905",
benchmark_structure=structure,
xyz_file=test_dir / "fitting" / "ref_files" / "trainGAP.xyz",
dft_reference=None
dft_reference=None,
)

ref_paths = {
Expand Down Expand Up @@ -216,7 +231,8 @@ def test_add_data_to_dataset_workflow(
0.5716963823412201, abs=0.02
)
for job, uuid in add_data_workflow.iterflow():
if "dft_phonopy_gen_data" in job.name: assert True
if "dft_phonopy_gen_data" in job.name:
assert True
for job, uuid in add_data_workflow_with_dft_reference.iterflow():
assert job.name != "dft_phonopy_gen_data"
for job, uuid in add_data_workflow_add_phonon_false.iterflow():
Expand All @@ -230,7 +246,7 @@ def test_add_data_to_dataset_workflow(


def test_phonon_dft_ml_data_generation_flow(
vasp_test_dir, mock_vasp, clean_dir, memory_jobstore
vasp_test_dir, mock_vasp, clean_dir, memory_jobstore
):
from jobflow import run_locally

Expand Down Expand Up @@ -332,4 +348,6 @@ def test_phonon_dft_ml_data_generation_flow(
for item in output.resolve(store=memory_jobstore):
paths_to_rand_calcs_worattled.append(item)

assert len(paths_to_phonon_calcs_worattled) + len(paths_to_rand_calcs_worattled) == 2
assert (
len(paths_to_phonon_calcs_worattled) + len(paths_to_rand_calcs_worattled) == 2
)

0 comments on commit 959b73c

Please sign in to comment.