From 8e3febeff7254e1696799ae0f672df3fc92e006c Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Mon, 23 Dec 2024 18:22:39 +0800 Subject: [PATCH 1/2] make viber adapted for dp v3 --- ai2_kit/feat/spectrum/viber.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ai2_kit/feat/spectrum/viber.py b/ai2_kit/feat/spectrum/viber.py index 58cb927..912b446 100644 --- a/ai2_kit/feat/spectrum/viber.py +++ b/ai2_kit/feat/spectrum/viber.py @@ -173,8 +173,15 @@ def dpdata_read_cp2k_viber_data(data_dir: str, # get cell cell = dp_sys.data['cells'][0] + # get selected atoms ids + symbols = np.array(dp_sys.data["atom_names"])[dp_sys.data["atom_types"]] + sel_ids = [np.where(symbols == atype)[0] for atype in lumped_dict.keys()] + sel_ids = np.concatenate(sel_ids) + # build the data of atomic_dipole and atomic_polarizability with numpy wannier_atoms = ase.io.read(os.path.join(data_dir, wannier), index=":", format='extxyz') + n_frames = len(wannier_atoms) + n_atoms = np.sum(np.logical_not(wannier_atoms[0].symbols == wacent_symbol)) lumped_dict_c = lumped_dict.copy() del_list = [] @@ -189,7 +196,9 @@ def dpdata_read_cp2k_viber_data(data_dir: str, wfc_compute_polar = _set_lumped_wfc(stc_list, lumped_dict_c, cutoff, wacent_symbol, to_polar = True) wfc_save = _set_lumped_wfc(stc_list, lumped_dict_c, cutoff, wacent_symbol, to_polar = False) - dp_sys.data['atomic_dipole'] = wfc_save + wannier = np.zeros([n_frames, n_atoms, 3]) + wannier[:, sel_ids, :] = np.reshape(wfc_save, [n_frames, -1, 3]) + dp_sys.data["atomic_dipole"] = wannier if mode == 'both': wannier_atoms_x = ase.io.read(os.path.join(data_dir, wannier_x), index=":", format='extxyz') @@ -210,9 +219,9 @@ def dpdata_read_cp2k_viber_data(data_dir: str, polar[:, :, 1] = (wfc_y - wfc_compute_polar) / eps polar[:, :, 2] = (wfc_z - wfc_compute_polar) / eps - dp_sys.data['atomic_polarizability'] = polar.reshape(polar.shape[0], -1) + dp_sys.data['atomic_polarizability'] = np.reshape(polar, [n_frames, n_atoms, 9]) elif mode == 'dipole_only': - dp_sys.data['atomic_polarizability'] = np.array([-1,-1]).reshape(1,-1) + dp_sys.data['atomic_polarizability'] = np.full((n_frames, n_atoms, 9), -1.0) else: logger.warning(f"There is no mode called '{mode}', expected 'both' or 'dipole_only'") From 9b331936ec909b9b7cf7a704a7345b08a12db2bb Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Mon, 23 Dec 2024 18:38:19 +0800 Subject: [PATCH 2/2] add sorting UT for dplr in dpdata --- tests/test_dplr.py | 53 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/tests/test_dplr.py b/tests/test_dplr.py index 2836081..b4826b4 100644 --- a/tests/test_dplr.py +++ b/tests/test_dplr.py @@ -64,7 +64,7 @@ def test_shape(self): assert self.data.data["atomic_dipole"].shape[1] == natoms assert self.data.data["atomic_dipole"].shape[2] == 3 - def test_consisitent(self): + def test_consistent(self): np.testing.assert_allclose( self.data.data["atomic_dipole"][0], self.data_with_ion.data["atomic_dipole"][0], @@ -200,5 +200,56 @@ def test_merged_consistent(self): shutil.rmtree(".tmp_data-merged") +class TestDPLRSorted(unittest.TestCase): + def setUp(self) -> None: + self.cp2k_output = "output" + self.wannier_file = "wannier.xyz" + self.type_map = [ + "Na", + "S", + "O", + "N", + "Cl", + "H", + ] + self.sel_type = [0, 2] + + self.data_sorted = dpdata_read_cp2k_dplr_data( + str(Path(__file__).parent / "data-sample/cp2k_wannier_sort_test/sorted"), + self.cp2k_output, + self.wannier_file, + self.type_map, + self.sel_type, + ) + self.data_random = dpdata_read_cp2k_dplr_data( + str(Path(__file__).parent / "data-sample/cp2k_wannier_sort_test/random"), + self.cp2k_output, + self.wannier_file, + self.type_map, + self.sel_type, + ) + + self.random_ids = np.loadtxt( + str( + Path(__file__).parent + / "data-sample/cp2k_wannier_sort_test/random_ids.txt" + ), + dtype=int, + ) + + def test_coord(self): + np.testing.assert_allclose( + self.data_sorted.data["coords"][0][self.random_ids], + self.data_random.data["coords"][0], + ) + + def test_consistent(self): + np.testing.assert_allclose( + self.data_sorted.data["atomic_dipole"][0][self.random_ids], + self.data_random.data["atomic_dipole"][0], + atol=1e-7, + ) + + if __name__ == "__main__": unittest.main()