Skip to content

Commit

Permalink
Merge pull request #132 from pyiron/central_dataclass
Browse files Browse the repository at this point in the history
Central data classes for output
  • Loading branch information
jan-janssen authored Dec 13, 2023
2 parents 44cd7b8 + a527fc9 commit d9601c1
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 53 deletions.
17 changes: 8 additions & 9 deletions atomistics/calculators/ase.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations
import dataclasses

from ase.constraints import UnitCellFilter
from typing import TYPE_CHECKING

from atomistics.calculators.output import AtomisticsOutput
from atomistics.calculators.output import OutputStatic
from atomistics.calculators.wrapper import as_task_dict_evaluator

if TYPE_CHECKING:
Expand All @@ -29,11 +28,11 @@ def get_stress(self):
return self.structure.get_stress(voigt=False)


@dataclasses.dataclass
class ASEStaticOutput(AtomisticsOutput):
forces: callable = ASEExecutor.get_forces
energy: callable = ASEExecutor.get_energy
stress: callable = ASEExecutor.get_stress
ASEOutputStatic = OutputStatic(
forces=ASEExecutor.get_forces,
energy=ASEExecutor.get_energy,
stress=ASEExecutor.get_stress,
)


@as_task_dict_evaluator
Expand Down Expand Up @@ -80,9 +79,9 @@ def evaluate_with_ase(
def calc_static_with_ase(
structure,
ase_calculator,
quantities=ASEStaticOutput.fields(),
quantities=OutputStatic.fields(),
):
return ASEStaticOutput.get(
return ASEOutputStatic.get(
ASEExecutor(ase_structure=structure, ase_calculator=ase_calculator), *quantities
)

Expand Down
15 changes: 9 additions & 6 deletions atomistics/calculators/lammps/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
LAMMPS_RUN,
LAMMPS_MINIMIZE_VOLUME,
)
from atomistics.calculators.lammps.output import LammpsMDOutput, LammpsStaticOutput
from atomistics.calculators.lammps.output import (
LammpsOutputMolecularDynamics,
LammpsOutputStatic,
)

if TYPE_CHECKING:
from ase import Atoms
Expand Down Expand Up @@ -113,7 +116,7 @@ def calc_static_with_lammps(
structure,
potential_dataframe,
lmp=None,
quantities=LammpsStaticOutput.fields(),
quantities=LammpsOutputStatic.fields(),
**kwargs,
):
template_str = LAMMPS_THERMO_STYLE + "\n" + LAMMPS_THERMO + "\n" + LAMMPS_RUN
Expand All @@ -127,7 +130,7 @@ def calc_static_with_lammps(
lmp=lmp,
**kwargs,
)
result_dict = LammpsStaticOutput.get(lmp_instance, *quantities)
result_dict = LammpsOutputStatic.get(lmp_instance, *quantities)
lammps_shutdown(lmp_instance=lmp_instance, close_instance=lmp is None)
return result_dict

Expand All @@ -144,7 +147,7 @@ def calc_molecular_dynamics_nvt_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=LammpsMDOutput.fields(),
quantities=LammpsOutputMolecularDynamics.fields(),
**kwargs,
):
init_str = (
Expand Down Expand Up @@ -201,7 +204,7 @@ def calc_molecular_dynamics_npt_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=LammpsMDOutput.fields(),
quantities=LammpsOutputMolecularDynamics.fields(),
**kwargs,
):
init_str = (
Expand Down Expand Up @@ -259,7 +262,7 @@ def calc_molecular_dynamics_nph_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=LammpsMDOutput.fields(),
quantities=LammpsOutputMolecularDynamics.fields(),
**kwargs,
):
init_str = (
Expand Down
8 changes: 4 additions & 4 deletions atomistics/calculators/lammps/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pylammpsmpi import LammpsASELibrary

from atomistics.calculators.lammps.potential import validate_potential_dataframe
from atomistics.calculators.lammps.output import LammpsMDOutput
from atomistics.calculators.lammps.output import LammpsOutputMolecularDynamics


def lammps_run(structure, potential_dataframe, input_template=None, lmp=None, **kwargs):
Expand Down Expand Up @@ -41,19 +41,19 @@ def lammps_calc_md_step(
lmp_instance,
run_str,
run,
quantities=LammpsMDOutput.fields(),
quantities=LammpsOutputMolecularDynamics.fields(),
):
run_str_rendered = Template(run_str).render(run=run)
lmp_instance.interactive_lib_command(run_str_rendered)
return LammpsMDOutput.get(lmp_instance, *quantities)
return LammpsOutputMolecularDynamics.get(lmp_instance, *quantities)


def lammps_calc_md(
lmp_instance,
run_str,
run,
thermo,
quantities=LammpsMDOutput.fields(),
quantities=LammpsOutputMolecularDynamics.fields(),
):
results_lst = [
lammps_calc_md_step(
Expand Down
37 changes: 16 additions & 21 deletions atomistics/calculators/lammps/output.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
import dataclasses

from atomistics.calculators.output import OutputStatic, OutputMolecularDynamics
from pylammpsmpi import LammpsASELibrary

from atomistics.calculators.output import AtomisticsOutput


@dataclasses.dataclass
class LammpsMDOutput(AtomisticsOutput):
positions: callable = LammpsASELibrary.interactive_positions_getter
cell: callable = LammpsASELibrary.interactive_cells_getter
forces: callable = LammpsASELibrary.interactive_forces_getter
temperature: callable = LammpsASELibrary.interactive_temperatures_getter
energy_pot: callable = LammpsASELibrary.interactive_energy_pot_getter
energy_tot: callable = LammpsASELibrary.interactive_energy_tot_getter
pressure: callable = LammpsASELibrary.interactive_pressures_getter
velocities: callable = LammpsASELibrary.interactive_velocities_getter


@dataclasses.dataclass
class LammpsStaticOutput(AtomisticsOutput):
forces: callable = LammpsASELibrary.interactive_forces_getter
energy: callable = LammpsASELibrary.interactive_energy_pot_getter
stress: callable = LammpsASELibrary.interactive_pressures_getter
LammpsOutputStatic = OutputStatic(
forces=LammpsASELibrary.interactive_forces_getter,
energy=LammpsASELibrary.interactive_energy_pot_getter,
stress=LammpsASELibrary.interactive_pressures_getter,
)
LammpsOutputMolecularDynamics = OutputMolecularDynamics(
positions=LammpsASELibrary.interactive_positions_getter,
cell=LammpsASELibrary.interactive_cells_getter,
forces=LammpsASELibrary.interactive_forces_getter,
temperature=LammpsASELibrary.interactive_temperatures_getter,
energy_pot=LammpsASELibrary.interactive_energy_pot_getter,
energy_tot=LammpsASELibrary.interactive_energy_tot_getter,
pressure=LammpsASELibrary.interactive_pressures_getter,
velocities=LammpsASELibrary.interactive_velocities_getter,
)
26 changes: 22 additions & 4 deletions atomistics/calculators/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,29 @@


@dataclasses.dataclass
class AtomisticsOutput:
class Output:
@classmethod
def fields(cls):
return tuple(field.name for field in dataclasses.fields(cls))

@classmethod
def get(cls, engine, *quantities: str) -> dict:
return {q: getattr(cls, q)(engine) for q in quantities}
def get(self, engine, *quantities: str) -> dict:
return {q: getattr(self, q)(engine) for q in quantities}


@dataclasses.dataclass
class OutputStatic(Output):
forces: callable
energy: callable
stress: callable


@dataclasses.dataclass
class OutputMolecularDynamics(Output):
positions: callable
cell: callable
forces: callable
temperature: callable
energy_pot: callable
energy_tot: callable
pressure: callable
velocities: callable
19 changes: 10 additions & 9 deletions atomistics/calculators/qe.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import dataclasses
import os
import subprocess

from ase.io import write
from pwtools import io

from atomistics.calculators.output import AtomisticsOutput
from atomistics.calculators.output import OutputStatic
from atomistics.calculators.wrapper import as_task_dict_evaluator


Expand All @@ -23,11 +22,11 @@ def get_stress(self):
return self.parser.stress


@dataclasses.dataclass
class QEStaticOutput(AtomisticsOutput):
forces: callable = QEStaticParser.get_forces
energy: callable = QEStaticParser.get_energy
stress: callable = QEStaticParser.get_stress
QuantumEspressoOutputStatic = OutputStatic(
forces=QEStaticParser.get_forces,
energy=QEStaticParser.get_energy,
stress=QEStaticParser.get_stress,
)


def call_qe_via_ase_command(calculation_name, working_directory):
Expand Down Expand Up @@ -185,7 +184,7 @@ def calc_static_with_qe(
pseudopotentials=None,
tstress=True,
tprnfor=True,
quantities=QEStaticOutput.fields(),
quantities=OutputStatic.fields(),
**kwargs,
):
input_file_name = os.path.join(working_directory, calculation_name + ".pwi")
Expand All @@ -211,7 +210,9 @@ def calc_static_with_qe(
call_qe_via_ase_command(
calculation_name=calculation_name, working_directory=working_directory
)
return QEStaticOutput.get(QEStaticParser(filename=output_file_name), *quantities)
return QuantumEspressoOutputStatic.get(
QEStaticParser(filename=output_file_name), *quantities
)


@as_task_dict_evaluator
Expand Down

0 comments on commit d9601c1

Please sign in to comment.