From 7ca5fec9737e5d30de3faf882976bc8be858eb27 Mon Sep 17 00:00:00 2001 From: qzhu2017 Date: Sun, 8 Sep 2024 17:05:21 -0400 Subject: [PATCH] add optimize_reps to reduce the memory storage --- pyxtal/lego/builder.py | 88 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/pyxtal/lego/builder.py b/pyxtal/lego/builder.py index f8f10dba..acdddf48 100644 --- a/pyxtal/lego/builder.py +++ b/pyxtal/lego/builder.py @@ -178,8 +178,8 @@ def minimize_from_x(x, dim, spg, wps, elements, calculator, ref_environments, while True: count += 1 try: - xtal.from_random(dim, g, elements, numIons, - sites=sites_wp, factor=1.0, + xtal.from_random(dim, g, elements, numIons, + sites=sites_wp, factor=1.0, random_state=random_state) except RuntimeError: print(g.number, numIons, sites) @@ -600,6 +600,7 @@ def print_memory_usage(self): process = psutil.Process(os.getpid()) mem = process.memory_info().rss / 1024 ** 2 self.logging.info(f"Rank {self.rank} memory: {mem:.1f} MB") + print(f"Rank {self.rank} memory: {mem:.1f} MB") def set_descriptor_calculator(self, dtype='SO3', mykwargs={}): """ @@ -793,6 +794,89 @@ def optimize_xtals_serial(self, xtals, args): xtals_opt.append(xtal) return xtals_opt + def optimize_reps(self, reps, ncpu=1, opt_type='local', + T=0.2, niter=20, early_quit=0.02, + add_db=True, symmetrize=False, + minimizers=[('Nelder-Mead', 100), ('L-BFGS-B', 100)], + ): + """ + Perform optimization for each structure + + Args: + reps: list of reps + ncpu (int): + """ + args = (opt_type, T, niter, early_quit, add_db, symmetrize, minimizers) + if ncpu > 1: + valid_xtals = self.optimize_reps_mproc(reps, ncpu, args) + return valid_xtals + else: + raise NotImplementedError("optimize_reps works in parallel mode") + + def optimize_reps_mproc(self, reps, ncpu, args): + """ + Optimization in multiprocess mode. + + Args: + reps: list of reps + ncpu (int): number of parallel python processes + args: (opt_type, T, n_iter, early_quit, add_db, symmetrize, minimizers) + """ + from multiprocessing import Pool + from collections import deque + import gc + + pool = Pool(processes=ncpu) + (opt_type, T, niter, early_quit, add_db, symmetrize, minimizers) = args + xtals_opt = deque() + + # Split the input structures to minibatches + N_rep = 4 + N_batches = N_rep * ncpu + for _i, i in enumerate(range(0, len(reps), N_batches)): + start, end = i, min([i+N_batches, len(reps)]) + ids = list(range(start, end)) + print(f"Rank {self.rank} minibatch {start} {end}") + self.print_memory_usage() + + def generate_args(): + """ + A generator to yield argument lists for minimize_from_x_par. + """ + for j in range(ncpu): + _ids = ids[j::ncpu] + wp_libs = [] + for id in _ids: + rep = reps[id] + xtal = pyxtal() + xtal.from_tabular_representation(rep, normalize=False) + x = xtal.get_1d_rep_x() + spg, wps, _ = self.get_input_from_ref_xtal(xtal) + wp_libs.append((x, spg, wps)) + yield (self.dim, wp_libs, self.elements, self.calculator, + self.ref_environments, opt_type, T, niter, + early_quit, minimizers) + # Use the generator to pass args to reduce memory usage + _xtal, _xs = None, None + for result in pool.imap_unordered(minimize_from_x_par, generate_args()): + if result is not None: + (_xtals, _xs) = result + valid_xtals = self.process_xtals( + _xtals, _xs, add_db, symmetrize) + xtals_opt.extend(valid_xtals) # Use deque to reduce memory + + # Remove the duplicate structures + self.db.update_row_topology(overwrite=False, prefix=self.prefix) + self.db.clean_structures_spg_topology(dim=self.dim) + + # After each minibatch, delete the local variables and run garbage collection + del ids, _xtals, _xs + gc.collect() # Explicitly call garbage collector to free memory + + xtals_opt = list(xtals_opt) + print(f"Rank {self.rank} finish optimize_reps_mproc {len(xtals_opt)}") + return xtals_opt + def optimize_xtals_mproc(self, xtals, ncpu, args): """ Optimization in multiprocess mode.