From df73cb470013b7431a5b791f2f8161055115ec2d Mon Sep 17 00:00:00 2001 From: link89 Date: Tue, 24 Dec 2024 01:16:42 +0000 Subject: [PATCH] deploy: db8263bbcd3c8f324a3c18978077de8be4db98e8 --- _modules/ai2_kit/feat/spectrum/viber.html | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/_modules/ai2_kit/feat/spectrum/viber.html b/_modules/ai2_kit/feat/spectrum/viber.html index 0fa88a4..7c79939 100644 --- a/_modules/ai2_kit/feat/spectrum/viber.html +++ b/_modules/ai2_kit/feat/spectrum/viber.html @@ -606,8 +606,15 @@

Source code for ai2_kit.feat.spectrum.viber

     # get cell
     cell = dp_sys.data['cells'][0]
 
+    # get selected atoms ids
+    symbols = np.array(dp_sys.data["atom_names"])[dp_sys.data["atom_types"]]
+    sel_ids = [np.where(symbols == atype)[0] for atype in lumped_dict.keys()]
+    sel_ids = np.concatenate(sel_ids)
+    
     # build the data of atomic_dipole and atomic_polarizability with numpy
     wannier_atoms = ase.io.read(os.path.join(data_dir, wannier), index=":", format='extxyz')
+    n_frames = len(wannier_atoms)
+    n_atoms = np.sum(np.logical_not(wannier_atoms[0].symbols == wacent_symbol))
 
     lumped_dict_c = lumped_dict.copy()
     del_list = []
@@ -622,7 +629,9 @@ 

Source code for ai2_kit.feat.spectrum.viber

     wfc_compute_polar = _set_lumped_wfc(stc_list, lumped_dict_c, cutoff, wacent_symbol, to_polar = True)
     wfc_save = _set_lumped_wfc(stc_list, lumped_dict_c, cutoff, wacent_symbol, to_polar = False)
 
-    dp_sys.data['atomic_dipole'] = wfc_save
+    wannier = np.zeros([n_frames, n_atoms, 3])
+    wannier[:, sel_ids, :] = np.reshape(wfc_save, [n_frames, -1, 3])
+    dp_sys.data["atomic_dipole"] = wannier
 
     if mode == 'both':
         wannier_atoms_x = ase.io.read(os.path.join(data_dir, wannier_x), index=":", format='extxyz')
@@ -643,9 +652,9 @@ 

Source code for ai2_kit.feat.spectrum.viber

         polar[:, :, 1] = (wfc_y - wfc_compute_polar) / eps
         polar[:, :, 2] = (wfc_z - wfc_compute_polar) / eps
 
-        dp_sys.data['atomic_polarizability'] = polar.reshape(polar.shape[0], -1)
+        dp_sys.data['atomic_polarizability'] = np.reshape(polar, [n_frames, n_atoms, 9])
     elif mode == 'dipole_only':
-        dp_sys.data['atomic_polarizability'] = np.array([-1,-1]).reshape(1,-1)
+        dp_sys.data['atomic_polarizability'] = np.full((n_frames, n_atoms, 9), -1.0)
     else:
         logger.warning(f"There is no mode called '{mode}', expected 'both' or 'dipole_only'")