Skip to content

Commit

Permalink
fully debug mace
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Sep 28, 2024
1 parent 8fa49ac commit ac664cc
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 99 deletions.
119 changes: 52 additions & 67 deletions pyxtal/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def setup_worker_logger(log_file):
filename=log_file,
level=logging.INFO)


def call_opt_single(p):
"""
Optimize a single structure and log the result.
Expand All @@ -43,16 +42,11 @@ def call_opt_single(p):
This function calls `opt_single` to perform the optimization of the structure
associated with the given id.
"""
logger = logging.getLogger()
logger.info(f"ID: {p[0]} *{sum(p[1].numIons)}")
id = p[0]
#logger = logging.getLogger()
#logger.info(f"ID: {p[0]} *{sum(p[1].numIons)}")
myid = p[0]
xtal, eng, status = opt_single(*p)

if eng is not None:
logger.info(f"ID: {id}, eng {eng:.3f} *{sum(xtal.numIons)}")
else:
logger.info(f"ID: {id}, Failed")
return id, xtal, eng
return myid, xtal, eng


def opt_single(id, xtal, calc, *args):
Expand Down Expand Up @@ -253,26 +247,28 @@ def mace_opt_single(id, xtal, criteria, step=250):
"""
from pyxtal.interface.ase_opt import ASE_relax as mace_opt

logger = logging.getLogger()
atoms = xtal.to_ase(resort=False)
s = mace_opt(atoms,
'MACE',
opt_cell=True,
step=step,
max_time=10.0)
error = False
max_time=9.0 * max([1, (len(atoms)/200)]),
label=str(id))
if s is None:
logger.info(f"mace_opt_single Failure {id}")
return None, None, False

try:
xtal = pyxtal()
xtal.from_seed(s)
eng = s.get_potential_energy() / len(s)
except:
error = True
eng = None
xtal = None

status = False
if not error:
status = process_xtal(id, xtal, eng, criteria)
return xtal, eng, status
logger.info(f"mace_opt_single Success {id}")
return xtal, eng, status
except:
logger.info(f"mace_opt_single Bug {id}")
return None, None, False


def process_xtal(id, xtal, eng, criteria):
Expand Down Expand Up @@ -1261,6 +1257,7 @@ def update_row_energy_serial(self, generator, write_freq, args, args_up):
results = []
for id, xtal in generator:
self.logging.info(f"Processing {id} {xtal.lattice} {args[0]}")
print(f"Processing {id} {xtal.lattice} {args[0]}")
res = opt_single(id, xtal, *args)
(xtal, eng, status) = res
if status:
Expand Down Expand Up @@ -1314,35 +1311,39 @@ def chunkify(generator, chunk_size):
if chunk:
yield chunk

results = []
for chunk in chunkify(generator, ncpu*8):
for chunk in chunkify(generator, ncpu*10):
myargs = []
for _id, xtal in chunk:
if xtal is not None:
myargs.append(tuple([_id, xtal] + args))

results = []
self.logging.info(f"Start minicycle: {myargs[0][0]}-{myargs[-1][0]}")
for result in pool.imap_unordered(call_opt_single,
myargs,
chunksize=1):
if result is not None:
(id, xtal, eng) = result
(myid, xtal, eng) = result
if eng is not None:
results.append(result)
#self.logging.info(f"Add {id}, size: {len(results)}")
numIons = sum(xtal.numIons)
count = len(results)
self.logging.info(f"Add {myid:4d} {eng:.3f} *{numIons} {count}")

if len(results) >= ncpu:
# Only do frequent update for slow calculator VASP
if len(results) >= ncpu and args[0] == 'VASP':
self._update_db(results, args[0], *args_up)
self.logging.info(f"Finish minibatch: {len(results)}")
self.print_memory_usage()
self.logging.info(f"Update db: {len(results)}")
results = []

self.logging.info(f"Complete minicycle: {len(chunk)}")
self.logging.info(f"Done minicycle: {myargs[0][0]}-{myargs[-1][0]}")

# After the loop, handle the remaining results
if results:
self.logging.info(f"Update db leftover: {len(results)}")
self._update_db(results, args[0], *args_up)
self.logging.info(f"Finish Update db leftover: {len(results)}")
# After the loop, handle the remaining results
if results:
self.logging.info(f"Start Update db: {len(results)}")
self._update_db(results, args[0], *args_up)
self.logging.info(f"Finish Update db: {len(results)}")

pool.close()
pool.join()
Expand All @@ -1356,48 +1357,32 @@ def _update_db(self, results, calc, *args):
results: list of (id, xtal, eng) tuples
calc (str): calculator
"""
delay = 20
max_retries = 3

self.logging.info(f"Update db: {len(results)}")
#self.logging.info(f"====================Update db: {len(results)}")
if calc == 'GULP':
ff_lib = args[0]

with self.db:
for result in results:
(id, xtal, eng) = result
#self.logging.info(f"Start: id{id}")
if xtal is not None:
for attempt in range(max_retries):
self.logging.info(f'update_db_{calc}_{id}, att: {attempt}')
if True: #try:
if calc == 'GULP':
self.db.update(id,
ff_energy=eng,
ff_lib=ff_lib,
ff_relaxed=xtal.to_file())
elif calc == 'MACE':
self.db.update(id,
mace_energy=eng,
mace_relaxed=xtal.to_file())
elif calc == 'VASP':
self.db.update(id,
vasp_energy=eng,
vasp_relaxed=xtal.to_file())
elif calc == 'DFTB':
self.db.update(id,
dftb_energy=eng,
dftb_relaxed=xtal.to_file())
# If update is successful, break out loop
break

