Skip to content

Commit

Permalink
fix #271 by adding subprocess control and output logger for mpi debug
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Sep 6, 2024
1 parent d83e770 commit 2167d89
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 46 deletions.
6 changes: 3 additions & 3 deletions pyxtal/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def copy(self, db_name, csd_codes):
"""
if db_name == self.db_name:
raise RuntimeError("Cannot use the same db file for copy")
with connect(db_name) as db:
with connect(db_name, serial=True) as db:
for csd_code in csd_codes:
row_info = self.get_row_info(code=csd_code)
(atom, kvp, data) = row_info
Expand Down Expand Up @@ -603,7 +603,7 @@ def add_strucs_from_db(self, db_file, check=False, tol=0.1, freq=50):
print(f"\nAdding new strucs from {db_file:s}")

count = 0
with connect(db_file) as db:
with connect(db_file, serial=True) as db:
for row in db.select():
atoms = row.toatoms()
xtal = pyxtal()
Expand Down Expand Up @@ -1448,7 +1448,7 @@ def get_db_unique(self, db_name=None, prec=3):
unique_props[prop_key] = (row.id, dof)

ids = [unique_props[key][0] for key in unique_props.keys()]
with connect(db_name) as db:
with connect(db_name, serial=True) as db:
for id in ids:
row = self.db.get(id)
kvp = {}
Expand Down
78 changes: 54 additions & 24 deletions pyxtal/interface/charmm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import os
import shutil

import subprocess
import numpy as np


Expand Down Expand Up @@ -36,11 +36,14 @@ def __init__(
output="charmm.log",
dump="result.pdb",
debug=False,
timeout=20,
):
self.errorE = 1e+5
self.error = False
if steps is None:
steps = [2000, 1000]
self.debug = debug

self.timeout = timeout
# check charmm Executable
if shutil.which(exe) is None:
raise BaseException(f"{exe} is not installed")
Expand Down Expand Up @@ -101,16 +104,32 @@ def run(self, clean=True):
os.chdir(self.folder)

self.write() # ; print("write", time()-t0)
self.execute() # ; print("exe", time()-t0)
self.read() # ; print("read", self.structure.energy)
res = self.execute() # ; print("exe", time()-t0)
if res is not None:
self.read() # ; print("read", self.structure.energy)
else:
self.structure.energy = self.errorE
self.error = True
if clean:
self.clean()

os.chdir(cwd)

def execute(self):
cmd = self.exe + "<" + self.input + ">" + self.output
os.system(cmd)
cmd = self.exe + " < " + self.input + " > " + self.output
# os.system(cmd)
with open(os.devnull, 'w') as devnull:
try:
# Run the external command with a timeout
result = subprocess.run(
cmd, shell=True, timeout=self.timeout, check=True, stderr=devnull)
return result.returncode # Or handle the result as needed
except subprocess.CalledProcessError as e:
print(f"Command '{cmd}' failed with return code {e.returncode}.")
return None
except subprocess.TimeoutExpired:
print(f"External command {cmd} timed out.")
return None

def clean(self):
os.remove(self.input)
Expand All @@ -129,7 +148,8 @@ def write(self):

a, b, c, alpha, beta, gamma = lat.get_para(degree=True)
ltype = lat.ltype
if ltype in ['trigonal', 'Trigonal']: ltype = 'hexagonal'
if ltype in ['trigonal', 'Trigonal']:
ltype = 'hexagonal'

fft = self.FFTGrid(np.array([a, b, c]))

Expand All @@ -148,14 +168,16 @@ def write(self):
if self.atom_info is None:
f.write(f"U0{site.type:d} ")
else:
f.write("{:s} ".format(self.atom_info["resName"][site.type]))
f.write("{:s} ".format(
self.atom_info["resName"][site.type]))

f.write("\ngenerate main first none last none setup warn\n")
f.write("Read coor card free\n")
f.write("* Residues coordinate\n*\n")
f.write(f"{sum(atom_count):5d}\n")
for i, site in enumerate(self.structure.mol_sites):
res_name = f"U0{site.type:d}" if self.atom_info is None else self.atom_info["resName"][site.type]
res_name = f"U0{site.type:d}" if self.atom_info is None else self.atom_info[
"resName"][site.type]

# reset lattice if needed (to move out later)
site.lattice = lat
Expand All @@ -179,12 +201,13 @@ def write(self):
j + 1 + count, i + 1, res_name, label, *coord
)
)
# quickly check if
# quickly check if
if abs(coord).max() > 500.0:
print("Unexpectedly large input coordinates, stop and debug")
print(self.structure)
self.structure.to_file('bug.cif')
import sys; sys.exit()
import sys
sys.exit()

f.write(f"write psf card name {self.psf:s}\n")
f.write(f"write coor crd card name {self.crd:s}\n")
Expand All @@ -204,26 +227,31 @@ def write(self):
f.write("coor stat select all end\n")
f.write("Crystal Define @shape @a @b @c @alpha @beta @gamma\n")
site0 = self.structure.mol_sites[0]
f.write(f"Crystal Build cutoff 14.0 noperations {len(site0.wp.ops) - 1:d}\n")
f.write(
f"Crystal Build cutoff 14.0 noperations {len(site0.wp.ops) - 1:d}\n")
for i, op in enumerate(site0.wp.ops):
if i > 0:
f.write(f"({op.as_xyz_str():s})\n")

f.write("image byres xcen ?xave ycen ?yave zcen ?zave sele resn LIG end\n")
f.write(
"image byres xcen ?xave ycen ?yave zcen ?zave sele resn LIG end\n")
f.write("set 7 fswitch\n")
f.write("set 8 atom\n")
f.write("set 9 vatom\n")
f.write("Update inbfrq 10 imgfrq 10 ihbfrq 10 -\n")
f.write("ewald pmewald lrc fftx {:d} ffty {:d} fftz {:d} -\n".format(*fft))
f.write(
"ewald pmewald lrc fftx {:d} ffty {:d} fftz {:d} -\n".format(*fft))
f.write("kappa 0.34 order 6 CTOFNB 12.0 CUTNB 14.0 QCOR 1.0 -\n")
f.write("@7 @8 @9 vfswitch !\n")
f.write(f"mini {self.algo:s} nstep {self.steps[0]:d}\n")
if len(self.steps) > 1:
f.write(f"mini {self.algo:s} lattice nstep {self.steps[1]:d} \n")
f.write(
f"mini {self.algo:s} lattice nstep {self.steps[1]:d} \n")
if len(self.steps) > 2:
f.write(f"mini {self.algo:s} nstep {self.steps[2]:d}\n")

f.write("coor conv SYMM FRAC ?xtla ?xtlb ?xtlc ?xtlalpha ?xtlbeta ?xtlgamma\n") #
f.write(
"coor conv SYMM FRAC ?xtla ?xtlb ?xtlc ?xtlalpha ?xtlbeta ?xtlgamma\n") #
f.write(f"\nwrite coor pdb name {self.dump:s}\n") #
f.write("*CELL : ?xtla ?xtlb ?xtlc ?xtlalpha ?xtlbeta ?xtlgamma\n") #
f.write(f"*Z = {len(site0.wp):d}\n")
Expand Down Expand Up @@ -270,7 +298,8 @@ def read(self):
XYZ = [float(x) for x in xyz]
positions.append(XYZ)
except:
pass # print("Warning: BAD charmm output: " + line)
# print("Warning: BAD charmm output: " + line)
pass
positions = np.array(positions)
self.structure.energy *= Z

Expand All @@ -283,7 +312,7 @@ def read(self):
# if True:
try:
for _i, site in enumerate(self.structure.mol_sites):
coords = positions[count : count + len(site.molecule.mol)]
coords = positions[count: count + len(site.molecule.mol)]
site.update(coords, self.structure.lattice)
count += len(site.molecule.mol)
# print("after relaxation : ", self.structure.lattice, "iter: ", self.structure.iter)
Expand All @@ -292,7 +321,8 @@ def read(self):
# print("after latticeopt : ", self.structure.lattice, self.structure.check_distance()); import sys; sys.exit()
except:
# molecular connectivity or lattice optimization
self.structure.energy = 10000
self.structure.energy = self.errorE
self.error = True
if self.debug:
print("Unable to retrieve Structure after optimization")
print("lattice", self.structure.lattice)
Expand All @@ -304,12 +334,11 @@ def read(self):
print("short distance pair", pairs)

else:
self.structure.energy = 10000
self.structure.energy = self.errorE
self.error = True
if self.debug:
print(self.structure)
import sys

sys.exit()
import sys; sys.exit()

def FFTGrid(self, ABC):
"""
Expand Down Expand Up @@ -616,7 +645,8 @@ def merge(self, rtf1=None, single=None):
for a in res["ANGL"]:
tmp = a.split("!")
tmp1 = tmp[0].split()
a1, a2, a3 = str(i) + tmp1[1], str(i) + tmp1[2], str(i) + tmp1[3]
a1, a2, a3 = str(
i) + tmp1[1], str(i) + tmp1[2], str(i) + tmp1[3]
a = f"ANGL {a1:6s} {a2:6s} {a3:6s} "
if len(tmp) > 1:
a += f"!{tmp[-1]:12s}"
Expand Down
15 changes: 10 additions & 5 deletions pyxtal/optimize/WFS.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
strs = self.full_str()
self.logging.info(strs)
print(strs)
print(f"Rank {self.rank} finish initialization {self.tag}")

