Skip to content

Commit

Permalink
Merge pull request #5 from ChiahsinChu/main
Browse files Browse the repository at this point in the history
Update dpdata for dplr and dpff
  • Loading branch information
link89 authored Nov 22, 2024
2 parents c2e5d11 + 7bd7edf commit bfe1108
Show file tree
Hide file tree
Showing 24 changed files with 23,685 additions and 233 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ If you want to run `ai2-kit` from source, you can run the following commands in

```bash
pip install poetry
# If you meet ConnectionError:
# poetry config installer.max-workers 4
poetry install

./ai2-kit --help
Expand Down
20 changes: 13 additions & 7 deletions ai2_kit/domain/deepmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
import copy
import random
import dpdata
import numpy as np

from .iface import ICllTrainOutput, BaseCllContext, TRAINING_MODE
from .data import DataFormat, get_data_format
from .dpff import set_dplr_ext_from_cp2k_output
from .dpff import set_dpff_ext_from_cp2k_output
from .dplr import dplr_v3_to_v2

from .constant import (
DP_CHECKPOINT_FILE,
Expand Down Expand Up @@ -532,7 +534,7 @@ def make_deepmd_dataset(
# Arguments of DPLR model can be found here:
# https://github.com/deepmodeling/deepmd-kit/blob/master/doc/model/dplr.md
try:
set_dplr_ext_from_cp2k_output(
set_dpff_ext_from_cp2k_output(
dp_sys=dp_system,
cp2k_output=os.path.join(raw_data['url'], 'output'),
wannier_file=os.path.join(raw_data['url'], 'wannier.xyz'),
Expand All @@ -545,7 +547,7 @@ def make_deepmd_dataset(
sel_type=sel_type,
)
except Exception:
logger.exception(f'dpff: failed to set dplr ext')
logger.exception(f'dpff: failed to set dpff ext')
continue
else:
raise ValueError(f"Unsupported data format: {data_format}")
Expand All @@ -562,14 +564,14 @@ def make_deepmd_dataset(

_write_dp_dataset = _write_dp_dataset_by_formula if group_by_formula else _write_dp_dataset_by_ancestor

dataset_dirs = _write_dp_dataset(dp_system_list=dataset_collection, out_dir=dataset_dir, type_map=type_map)
outlier_dirs = _write_dp_dataset(dp_system_list=outlier_collection, out_dir=outlier_dir, type_map=type_map)
dataset_dirs = _write_dp_dataset(dp_system_list=dataset_collection, out_dir=dataset_dir, type_map=type_map, sel_type=sel_type)
outlier_dirs = _write_dp_dataset(dp_system_list=outlier_collection, out_dir=outlier_dir, type_map=type_map, sel_type=sel_type)

return dataset_dirs, outlier_dirs


def _write_dp_dataset_by_formula(dp_system_list: List[Tuple[ArtifactDict, dpdata.LabeledSystem]],
out_dir: str, type_map: List[str]):
out_dir: str, type_map: List[str], sel_type: Optional[List[int]] = None):
"""
Write dp dataset that grouping by formula
Use dpdata.MultipleSystems to merge systems with the same formula
Expand All @@ -585,6 +587,8 @@ def _write_dp_dataset_by_formula(dp_system_list: List[Tuple[ArtifactDict, dpdata

# pylint: disable=no-member
multi_systems.to_deepmd_npy(out_dir, type_map=type_map) # type: ignore
if sel_type is not None:
dplr_v3_to_v2(out_dir, np.array(type_map)[sel_type].tolist())

return [ {
'url': os.path.join(out_dir, fname),
Expand All @@ -593,7 +597,7 @@ def _write_dp_dataset_by_formula(dp_system_list: List[Tuple[ArtifactDict, dpdata
} for fname in os.listdir(out_dir)]


def _write_dp_dataset_by_ancestor(dp_system_list: List[Tuple[ArtifactDict, dpdata.LabeledSystem]], out_dir: str, type_map: List[str]):
def _write_dp_dataset_by_ancestor(dp_system_list: List[Tuple[ArtifactDict, dpdata.LabeledSystem]], out_dir: str, type_map: List[str], sel_type: Optional[List[int]] = None):
"""
write dp dataset that grouping by ancestor
"""
Expand All @@ -612,6 +616,8 @@ def get_ancestor(x):
for item in dp_system_group[1:]:
dp_system += item[1]
dp_system.to_deepmd_npy(group_out_dir, set_size=len(dp_system), type_map=type_map) # type: ignore
if sel_type is not None:
dplr_v3_to_v2(out_dir, np.array(type_map)[sel_type].tolist())
# inherit attrs key from input artifact
output_dirs.append({'url': group_out_dir,
'format': DataFormat.DEEPMD_NPY,
Expand Down
Loading

0 comments on commit bfe1108

Please sign in to comment.