Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Dec 20, 2024
1 parent aaa10fa commit ae7ad4b
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 31 deletions.
3 changes: 3 additions & 0 deletions psiflow/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Dataset:
This class provides methods for manipulating and analyzing collections of atomic structures.
"""

extxyz: psiflow._DataFuture

def __init__(
Expand Down Expand Up @@ -415,6 +416,7 @@ def _concatenate_multiple(*args: list[np.ndarray]) -> list[np.ndarray]:
Note:
This function is wrapped as a Parsl app and executed using the default_threads executor.
"""

def pad_arrays(
arrays: list[np.ndarray],
pad_dimension: int = 1,
Expand Down Expand Up @@ -621,6 +623,7 @@ class Computable:
outputs (ClassVar[tuple[str, ...]]): Names of output quantities.
batch_size (ClassVar[Optional[int]]): Default batch size for computation.
"""

outputs: ClassVar[tuple[str, ...]] = ()
batch_size: ClassVar[Optional[int]] = None

Expand Down
17 changes: 9 additions & 8 deletions psiflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def from_config(
provider_cls = SlurmProvider
provider_kwargs = kwargs.pop("slurm") # do not allow empty dict
provider_kwargs["init_blocks"] = 0
if not 'exclusive' in provider_kwargs:
provider_kwargs['exclusive'] = False
if "exclusive" not in provider_kwargs:
provider_kwargs["exclusive"] = False
else:
provider_cls = LocalProvider # noqa: F405
provider_kwargs = kwargs.pop("local", {})
Expand Down Expand Up @@ -452,7 +452,7 @@ def from_config(
max_idletime: float = 20,
internal_tasks_max_threads: int = 10,
default_threads: int = 4,
htex_address: str = '127.0.0.1',
htex_address: str = "127.0.0.1",
zip_staging: Optional[bool] = None,
container_uri: Optional[str] = None,
container_engine: str = "apptainer",
Expand Down Expand Up @@ -552,11 +552,14 @@ def from_config(
context = ExecutionContext(config, definitions, path / "context_dir")

if make_symlinks:
src, dest = Path.cwd() / f'psiflow_log', path / 'parsl.log'
src, dest = Path.cwd() / "psiflow_log", path / "parsl.log"
_create_symlink(src, dest)
src, dest = Path.cwd() / f'psiflow_submit_scripts', path / '000' / 'submit_scripts'
src, dest = (
Path.cwd() / "psiflow_submit_scripts",
path / "000" / "submit_scripts",
)
_create_symlink(src, dest, is_dir=True)
src, dest = Path.cwd() / f'psiflow_task_logs', path / '000' / 'task_logs'
src, dest = Path.cwd() / "psiflow_task_logs", path / "000" / "task_logs"
_create_symlink(src, dest, is_dir=True)

return context
Expand Down Expand Up @@ -684,5 +687,3 @@ def _create_symlink(src: Path, dest: Path, is_dir: bool = False) -> None:
else:
dest.touch(exist_ok=True)
src.symlink_to(dest, target_is_directory=is_dir)


18 changes: 11 additions & 7 deletions psiflow/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,26 +274,30 @@ def from_string(cls, s: str, natoms: Optional[int] = None) -> Optional[Geometry]

# read and format per_atom data
column_indices = {}
if 'Properties' in comment_dict:
properties = comment_dict['Properties'].split(':')
if "Properties" in comment_dict:
properties = comment_dict["Properties"].split(":")
count = 0
for i in range(len(properties) // 3):
name = properties[3 * i]
ncolumns = int(properties[3 * i + 2])
column_indices[name] = count
count += ncolumns
assert 'pos' in column_indices # positions need to be there
assert "pos" in column_indices # positions need to be there

per_atom = np.recarray(natoms, dtype=per_atom_dtype)
per_atom.forces[:] = np.nan
POS_INDEX = column_indices.get('pos', 1)
FORCES_INDEX = column_indices.get('forces', None)
POS_INDEX = column_indices.get("pos", 1)
FORCES_INDEX = column_indices.get("forces", None)
for i in range(natoms):
values = lines[i + 1].split()
per_atom.numbers[i] = chemical_symbols.index(values[0])
per_atom.positions[i, :] = [float(_) for _ in values[POS_INDEX:POS_INDEX + 3]]
per_atom.positions[i, :] = [
float(_) for _ in values[POS_INDEX : POS_INDEX + 3]
]
if FORCES_INDEX is not None:
per_atom.forces[i, :] = [float(_) for _ in values[FORCES_INDEX:FORCES_INDEX + 3]]
per_atom.forces[i, :] = [
float(_) for _ in values[FORCES_INDEX : FORCES_INDEX + 3]
]

order = {}
for key, value in comment_dict.items():
Expand Down
2 changes: 1 addition & 1 deletion psiflow/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from psiflow.models import Model
from psiflow.reference import Reference, evaluate
from psiflow.sampling import SimulationOutput, Walker, sample
from psiflow.utils.apps import boolean_or, setup_logger, unpack_i, isnan
from psiflow.utils.apps import boolean_or, isnan, setup_logger, unpack_i

logger = setup_logger(__name__)

Expand Down
21 changes: 8 additions & 13 deletions psiflow/reference/_cp2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def set_global_section(cp2k_input_dict: dict, properties: tuple):
global_dict = cp2k_input_dict["global"]

# override low/silent print levels
level = global_dict.pop('print_level', 'MEDIUM')
if level in ['SILENT', 'LOW']:
global_dict['print_level'] = 'MEDIUM'
level = global_dict.pop("print_level", "MEDIUM")
if level in ["SILENT", "LOW"]:
global_dict["print_level"] = "MEDIUM"

if properties == ("energy",):
global_dict["run_type"] = "ENERGY"
Expand Down Expand Up @@ -156,11 +156,11 @@ def _prepare_input(
if "forces" in properties:
cp2k_input_dict["force_eval"]["print"] = {"FORCES": {}}
cp2k_input_str = dict_to_str(cp2k_input_dict)
with open(outputs[0], 'w') as f:
with open(outputs[0], "w") as f:
f.write(cp2k_input_str)


prepare_input = python_app(_prepare_input, executors=['default_threads'])
prepare_input = python_app(_prepare_input, executors=["default_threads"])


# typeguarding for some reason incompatible with WQ
Expand All @@ -175,14 +175,9 @@ def cp2k_singlepoint_pre(
cd_command = "cd $mytmpdir"
cp_command = "cp {} cp2k.inp".format(inputs[0].filepath)

command_list = [
tmp_command,
cd_command,
cp_command,
cp2k_command
]
command_list = [tmp_command, cd_command, cp_command, cp2k_command]

return ' && '.join(command_list)
return " && ".join(command_list)


@typeguard.typechecked
Expand Down Expand Up @@ -242,7 +237,7 @@ def wrapped_app_pre(geometry, stdout: str, stderr: str):
geometry,
cp2k_input_dict=self.cp2k_input_dict,
properties=tuple(self.outputs),
outputs=[psiflow.context().new_file('cp2k_', '.inp')],
outputs=[psiflow.context().new_file("cp2k_", ".inp")],
)
return app_pre(
cp2k_command=cp2k_command,
Expand Down
2 changes: 1 addition & 1 deletion psiflow/reference/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _nan_if_unsuccessful(
return result


nan_if_unsuccessful = python_app(_nan_if_unsuccessful, executors=['default_threads'])
nan_if_unsuccessful = python_app(_nan_if_unsuccessful, executors=["default_threads"])


@join_app
Expand Down
2 changes: 1 addition & 1 deletion psiflow/utils/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,4 @@ def _isnan(a: Union[float, np.ndarray]) -> bool:
return bool(np.any(np.isnan(a)))


isnan = python_app(_isnan, executors=['default_threads'])
isnan = python_app(_isnan, executors=["default_threads"])

0 comments on commit ae7ad4b

Please sign in to comment.