From 11630be1259b41325815dc5f6c913e7e899a79c3 Mon Sep 17 00:00:00 2001 From: qzhu2017 Date: Fri, 15 Mar 2024 20:55:20 -0400 Subject: [PATCH] output splitter for get_transition function --- pyxtal/__init__.py | 34 ++++++++++++++++++----------- pyxtal/supergroup.py | 52 ++++++++++++++++++-------------------------- pyxtal/test_all.py | 2 +- 3 files changed, 43 insertions(+), 45 deletions(-) diff --git a/pyxtal/__init__.py b/pyxtal/__init__.py index b9eaa639..4027fdcc 100644 --- a/pyxtal/__init__.py +++ b/pyxtal/__init__.py @@ -122,7 +122,7 @@ class pyxtal: >>> s2 = pyxtal() >>> s1.from_seed("pyxtal/database/cifs/0-G62.cif") #structure with low symmetry >>> s2.from_seed("pyxtal/database/cifs/2-G71.cif") #structure with high symmetry - >>> strucs, _, _, _ = s2.get_transition(s1) # get the transition from high to low + >>> strucs, _, _, _, _ = s2.get_transition(s1) # get the transition from high to low >>> strucs [ ------Crystal from Transition 0 0.000------ @@ -735,9 +735,10 @@ def subgroup_by_path(self, gtypes, ids, eps=0, mut_lat=False): Returns: a pyxtal structure with lower symmetries + list of splitters """ struc = self.copy() - + splitters = [] G = self.group for g_type, id in zip(gtypes, ids): if self.molecular: @@ -750,8 +751,9 @@ def subgroup_by_path(self, gtypes, ids, eps=0, mut_lat=False): struc = struc._subgroup_by_splitter(splitter, eps=eps, mut_lat=mut_lat) if struc is None: return None + splitters.append(splitter) G = splitter.H - return struc + return struc, splitters def subgroup_once(self, eps=0.1, H=None, perms=None, group_type='t', \ max_cell=4, min_cell=0, mut_lat=True, ignore_special=False): @@ -1806,7 +1808,7 @@ def to_subgroup(self, path=None, t_only=True, iterate=False, species=None): for p in path: gtypes.append(p[0]) ids.append(p[1]) - sub = self.subgroup_by_path(gtypes, ids, eps=0) + sub, _ = self.subgroup_by_path(gtypes, ids, eps=0) sub.optimize_lattice() sub.source = "subgroup" else: @@ -2224,16 +2226,17 @@ def get_transition(self, ref_struc, d_tol=1.0, d_tol2=0.3, N_images=2, max_path= good_disps = [] good_paths = [] good_trans = [] + good_splitters = [] for p in paths: r = self.get_transition_by_path(ref_struc, p, d_tol, d_tol2, N_images, both) - (strucs, disp, tran, count) = r + (strucs, disp, tran, count, sps) = r if count == 0: # prepare more paths to increase diversity add_paths = self.group.add_k_transitions(p) for p0 in add_paths: r = self.get_transition_by_path(ref_struc, p0, d_tol, d_tol2, N_images, both) - (strucs, disp, tran, count) = r + (strucs, disp, tran, count, sps) = r if strucs is not None: if strucs[-1].disp < d_tol2: #stop return strucs, disp, tran, p0 @@ -2243,16 +2246,18 @@ def get_transition(self, ref_struc, d_tol=1.0, d_tol2=0.3, N_images=2, max_path= good_paths.append(p0) good_strucs.append(strucs) good_trans.append(tran) + good_splitters.append(sps) else: if strucs is not None: if strucs[-1].disp < d_tol2: - return strucs, disp, tran, p + return strucs, disp, tran, p, sps else: good_ds.append(strucs[-1].disp) good_disps.append(disp) good_paths.append(p) good_strucs.append(strucs) good_trans.append(tran) + good_splitters.append(sps) # Early stop if len(good_ds) > 5: break @@ -2260,12 +2265,12 @@ def get_transition(self, ref_struc, d_tol=1.0, d_tol2=0.3, N_images=2, max_path= #print("Number of candidate path:", len(good_ds)) good_ds = np.array(good_ds) id = np.argmin(good_ds) - return good_strucs[id], good_disps[id], good_trans[id], good_paths[id] + return good_strucs[id], good_disps[id], good_trans[id], good_paths[id], good_splitters[id] if Skipped > 0: print("Warning: ignore some solutions: ", Skipped) - return None, None, None, p + return None, None, None, p, None def get_transition_by_path(self, ref_struc, path, d_tol, d_tol2=0.5, N_images=2, both=False): """ @@ -2321,6 +2326,7 @@ def get_transition_by_path(self, ref_struc, path, d_tol, d_tol2=0.5, N_images=2, refs = [] trans = [] ds = [] + splitters = [] for sol in sols: _sites = deepcopy(sites_G) @@ -2365,7 +2371,7 @@ def get_transition_by_path(self, ref_struc, path, d_tol, d_tol2=0.5, N_images=2, # make subgroup if match: count_match += 1 - s = self.subgroup_by_path(g_types, ids=sol, eps=0) + s, sps = self.subgroup_by_path(g_types, ids=sol, eps=0) if s is not None: disp, tran, s, max_disp = ref_struc.get_disps_sets(s, d_tol, d_tol2) #import sys; sys.exit() @@ -2374,12 +2380,13 @@ def get_transition_by_path(self, ref_struc, path, d_tol, d_tol2=0.5, N_images=2, if max_disp < d_tol2: cell = s.lattice.matrix strucs = ref_struc.make_transitions(disp, cell, tran, N_images, both) - return strucs, disp, tran, count_match + return strucs, disp, tran, count_match, sps else: disps.append(disp) refs.append(s) trans.append(tran) ds.append(max_disp) + splitters.append(sps) if len(ds) > 0: ds = np.array(ds) @@ -2387,11 +2394,12 @@ def get_transition_by_path(self, ref_struc, path, d_tol, d_tol2=0.5, N_images=2, cell = refs[id].lattice.matrix tran = trans[id] disp = disps[id] + sps = splitters[id] strucs = ref_struc.make_transitions(disp, cell, tran, N_images, both) - return strucs, disp, tran, count_match + return strucs, disp, tran, count_match, splitters else: - return None, None, None, count_match + return None, None, None, count_match, None def translate(self, trans, reset_wp=False): """ diff --git a/pyxtal/supergroup.py b/pyxtal/supergroup.py index c81bdfa1..97ced5b2 100644 --- a/pyxtal/supergroup.py +++ b/pyxtal/supergroup.py @@ -895,7 +895,7 @@ def print_detail(self, solution, coords_H, coords_G, elements): disps.append(dis_abs) print(output) output = "Cell: {:7.3f}{:7.3f}{:7.3f}".format(*translation) - output += ", Disp (A): {:6.3f}".format(max(max_disps)) + output += ", Disp (A): {:6.3f}".format(max(disps)) print(output) def sort_solutions(self, solutions): @@ -1025,35 +1025,24 @@ def _make_pyxtal(self, sp, coords, elements=None, run_type=0, check=True): return struc -class symmetry_mapper(): - """ - Class to map the symmetry relation between two structures - - Args: - struc_H: pyxtal structure with low symmetry (H) - struc_G: pyxtal structure with high symmetry (G) - max_d: maximum displacement to be considered - """ - def __init__(self, struc_H, struc_G, max_d=1.0): - # initilize the necesary parameters - G = struc_G.group - H = struc_H.group - elements, sites = struc._get_elements_and_sites() - paths = G.search_subgroup_paths(H.number) - for path in paths: - sols = self.along_path(path) - for i, sol in enumerate(sols): - max_disp, trans, mapping, sp = self.calc_disps(id, sol, self.max_d) - if max_disp < d_tol: - break +#class symmetry_mapper(): +# """ +# Class to map the symmetry relation between two structures +# QZ: not needed now +# Args: +# struc_H: pyxtal structure with low symmetry (H) +# struc_G: pyxtal structure with high symmetry (G) +# max_d: maximum displacement to be considered +# """ +# def __init__(self, struc_H, struc_G, max_d=1.0): +# # initilize the necesary parameters +# G = struc_G.group +# H = struc_H.group +# elements, sites = struc_G._get_elements_and_sites() +# strucs, disp, cell, path, gts, sols = struc_G.get_transition(struc_H, d_tol=max_d) +# if path is not None: +# struc_G.subgroup_by_path(gts, sols) - #def along_path(self, path): - # #letters_H - # #letters_G - # for ele in eles: - # # 1, make_subgroup to H - # # 2, enumerate all wycs - # return class supergroups(): """ @@ -1130,13 +1119,14 @@ def print_solutions(self): output = "Cell: {:7.3f}{:7.3f}{:7.3f}".format(*trans) output += ", Disp (A): {:6.3f}".format(max_disp) print(output) + #print('mapping', mapping) for i, wp2 in enumerate(sp.wp2_lists): wp1 = sp.wp1_lists[i] ele = sp.elements[i] l2 = wp1.get_label() for j, wp in enumerate(wp2): l1 = wp.get_label() - output = "{:2s} [{:2d}]: ".format(ele, mapping[i][j]) + output = "{:2s} [{:2d}]: ".format(ele, mapping[i]) output += "{:3s} -> {:3s}".format(l1, l2) print(output) @@ -1150,7 +1140,7 @@ def get_transformation(self, N_images=2): Returns: a series of pyxtal structures """ - #self.print_solutions() + # self.print_solutions() # derive the backward subgroup representation struc0 = self.strucs[-1] for i in range(1, len(self.solutions)+1): diff --git a/pyxtal/test_all.py b/pyxtal/test_all.py index 5b801507..912f6289 100644 --- a/pyxtal/test_all.py +++ b/pyxtal/test_all.py @@ -295,7 +295,7 @@ def test_similarity(self): s2.from_seed(cif_path + cif2 + '.cif') pmg_s2 = s2.to_pymatgen() - strucs, _, _, _ = s2.get_transition(s1) + strucs, _, _, _, _ = s2.get_transition(s1) if strucs is None: print("Problem between ", cif1, cif2)