Skip to content

Commit

Permalink
output splitter for get_transition function
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Mar 16, 2024
1 parent 789b702 commit 11630be
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 45 deletions.
34 changes: 21 additions & 13 deletions pyxtal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class pyxtal:
>>> s2 = pyxtal()
>>> s1.from_seed("pyxtal/database/cifs/0-G62.cif") #structure with low symmetry
>>> s2.from_seed("pyxtal/database/cifs/2-G71.cif") #structure with high symmetry
>>> strucs, _, _, _ = s2.get_transition(s1) # get the transition from high to low
>>> strucs, _, _, _, _ = s2.get_transition(s1) # get the transition from high to low
>>> strucs
[
------Crystal from Transition 0 0.000------
Expand Down Expand Up @@ -735,9 +735,10 @@ def subgroup_by_path(self, gtypes, ids, eps=0, mut_lat=False):
Returns:
a pyxtal structure with lower symmetries
list of splitters
"""
struc = self.copy()

splitters = []
G = self.group
for g_type, id in zip(gtypes, ids):
if self.molecular:
Expand All @@ -750,8 +751,9 @@ def subgroup_by_path(self, gtypes, ids, eps=0, mut_lat=False):
struc = struc._subgroup_by_splitter(splitter, eps=eps, mut_lat=mut_lat)
if struc is None:
return None
splitters.append(splitter)
G = splitter.H
return struc
return struc, splitters

def subgroup_once(self, eps=0.1, H=None, perms=None, group_type='t', \
max_cell=4, min_cell=0, mut_lat=True, ignore_special=False):
Expand Down Expand Up @@ -1806,7 +1808,7 @@ def to_subgroup(self, path=None, t_only=True, iterate=False, species=None):
for p in path:
gtypes.append(p[0])
ids.append(p[1])
sub = self.subgroup_by_path(gtypes, ids, eps=0)
sub, _ = self.subgroup_by_path(gtypes, ids, eps=0)
sub.optimize_lattice()
sub.source = "subgroup"
else:
Expand Down Expand Up @@ -2224,16 +2226,17 @@ def get_transition(self, ref_struc, d_tol=1.0, d_tol2=0.3, N_images=2, max_path=
good_disps = []
good_paths = []
good_trans = []
good_splitters = []

for p in paths:
r = self.get_transition_by_path(ref_struc, p, d_tol, d_tol2, N_images, both)
(strucs, disp, tran, count) = r
(strucs, disp, tran, count, sps) = r
if count == 0:
# prepare more paths to increase diversity
add_paths = self.group.add_k_transitions(p)
for p0 in add_paths:
r = self.get_transition_by_path(ref_struc, p0, d_tol, d_tol2, N_images, both)
(strucs, disp, tran, count) = r
(strucs, disp, tran, count, sps) = r
if strucs is not None:
if strucs[-1].disp < d_tol2: #stop
return strucs, disp, tran, p0
Expand All @@ -2243,29 +2246,31 @@ def get_transition(self, ref_struc, d_tol=1.0, d_tol2=0.3, N_images=2, max_path=
good_paths.append(p0)
good_strucs.append(strucs)
good_trans.append(tran)
good_splitters.append(sps)
else:
if strucs is not None:
if strucs[-1].disp < d_tol2:
return strucs, disp, tran, p
return strucs, disp, tran, p, sps
else:
good_ds.append(strucs[-1].disp)
good_disps.append(disp)
good_paths.append(p)
good_strucs.append(strucs)
good_trans.append(tran)
good_splitters.append(sps)
# Early stop
if len(good_ds) > 5:
break
if len(good_ds) > 0:
#print("Number of candidate path:", len(good_ds))
good_ds = np.array(good_ds)
id = np.argmin(good_ds)
return good_strucs[id], good_disps[id], good_trans[id], good_paths[id]
return good_strucs[id], good_disps[id], good_trans[id], good_paths[id], good_splitters[id]

if Skipped > 0:
print("Warning: ignore some solutions: ", Skipped)

return None, None, None, p
return None, None, None, p, None

def get_transition_by_path(self, ref_struc, path, d_tol, d_tol2=0.5, N_images=2, both=False):
"""
Expand Down Expand Up @@ -2321,6 +2326,7 @@ def get_transition_by_path(self, ref_struc, path, d_tol, d_tol2=0.5, N_images=2,
refs = []
trans = []
ds = []
splitters = []

for sol in sols:
_sites = deepcopy(sites_G)
Expand Down Expand Up @@ -2365,7 +2371,7 @@ def get_transition_by_path(self, ref_struc, path, d_tol, d_tol2=0.5, N_images=2,
# make subgroup
if match:
count_match += 1
s = self.subgroup_by_path(g_types, ids=sol, eps=0)
s, sps = self.subgroup_by_path(g_types, ids=sol, eps=0)
if s is not None:
disp, tran, s, max_disp = ref_struc.get_disps_sets(s, d_tol, d_tol2)
#import sys; sys.exit()
Expand All @@ -2374,24 +2380,26 @@ def get_transition_by_path(self, ref_struc, path, d_tol, d_tol2=0.5, N_images=2,
if max_disp < d_tol2:
cell = s.lattice.matrix
strucs = ref_struc.make_transitions(disp, cell, tran, N_images, both)
return strucs, disp, tran, count_match
return strucs, disp, tran, count_match, sps
else:
disps.append(disp)
refs.append(s)
trans.append(tran)
ds.append(max_disp)
splitters.append(sps)

if len(ds) > 0:
ds = np.array(ds)
id = np.argmin(ds)
cell = refs[id].lattice.matrix
tran = trans[id]
disp = disps[id]
sps = splitters[id]
strucs = ref_struc.make_transitions(disp, cell, tran, N_images, both)
return strucs, disp, tran, count_match
return strucs, disp, tran, count_match, splitters

else:
return None, None, None, count_match
return None, None, None, count_match, None

def translate(self, trans, reset_wp=False):
"""
Expand Down
52 changes: 21 additions & 31 deletions pyxtal/supergroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ def print_detail(self, solution, coords_H, coords_G, elements):
disps.append(dis_abs)
print(output)
output = "Cell: {:7.3f}{:7.3f}{:7.3f}".format(*translation)
output += ", Disp (A): {:6.3f}".format(max(max_disps))
output += ", Disp (A): {:6.3f}".format(max(disps))
print(output)

def sort_solutions(self, solutions):
Expand Down Expand Up @@ -1025,35 +1025,24 @@ def _make_pyxtal(self, sp, coords, elements=None, run_type=0, check=True):
return struc


class symmetry_mapper():
"""
Class to map the symmetry relation between two structures
Args:
struc_H: pyxtal structure with low symmetry (H)
struc_G: pyxtal structure with high symmetry (G)
max_d: maximum displacement to be considered
"""
def __init__(self, struc_H, struc_G, max_d=1.0):
# initilize the necesary parameters
G = struc_G.group
H = struc_H.group
elements, sites = struc._get_elements_and_sites()
paths = G.search_subgroup_paths(H.number)
for path in paths:
sols = self.along_path(path)
for i, sol in enumerate(sols):
max_disp, trans, mapping, sp = self.calc_disps(id, sol, self.max_d)
if max_disp < d_tol:
break
#class symmetry_mapper():
# """
# Class to map the symmetry relation between two structures
# QZ: not needed now
# Args:
# struc_H: pyxtal structure with low symmetry (H)
# struc_G: pyxtal structure with high symmetry (G)
# max_d: maximum displacement to be considered
# """
# def __init__(self, struc_H, struc_G, max_d=1.0):
# # initilize the necesary parameters
# G = struc_G.group
# H = struc_H.group
# elements, sites = struc_G._get_elements_and_sites()
# strucs, disp, cell, path, gts, sols = struc_G.get_transition(struc_H, d_tol=max_d)
# if path is not None:
# struc_G.subgroup_by_path(gts, sols)

#def along_path(self, path):
# #letters_H
# #letters_G
# for ele in eles:
# # 1, make_subgroup to H
# # 2, enumerate all wycs
# return

class supergroups():
"""
Expand Down Expand Up @@ -1130,13 +1119,14 @@ def print_solutions(self):
output = "Cell: {:7.3f}{:7.3f}{:7.3f}".format(*trans)
output += ", Disp (A): {:6.3f}".format(max_disp)
print(output)
#print('mapping', mapping)
for i, wp2 in enumerate(sp.wp2_lists):
wp1 = sp.wp1_lists[i]
ele = sp.elements[i]
l2 = wp1.get_label()
for j, wp in enumerate(wp2):
l1 = wp.get_label()
output = "{:2s} [{:2d}]: ".format(ele, mapping[i][j])
output = "{:2s} [{:2d}]: ".format(ele, mapping[i])
output += "{:3s} -> {:3s}".format(l1, l2)
print(output)

Expand All @@ -1150,7 +1140,7 @@ def get_transformation(self, N_images=2):
Returns:
a series of pyxtal structures
"""
#self.print_solutions()
# self.print_solutions()
# derive the backward subgroup representation
struc0 = self.strucs[-1]
for i in range(1, len(self.solutions)+1):
Expand Down
2 changes: 1 addition & 1 deletion pyxtal/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def test_similarity(self):
s2.from_seed(cif_path + cif2 + '.cif')
pmg_s2 = s2.to_pymatgen()

strucs, _, _, _ = s2.get_transition(s1)
strucs, _, _, _, _ = s2.get_transition(s1)

if strucs is None:
print("Problem between ", cif1, cif2)
Expand Down

0 comments on commit 11630be

Please sign in to comment.