Skip to content

Commit

Permalink
Merge pull request #9 from ChiahsinChu/main
Browse files Browse the repository at this point in the history
enhancement: dpdata for dipole/polarizability
  • Loading branch information
link89 authored Dec 24, 2024
2 parents 6ee64a4 + 9b33193 commit db8263b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
15 changes: 12 additions & 3 deletions ai2_kit/feat/spectrum/viber.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,15 @@ def dpdata_read_cp2k_viber_data(data_dir: str,
# get cell
cell = dp_sys.data['cells'][0]

# get selected atoms ids
symbols = np.array(dp_sys.data["atom_names"])[dp_sys.data["atom_types"]]
sel_ids = [np.where(symbols == atype)[0] for atype in lumped_dict.keys()]
sel_ids = np.concatenate(sel_ids)

# build the data of atomic_dipole and atomic_polarizability with numpy
wannier_atoms = ase.io.read(os.path.join(data_dir, wannier), index=":", format='extxyz')
n_frames = len(wannier_atoms)
n_atoms = np.sum(np.logical_not(wannier_atoms[0].symbols == wacent_symbol))

lumped_dict_c = lumped_dict.copy()
del_list = []
Expand All @@ -189,7 +196,9 @@ def dpdata_read_cp2k_viber_data(data_dir: str,
wfc_compute_polar = _set_lumped_wfc(stc_list, lumped_dict_c, cutoff, wacent_symbol, to_polar = True)
wfc_save = _set_lumped_wfc(stc_list, lumped_dict_c, cutoff, wacent_symbol, to_polar = False)

dp_sys.data['atomic_dipole'] = wfc_save
wannier = np.zeros([n_frames, n_atoms, 3])
wannier[:, sel_ids, :] = np.reshape(wfc_save, [n_frames, -1, 3])
dp_sys.data["atomic_dipole"] = wannier

if mode == 'both':
wannier_atoms_x = ase.io.read(os.path.join(data_dir, wannier_x), index=":", format='extxyz')
Expand All @@ -210,9 +219,9 @@ def dpdata_read_cp2k_viber_data(data_dir: str,
polar[:, :, 1] = (wfc_y - wfc_compute_polar) / eps
polar[:, :, 2] = (wfc_z - wfc_compute_polar) / eps

dp_sys.data['atomic_polarizability'] = polar.reshape(polar.shape[0], -1)
dp_sys.data['atomic_polarizability'] = np.reshape(polar, [n_frames, n_atoms, 9])
elif mode == 'dipole_only':
dp_sys.data['atomic_polarizability'] = np.array([-1,-1]).reshape(1,-1)
dp_sys.data['atomic_polarizability'] = np.full((n_frames, n_atoms, 9), -1.0)
else:
logger.warning(f"There is no mode called '{mode}', expected 'both' or 'dipole_only'")

Expand Down
53 changes: 52 additions & 1 deletion tests/test_dplr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_shape(self):
assert self.data.data["atomic_dipole"].shape[1] == natoms
assert self.data.data["atomic_dipole"].shape[2] == 3

def test_consisitent(self):
def test_consistent(self):
np.testing.assert_allclose(
self.data.data["atomic_dipole"][0],
self.data_with_ion.data["atomic_dipole"][0],
Expand Down Expand Up @@ -200,5 +200,56 @@ def test_merged_consistent(self):
shutil.rmtree(".tmp_data-merged")


class TestDPLRSorted(unittest.TestCase):
def setUp(self) -> None:
self.cp2k_output = "output"
self.wannier_file = "wannier.xyz"
self.type_map = [
"Na",
"S",
"O",
"N",
"Cl",
"H",
]
self.sel_type = [0, 2]

self.data_sorted = dpdata_read_cp2k_dplr_data(
str(Path(__file__).parent / "data-sample/cp2k_wannier_sort_test/sorted"),
self.cp2k_output,
self.wannier_file,
self.type_map,
self.sel_type,
)
self.data_random = dpdata_read_cp2k_dplr_data(
str(Path(__file__).parent / "data-sample/cp2k_wannier_sort_test/random"),
self.cp2k_output,
self.wannier_file,
self.type_map,
self.sel_type,
)

self.random_ids = np.loadtxt(
str(
Path(__file__).parent
/ "data-sample/cp2k_wannier_sort_test/random_ids.txt"
),
dtype=int,
)

def test_coord(self):
np.testing.assert_allclose(
self.data_sorted.data["coords"][0][self.random_ids],
self.data_random.data["coords"][0],
)

def test_consistent(self):
np.testing.assert_allclose(
self.data_sorted.data["atomic_dipole"][0][self.random_ids],
self.data_random.data["atomic_dipole"][0],
atol=1e-7,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit db8263b

Please sign in to comment.