def full_str(self):
s = str(self)
Expand All @@ -168,17 +169,15 @@ def _run(self, pool=None):
# Related to the FF optimization
N_added = 0
success_rate = 0
print(f"Rank {self.rank} starts WFS in {self.tag}")

for gen in range(self.N_gen):
self.generation = gen
cur_xtals = None
print(f"Rank {self.rank} entering generation {gen} in {self.tag}")

self.logging.info(f"Gen {gen} starts in Rank {self.rank} {self.tag}")
if self.rank == 0:
print(f"\nGeneration {gen:d} starts")
self.logging.info(f"Generation {gen:d} starts")
t0 = time()

# Initialize
cur_xtals = [(None, "Random")] * self.N_pop

Expand All @@ -194,10 +193,11 @@ def _run(self, pool=None):
# broadcast
if self.use_mpi:
cur_xtals = self.comm.bcast(cur_xtals, root=0)
#print(f"Rank {self.rank} after broadcast: current_xtals = {current_xtals}")
#self.logging.info(f"Rank {self.rank} gets {len(cur_xtals)} strucs {self.tag}")

# Local optimization
gen_results = self.local_optimization(cur_xtals, pool=pool)
self.logging.info(f"Rank {self.rank} finishes local_opt {self.tag}")

prev_xtals = None
if self.rank == 0:
Expand All @@ -212,6 +212,8 @@ def _run(self, pool=None):
if self.use_mpi:
prev_xtals = self.comm.bcast(prev_xtals, root=0)

self.logging.info(f"Gen {gen} bcast in Rank {self.rank} {self.tag}")

# Update the FF parameters if necessary
if self.ff_opt:
N_added = self.update_ff_paramters(cur_xtals, engs, N_added)
Expand All @@ -225,13 +227,16 @@ def _run(self, pool=None):

elif self.ref_pxrd is not None:
self.count_pxrd_match(cur_xtals, matches)

# quit the loop
if self.use_mpi:
quit = self.comm.bcast(quit, root=0)
self.comm.Barrier()

self.logging.info(f"Gen {gen} Finish in Rank {self.rank} {self.tag}")
# Ensure that all ranks exit
if quit:
self.logging.info(f"Early Termination in Rank {self.rank} {self.tag}")
return success_rate

return success_rate
Expand Down
Loading

0 comments on commit 2167d89

Please sign in to comment.