#except Exception as e:
# msg = f"Rank-{self.rank} failed in updating {id}, Wait"
# self.logging.info(msg)
# if attempt == max_retries - 1:
# break
# time.sleep(delay)
self.logging.info(f'update_db_{calc}, {id}')
if calc == 'GULP':
self.db.update(id,
ff_energy=eng,
ff_lib=ff_lib,
ff_relaxed=xtal.to_file())
elif calc == 'MACE':
self.db.update(id,
mace_energy=eng,
mace_relaxed=xtal.to_file())
elif calc == 'VASP':
self.db.update(id,
vasp_energy=eng,
vasp_relaxed=xtal.to_file())
elif calc == 'DFTB':
self.db.update(id,
dftb_energy=eng,
dftb_relaxed=xtal.to_file())
#self.logging.info(f'update_db_{calc}, {id}')

def update_row_topology(self, StructureType="Auto", overwrite=True, prefix=None, ref_dim=3):
"""
Expand Down
64 changes: 32 additions & 32 deletions pyxtal/interface/ase_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,13 @@
from ase.filters import UnitCellFilter
from ase.optimize.fire import FIRE
import torchani
import logging

import os
from mace.calculators import mace_mp
_cached_mace_mp = None


# Define a handler for the timeout
class TimeoutException(Exception):
pass

def timeout_handler(signum, frame):
raise TimeoutException

def get_calculator(calculator):
global _cached_mace_mp

Expand All @@ -33,8 +27,8 @@ def get_calculator(calculator):

return calc

#def ASE_relax(struc, calculator, opt_cell=False, step=500, fmax=0.1, logfile=None, max_time=10.0):
def ASE_relax(struc, calculator, opt_cell=False, step=500, fmax=0.1, logfile='ase.log', max_time=10.0):
#def ASE_relax(struc, calculator, opt_cell=False, step=500, fmax=0.1, logfile=None, max_time=10.0, label='ase'):
def ASE_relax(struc, calculator, opt_cell=False, step=500, fmax=0.1, logfile='ase.log', max_time=10.0, label='ase'):
"""
ASE optimizer
Args:
Expand All @@ -43,12 +37,20 @@ def ASE_relax(struc, calculator, opt_cell=False, step=500, fmax=0.1, logfile='as
step: optimization steps (int)
max_time: float (minutes)
"""

def handler(signum, frame):
raise TimeoutError("Optimization timed out")

step_init = 40
logger = logging.getLogger()
max_time *= 60
timeout = int(max_time)
signal.signal(signal.SIGALRM, timeout_handler)
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
#logger.info(f"{label} start calculation")

try:
#if True:
calc = get_calculator(calculator)
struc.set_calculator(calc)
struc.set_constraint(FixSymmetry(struc))
Expand All @@ -59,40 +61,38 @@ def ASE_relax(struc, calculator, opt_cell=False, step=500, fmax=0.1, logfile='as
dyn = FIRE(struc, a=0.1, logfile=logfile) if logfile is not None else FIRE(struc, a=0.1)

# Run relaxation
dyn.run(fmax=fmax, steps=20)
dyn.run(fmax=fmax, steps=step_init)
forces = dyn.optimizable.get_forces()
_fmax = np.sqrt((forces ** 2).sum(axis=1).max())
# print("debug", _fmax)

if _fmax < 1e+3:
if step < 50:
dyn.run(fmax=fmax, steps=step)
else:
t0 = time()
dyn.run(fmax=fmax, steps=int(step / 2))
# If time is too long, only run half steps
if (time() - t0) < max_time / 2:
dyn.run(fmax=fmax, steps=int(step / 2))
if _fmax < 1e+3 and step > step_init:
dyn.run(fmax=fmax, steps=step-step_init)
forces = dyn.optimizable.get_forces()
_fmax = np.sqrt((forces ** 2).sum(axis=1).max())
eng = struc.get_potential_energy() / len(struc)
if _fmax > 100:
return None
logger.info(f"Warning {label} big stress {eng:.2f} / {_fmax:.2f}, skip")
struc = None
else:
return struc
logger.info(f"{label} Success {eng:.2f} / {_fmax:.2f}")
else:
logger.info(f"Warning {label} big stress {_fmax:.2f} for 20 steps, skip")
struc = None
signal.alarm(0) # Cancel the alarm if finished within time

except TimeoutException:
print(f"ASE_relax timed out after {timeout} seconds.")
return None
except TimeoutError:
logger.info(f"Warning {label} timed out after {timeout} seconds.")
struc = None

except TypeError:
print("spglib error in getting the lattice")
return None

finally:
logger.info(f"Warning {label} spglib error in getting the lattice")
struc = None
signal.alarm(0) # Cancel the alarm if finished within time

return None

tag = 'False' if struc is None else 'True'
logger.info(f"Finishing {label} {tag}")
#signal.alarm(0) # Cancel the alarm
return struc

class ASE_optimizer:
"""
Expand Down

0 comments on commit ac664cc

Please sign in to comment.