Skip to content

Commit

Permalink
add the option to allow the reconstruction from discrete rep
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Oct 7, 2024
1 parent 205c131 commit 0df8653
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 41 deletions.
7 changes: 5 additions & 2 deletions pyxtal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,13 +1666,14 @@ def build(self, group, species, numIons, lattice, sites, tol=1e-2, dim=3, use_ha
elif len(wp) == 4: # tuple:
(key, x, y, z) = wp
_wp = choose_wyckoff(self.group, site=key, dim=dim)
# print(key, x, y, z, _wp.get_label())
#print('debug build', key, x, y, z, _wp.get_label())
if _wp is not False:
if _wp.get_dof() == 0: # fixed pos
pt = [0.0, 0.0, 0.0]
else:
ans = _wp.get_all_positions([x, y, z])
pt = ans[0] if ans is not None else None
#print('debug build', ans, x, y, z)
# print('debug', ans)
if pt is not None:
_sites.append(atom_site(_wp, pt, sp))
Expand Down Expand Up @@ -3844,10 +3845,12 @@ def from_tabular_representation(

# Conversion from discrete to continuous
if discrete:
#print('discrete', x, y, z)
[x, y, z] = wp.from_discrete_grid([x, y, z], N_grids)
#print('conversion', x, y, z)

# ; print(wp.get_label(), xyz)
xyz = wp.search_generator([x, y, z], tol=tol)
xyz = wp.search_generator([x, y, z], tol=tol, symmetrize=True)
if xyz is not None:
xyz, wp, _ = wp.merge(xyz, np.eye(3), 1e-3)
label = wp.get_label()
Expand Down
101 changes: 63 additions & 38 deletions pyxtal/lego/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,31 +795,12 @@ def optimize_xtals_serial(self, xtals, args):
xtals_opt.append(xtal)
return xtals_opt

def optimize_reps(self, reps, ncpu=1, opt_type='local',
T=0.2, niter=20, early_quit=0.02,
add_db=True, symmetrize=False,
minimizers=[('Nelder-Mead', 100), ('L-BFGS-B', 100)],
discrete=False):
"""
Perform optimization for each structure
Args:
reps: list of reps
ncpu (int):
"""
args = (opt_type, T, niter, early_quit, add_db, symmetrize, minimizers)
if ncpu > 1:
valid_xtals = self.optimize_reps_mproc(reps, ncpu, args, discrete)
return valid_xtals
else:
raise NotImplementedError("optimize_reps works in parallel mode")

def optimize_reps_mproc(self, reps, ncpu, args, discrete):
def optimize_xtals_mproc(self, xtals, ncpu, args):
"""
Optimization in multiprocess mode.
Args:
reps: list of reps
xtals: list of xtals
ncpu (int): number of parallel python processes
args: (opt_type, T, n_iter, early_quit, add_db, symmetrize, minimizers)
"""
Expand All @@ -830,11 +811,10 @@ def optimize_reps_mproc(self, reps, ncpu, args, discrete):

# Split the input structures to minibatches
N_batches = 10 * ncpu
for _i, i in enumerate(range(0, len(reps), N_batches)):
start, end = i, min([i+N_batches, len(reps)])
for _i, i in enumerate(range(0, len(xtals), N_batches)):
start, end = i, min([i+N_batches, len(xtals)])
ids = list(range(start, end))
print(f"Rank {self.rank} minibatch {start} {end}")
self.logging.info(f"Rank {self.rank} minibatch {start} {end}")
self.print_memory_usage()

def generate_args():
Expand All @@ -845,14 +825,10 @@ def generate_args():
_ids = ids[j::ncpu]
wp_libs = []
for id in _ids:
rep = reps[id]
xtal = pyxtal()
xtal.from_tabular_representation(rep,
normalize=False,
discrete=discrete)
xtal = xtals[id]
x = xtal.get_1d_rep_x()
spg, wps, _ = self.get_input_from_ref_xtal(xtal)
wp_libs.append((x, spg, wps))
wp_libs.append((x, xtal.group.number, wps))
yield (self.dim, wp_libs, self.elements, self.calculator,
self.ref_environments, opt_type, T, niter,
early_quit, minimizers)
Expand All @@ -877,15 +853,59 @@ def generate_args():
gc.collect() # Explicitly call garbage collector to free memory

xtals_opt = list(xtals_opt)
print(f"Rank {self.rank} finish optimize_reps_mproc {len(xtals_opt)}")
print(f"Rank {self.rank} finish optimize_xtals_mproc {len(xtals_opt)}")
return xtals_opt

def optimize_xtals_mproc(self, xtals, ncpu, args):
def optimize_reps(self, reps, ncpu=1, opt_type='local',
T=0.2, niter=20, early_quit=0.02,
add_db=True, symmetrize=False,
minimizers=[('Nelder-Mead', 100), ('L-BFGS-B', 100)],
discrete=False):
"""
Perform optimization for each structure
Args:
reps: list of reps
ncpu (int):
"""
args = (opt_type, T, niter, early_quit, add_db, symmetrize, minimizers)
if ncpu == 1:
valid_xtals = self.optimize_reps_serial(reps, args, discrete)
else:
valid_xtals = self.optimize_reps_mproc(reps, ncpu, args, discrete)
return valid_xtals

def optimize_reps_serial(self, reps, args, discrete):
"""
Optimization in multiprocess mode.
Args:
xtals: list of xtals
reps: list of reps
ncpu (int): number of parallel python processes
args: (opt_type, T, n_iter, early_quit, add_db, symmetrize, minimizers)
"""
xtals_opt = []
for i, rep in enumerate(reps):
#print('start', i, rep, len(rep))
xtal = pyxtal()
xtal.from_tabular_representation(rep,
normalize=False,
discrete=discrete)
#print(xtal.get_xtal_string())
#print(xtal)
xtal, sim, _xs = self.optimize_xtal(xtal, i, *args)
if xtal is not None:
xtals_opt.append(xtal)
else:
import sys; sys.exit()
return xtals_opt

def optimize_reps_mproc(self, reps, ncpu, args, discrete):
"""
Optimization in multiprocess mode.
Args:
reps: list of reps
ncpu (int): number of parallel python processes
args: (opt_type, T, n_iter, early_quit, add_db, symmetrize, minimizers)
"""
Expand All @@ -896,10 +916,11 @@ def optimize_xtals_mproc(self, xtals, ncpu, args):

# Split the input structures to minibatches
N_batches = 10 * ncpu
for _i, i in enumerate(range(0, len(xtals), N_batches)):
start, end = i, min([i+N_batches, len(xtals)])
for _i, i in enumerate(range(0, len(reps), N_batches)):
start, end = i, min([i+N_batches, len(reps)])
ids = list(range(start, end))
print(f"Rank {self.rank} minibatch {start} {end}")
self.logging.info(f"Rank {self.rank} minibatch {start} {end}")
self.print_memory_usage()

def generate_args():
Expand All @@ -910,10 +931,14 @@ def generate_args():
_ids = ids[j::ncpu]
wp_libs = []
for id in _ids:
xtal = xtals[id]
rep = reps[id]
xtal = pyxtal()
xtal.from_tabular_representation(rep,
normalize=False,
discrete=discrete)
x = xtal.get_1d_rep_x()
spg, wps, _ = self.get_input_from_ref_xtal(xtal)
wp_libs.append((x, xtal.group.number, wps))
wp_libs.append((x, spg, wps))
yield (self.dim, wp_libs, self.elements, self.calculator,
self.ref_environments, opt_type, T, niter,
early_quit, minimizers)
Expand All @@ -938,7 +963,7 @@ def generate_args():
gc.collect() # Explicitly call garbage collector to free memory

xtals_opt = list(xtals_opt)
print(f"Rank {self.rank} finish optimize_xtals_mproc {len(xtals_opt)}")
print(f"Rank {self.rank} finish optimize_reps_mproc {len(xtals_opt)}")
return xtals_opt

def optimize_xtal(self, xtal, count=0, opt_type='local',
Expand Down
5 changes: 4 additions & 1 deletion pyxtal/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2918,14 +2918,15 @@ def search_generator_dist(self, pt, lattice=None, group=None):
min_index = np.argmin(distances)
return pts[min_index], np.min(distances)

def search_generator(self, pos, ops=None, tol=1e-2):
def search_generator(self, pos, ops=None, tol=1e-2, symmetrize=False):
"""
search generator for a special Wyckoff position
Args:
pos: initial xyz position
ops: list of symops
tol: tolerance
symmetrize (bool): apply symmetrization
Return:
pos1: the position that matchs the standard setting
Expand All @@ -2945,6 +2946,8 @@ def search_generator(self, pos, ops=None, tol=1e-2):
if diff.sum() < tol:
pos1 -= np.floor(pos1)
match = True
if symmetrize:
pos1 = pos0
break

if match:
Expand Down

0 comments on commit 0df8653

Please sign in to comment.