Skip to content

Commit

Permalink
add energy_label arg to preprocess_data function
Browse files Browse the repository at this point in the history
  • Loading branch information
naik-aakash committed Dec 14, 2024
1 parent bd4681c commit 33fbb16
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/autoplex/data/common/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ def preprocess_data(
distillation: bool = False,
force_max: float = 40,
force_label: str = "REF_forces",
energy_label: str = "REF_energy",
pre_database_dir: str | None = None,
reg_minmax: list[tuple] | None = None,
isolated_atom_energies: dict | None = None,
Expand Down Expand Up @@ -726,6 +727,8 @@ def preprocess_data(
Maximum force value to exclude structures.
force_label: str
The label of force values to use for distillation.
energy_label: str
The label of energy values to use for distillation.
pre_database_dir : str
Directory where the previous database was saved.
reg_minmax: list[tuple]
Expand All @@ -748,7 +751,9 @@ def preprocess_data(
if test_ratio == 0 or test_ratio is None:
train_structures, test_structures = atoms, atoms
else:
train_structures, test_structures = stratified_dataset_split(atoms, test_ratio)
train_structures, test_structures = stratified_dataset_split(
atoms, test_ratio, energy_label
)

if pre_database_dir and os.path.exists(pre_database_dir):
files_to_copy = ["train.extxyz", "test.extxyz"]
Expand Down

0 comments on commit 33fbb16

Please sign in to comment.