Skip to content

Commit

Permalink
add znlib.atomistic.ase.FileToASE (#1)
Browse files Browse the repository at this point in the history
* add 'atomistic' extras

* add pylint; modify imports; code linting
  • Loading branch information
PythonFZ authored Oct 30, 2022
1 parent 2adddbf commit 5b95482
Show file tree
Hide file tree
Showing 10 changed files with 1,770 additions and 14 deletions.
1,513 changes: 1,510 additions & 3 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ python = ">=3.8,<4.0.0"
colorama = "^0.4.5"
zntrack = { version = "^0.4.3", optional = true }
matplotlib = { version = "^3.6.1", optional = true }
ase = { version = "^3.22.1", optional = true }


[tool.poetry.extras]
zntrack = ["zntrack", "matplotlib"]
atomistic = ["ase", "zntrack"]

[tool.poetry.group.dev.dependencies]
pytest = "^7.2.0"
Expand All @@ -23,12 +25,18 @@ black = "^22.10.0"
isort = "^5.10.1"
coverage = "^6.5.0"
pre-commit = "^2.20.0"
jupyterlab = "^3.5.0"
pylint = "^2.15.5"


[tool.poetry.group.zntrack.dependencies]
zntrack = "^0.4.3"
matplotlib = "^3.6.1"


[tool.poetry.group.atomistic.dependencies]
ase = "^3.22.1"

[tool.poetry.urls]
repository = "https://github.com/zincware/znlib"

Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
"""
import os
import pathlib
import random
import shutil
import subprocess

import ase.io
import pytest


Expand All @@ -31,3 +33,25 @@ def proj_path(tmp_path, request) -> pathlib.Path:
subprocess.check_call(["git", "commit", "-m", "initial commit"])

return tmp_path


@pytest.fixture
def tetraeder_test_traj(tmp_path_factory) -> str:
"""Generate n atoms objects of random shifts of a CH4 tetraeder."""
# different tetraeder

tetraeder = ase.Atoms(
"CH4",
positions=[(1, 1, 1), (0, 0, 0), (0, 2, 2), (2, 2, 0), (2, 0, 2)],
cell=(2, 2, 2),
)

random.seed(42)

atoms = [tetraeder.copy() for _ in range(20)]
[x.rattle(stdev=0.5, seed=random.randint(1, int(1e6))) for x in atoms]

temporary_path = tmp_path_factory.getbasetemp()
file = temporary_path / "tetraeder.extxyz"
ase.io.write(file, atoms)
return file.resolve().as_posix()
25 changes: 25 additions & 0 deletions tests/test_atomistic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pathlib
import shutil
import subprocess

import ase

import znlib


def test_AddData(proj_path, tetraeder_test_traj):
traj_file = pathlib.Path(tetraeder_test_traj)
shutil.copy(traj_file, ".")

subprocess.check_call(["dvc", "add", traj_file.name])
data = znlib.atomistic.FileToASE(file=traj_file.name)
data.run_and_save()

loaded_data = znlib.atomistic.FileToASE.load()

assert isinstance(loaded_data.atoms, znlib.atomistic.ase.LazyAtomsSequence)
assert isinstance(loaded_data.atoms[0], ase.Atoms)
assert loaded_data.atoms.__dict__["atoms"] == {0: loaded_data.atoms[0]}
assert isinstance(loaded_data.atoms[[0, 1]], list)
assert isinstance(loaded_data.atoms[:], list)
assert isinstance(loaded_data.atoms.tolist(), list)
6 changes: 6 additions & 0 deletions znlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""The znlib package"""
import importlib.metadata
import importlib.util

Expand All @@ -9,3 +10,8 @@
from znlib import examples # noqa: F401

__all__.append("examples")

if importlib.util.find_spec("ase") is not None:
from znlib import atomistic # noqa: F401

__all__.append("ase")
5 changes: 5 additions & 0 deletions znlib/atomistic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""The znlib atomistic interface"""
from znlib.atomistic import ase
from znlib.atomistic.ase import FileToASE

__all__ = ["ase", "FileToASE"]
170 changes: 170 additions & 0 deletions znlib/atomistic/ase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Atomic Simulation Environment interface for znlib / ZnTrack """
import collections.abc
import logging
import pathlib
import typing

import ase.db
import ase.io
import tqdm
from zntrack import Node, dvc, utils, zn
from zntrack.core import ZnTrackOption

log = logging.getLogger(__name__)

AtomsList = typing.List[ase.Atoms]


class LazyAtomsSequence(collections.abc.Sequence):
"""Sequence that loads atoms objects from ase Database only when accessed
This sequence does not support modifications but only reading values from it.
"""

def __init__(self, database: str, threshold: int = 100):
"""Default __init__
Parameters
----------
database: file
The database to read from
threshold: int
Minimum number of atoms to read at once to print tqdm loading bars
"""
self._database = database
self._threshold = threshold
self.__dict__["atoms"]: typing.Dict[int, ase.Atoms] = {}
self._len = None

def _update_state_from_db(self, indices: list):
"""Load requested atoms into memory
If the atoms are not present in __dict__ they will be read from db
Parameters
----------
indices: list
The indices of the atoms. Indices are 0based and will be converted to 1 based
when reading from the ase db.
"""
indices = [x for x in indices if x not in self.__dict__["atoms"]]

with ase.db.connect(self._database) as database:
for key in tqdm.tqdm(
indices,
disable=len(indices) < self._threshold,
ncols=120,
desc=f"Loading atoms from {self._database}",
):
self.__dict__["atoms"][key] = database[key + 1].toatoms()

def __iter__(self):
"""Enable iterating over the sequence. This will load all data at once"""
self._update_state_from_db(list(range(len(self))))
for idx in range(len(self)):
yield self[idx]

def __getitem__(self, item) -> typing.Union[ase.Atoms, AtomsList]:
"""Get atoms
Parameters
----------
item: int | list | slice
The identifier of the requested atoms to return
Returns
-------
Atoms | list[Atoms]
"""
if isinstance(item, int):
# The most simple case
try:
return self.__dict__["atoms"][item]
except KeyError:
self._update_state_from_db([item])
return self.__dict__["atoms"][item]
# everything with lists
if isinstance(item, slice):
item = list(range(len(self)))[item]

try:
return [self.__dict__["atoms"][x] for x in item]
except KeyError:
self._update_state_from_db(item)
return [self.__dict__["atoms"][x] for x in item]

def __len__(self):
"""Get the len based on the db. This value is cached because
the db is not expected to change during the lifetime of this class
"""
if self._len is None:
with ase.db.connect(self._database) as db:
self._len = len(db)
return self._len

def __repr__(self):
return f"{self.__class__.__name__}(db={self._database})"

def tolist(self) -> AtomsList:
"""Convert sequence to a list of atoms objects"""
return list(self)


class ZnAtoms(ZnTrackOption):
"""Store list[ase.Atoms] in an ASE database."""

dvc_option = "outs"
zn_type = utils.ZnTypes.RESULTS

def get_filename(self, instance) -> pathlib.Path:
"""Overwrite filename to csv"""
return pathlib.Path("nodes", instance.node_name, f"{self.name}.db")

def save(self, instance):
"""Save value with ase.db.connect"""
atoms: AtomsList = getattr(instance, self.name)
file = self.get_filename(instance)
# file.parent.mkdir(exist_ok=True, parents=True)
with ase.db.connect(file, append=False) as db:
for atom in tqdm.tqdm(atoms, desc=f"Writing atoms to {file}"):
db.write(atom, group=instance.node_name)

def get_data_from_files(self, instance) -> LazyAtomsSequence:
"""Load value with ase.db.connect"""
return LazyAtomsSequence(
database=self.get_filename(instance).resolve().as_posix()
)


class FileToASE(Node):
"""Read a ASE compatible file and make it available as list of atoms objects
The atoms object is a LazyAtomsSequence
"""

file: typing.Union[str, pathlib.Path] = dvc.deps()
frames_to_read: int = zn.params(None)

atoms: AtomsList = ZnAtoms()

def post_init(self):
if not self.is_loaded:
self.file = pathlib.Path(self.file)
dvc_file = pathlib.Path(self.file.name + ".dvc")
if not dvc_file.exists():
log.warning(
"The File you are adding as a dependency does not seem to be tracked"
f" by DVC. Please run 'dvc add {self.file}' to avoid tracking the"
" file directly with GIT."
)

def run(self):
self.atoms = []
for config, atom in enumerate(
tqdm.tqdm(ase.io.iread(self.file), desc="Reading File")
):
if self.frames_to_read is not None:
if config >= self.frames_to_read:
break
self.atoms.append(atom)
27 changes: 17 additions & 10 deletions znlib/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Command Line Interface"""
import dataclasses
import importlib.metadata
from importlib.util import find_spec
Expand Down Expand Up @@ -33,13 +34,19 @@ def znlib_status():
"""All zincware packages that should be listed"""
print(f"Available {Fore.LIGHTBLUE_EX}zincware{Style.RESET_ALL} packages:")

ZnModules(name="znlib")
ZnModules(name="zntrack")
ZnModules(name="mdsuite")
ZnModules(name="znjson")
ZnModules(name="zninit")
ZnModules(name="dot4dict")
ZnModules(name="znipy")
ZnModules(name="supercharge")
ZnModules(name="znvis")
ZnModules(name="symdet")
packages = sorted(
[
"znlib",
"zntrack",
"mdsuite",
"znjson",
"zninit",
"dot4dict",
"znipy",
"supercharge",
"znvis",
"symdet",
]
)
for package in packages:
ZnModules(package)
1 change: 1 addition & 0 deletions znlib/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""znlib / ZnTrack examples"""
from znlib.examples.general import (
AddInputs,
InputToMetric,
Expand Down
5 changes: 4 additions & 1 deletion znlib/examples/mc_pi_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


def plot_sampling(ax, coordinates, n_points, estimate):
"""Plot a quarter of a circle with the sampled points"""
circle = plt.Circle((0, 0), 1, fill=False, linewidth=3, edgecolor="k", zorder=10)

ax.set_xlim(-0.0, 1.0)
Expand All @@ -20,7 +21,7 @@ def plot_sampling(ax, coordinates, n_points, estimate):
inner_points = np.array(list(filter(lambda x: np.linalg.norm(x) <= 1, coordinates)))
ax.plot(inner_points[:, 0], inner_points[:, 1], "r.")
ax.add_patch(circle)
ax.set_title(f"N: {n_points} ; $\pi$ = {estimate}")
ax.set_title(rf"N: {n_points} ; $\pi$ = {estimate}")
ax.set_aspect("equal")


Expand All @@ -34,13 +35,15 @@ class MonteCarloPiEstimator(Node):
estimate: float = zn.metrics()

def run(self):
"""Compute pi using MC"""
np.random.seed(self.seed)
self.coordinates = np.random.random(size=(self.n_points, 2))
radial_values = np.linalg.norm(self.coordinates, axis=1)
n_circle_points = len(list(filter(lambda x: x <= 1, radial_values)))
self.estimate = 4 * n_circle_points / self.n_points

def plot(self, ax):
"""Create a plot of the sampled coordinates"""
plot_sampling(
ax,
coordinates=self.coordinates,
Expand Down

0 comments on commit 5b95482

Please sign in to comment.