diff --git a/configs/htex.yaml b/configs/htex.yaml index 87d49af..faccb96 100644 --- a/configs/htex.yaml +++ b/configs/htex.yaml @@ -8,6 +8,7 @@ ModelTraining: max_walltime: 1 gpu: true ReferenceEvaluation: + max_walltime: 0.3 mpi_command: 'mpirun -np {}' # cp2k on conda-forge comes with OpenMPI (not MPICH as in container) cores_per_worker: 1 mode: 'htex' diff --git a/psiflow/reference/_pyscf.py b/psiflow/reference/_pyscf.py index de7ba3c..cfeccc0 100644 --- a/psiflow/reference/_pyscf.py +++ b/psiflow/reference/_pyscf.py @@ -1,6 +1,7 @@ from __future__ import annotations # necessary for type-guarding class methods import logging +from typing import Optional import numpy as np import parsl @@ -8,11 +9,12 @@ from ase.data import atomic_numbers from parsl.app.app import bash_app, join_app, python_app from parsl.data_provider.files import File +from parsl.executors import WorkQueueExecutor import psiflow from psiflow.data import FlowAtoms, NullState from psiflow.reference.base import BaseReference -from psiflow.utils import copy_app_future +from psiflow.utils import copy_app_future, get_active_executor logger = logging.getLogger(__name__) # logging per module @@ -141,6 +143,7 @@ def pyscf_singlepoint_pre( stdout: str = "", stderr: str = "", walltime: int = 0, + parsl_resource_specification: Optional[dict] = None, **parameters, ) -> str: from psiflow.reference._pyscf import generate_script @@ -232,6 +235,11 @@ def create_apps(cls): ncores = definition.cores_per_worker walltime = definition.max_walltime + if isinstance(get_active_executor(label), WorkQueueExecutor): + resource_specification = definition.generate_parsl_resource_specification() + else: + resource_specification = {} + singlepoint_pre = bash_app( pyscf_singlepoint_pre, executors=[label], @@ -253,6 +261,7 @@ def singlepoint_wrapped(atoms, parameters, file_names, inputs=[]): stdout=parsl.AUTO_LOGNAME, stderr=parsl.AUTO_LOGNAME, walltime=60 * walltime, # killed after walltime - 10s + parsl_resource_specification=resource_specification, **parameters, ) return singlepoint_post(