-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add
znlib.atomistic.ase.FileToASE
(#1)
* add 'atomistic' extras * add pylint; modify imports; code linting
- Loading branch information
Showing
10 changed files
with
1,770 additions
and
14 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pathlib | ||
import shutil | ||
import subprocess | ||
|
||
import ase | ||
|
||
import znlib | ||
|
||
|
||
def test_AddData(proj_path, tetraeder_test_traj): | ||
traj_file = pathlib.Path(tetraeder_test_traj) | ||
shutil.copy(traj_file, ".") | ||
|
||
subprocess.check_call(["dvc", "add", traj_file.name]) | ||
data = znlib.atomistic.FileToASE(file=traj_file.name) | ||
data.run_and_save() | ||
|
||
loaded_data = znlib.atomistic.FileToASE.load() | ||
|
||
assert isinstance(loaded_data.atoms, znlib.atomistic.ase.LazyAtomsSequence) | ||
assert isinstance(loaded_data.atoms[0], ase.Atoms) | ||
assert loaded_data.atoms.__dict__["atoms"] == {0: loaded_data.atoms[0]} | ||
assert isinstance(loaded_data.atoms[[0, 1]], list) | ||
assert isinstance(loaded_data.atoms[:], list) | ||
assert isinstance(loaded_data.atoms.tolist(), list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""The znlib atomistic interface""" | ||
from znlib.atomistic import ase | ||
from znlib.atomistic.ase import FileToASE | ||
|
||
__all__ = ["ase", "FileToASE"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
"""Atomic Simulation Environment interface for znlib / ZnTrack """ | ||
import collections.abc | ||
import logging | ||
import pathlib | ||
import typing | ||
|
||
import ase.db | ||
import ase.io | ||
import tqdm | ||
from zntrack import Node, dvc, utils, zn | ||
from zntrack.core import ZnTrackOption | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
AtomsList = typing.List[ase.Atoms] | ||
|
||
|
||
class LazyAtomsSequence(collections.abc.Sequence): | ||
"""Sequence that loads atoms objects from ase Database only when accessed | ||
This sequence does not support modifications but only reading values from it. | ||
""" | ||
|
||
def __init__(self, database: str, threshold: int = 100): | ||
"""Default __init__ | ||
Parameters | ||
---------- | ||
database: file | ||
The database to read from | ||
threshold: int | ||
Minimum number of atoms to read at once to print tqdm loading bars | ||
""" | ||
self._database = database | ||
self._threshold = threshold | ||
self.__dict__["atoms"]: typing.Dict[int, ase.Atoms] = {} | ||
self._len = None | ||
|
||
def _update_state_from_db(self, indices: list): | ||
"""Load requested atoms into memory | ||
If the atoms are not present in __dict__ they will be read from db | ||
Parameters | ||
---------- | ||
indices: list | ||
The indices of the atoms. Indices are 0based and will be converted to 1 based | ||
when reading from the ase db. | ||
""" | ||
indices = [x for x in indices if x not in self.__dict__["atoms"]] | ||
|
||
with ase.db.connect(self._database) as database: | ||
for key in tqdm.tqdm( | ||
indices, | ||
disable=len(indices) < self._threshold, | ||
ncols=120, | ||
desc=f"Loading atoms from {self._database}", | ||
): | ||
self.__dict__["atoms"][key] = database[key + 1].toatoms() | ||
|
||
def __iter__(self): | ||
"""Enable iterating over the sequence. This will load all data at once""" | ||
self._update_state_from_db(list(range(len(self)))) | ||
for idx in range(len(self)): | ||
yield self[idx] | ||
|
||
def __getitem__(self, item) -> typing.Union[ase.Atoms, AtomsList]: | ||
"""Get atoms | ||
Parameters | ||
---------- | ||
item: int | list | slice | ||
The identifier of the requested atoms to return | ||
Returns | ||
------- | ||
Atoms | list[Atoms] | ||
""" | ||
if isinstance(item, int): | ||
# The most simple case | ||
try: | ||
return self.__dict__["atoms"][item] | ||
except KeyError: | ||
self._update_state_from_db([item]) | ||
return self.__dict__["atoms"][item] | ||
# everything with lists | ||
if isinstance(item, slice): | ||
item = list(range(len(self)))[item] | ||
|
||
try: | ||
return [self.__dict__["atoms"][x] for x in item] | ||
except KeyError: | ||
self._update_state_from_db(item) | ||
return [self.__dict__["atoms"][x] for x in item] | ||
|
||
def __len__(self): | ||
"""Get the len based on the db. This value is cached because | ||
the db is not expected to change during the lifetime of this class | ||
""" | ||
if self._len is None: | ||
with ase.db.connect(self._database) as db: | ||
self._len = len(db) | ||
return self._len | ||
|
||
def __repr__(self): | ||
return f"{self.__class__.__name__}(db={self._database})" | ||
|
||
def tolist(self) -> AtomsList: | ||
"""Convert sequence to a list of atoms objects""" | ||
return list(self) | ||
|
||
|
||
class ZnAtoms(ZnTrackOption): | ||
"""Store list[ase.Atoms] in an ASE database.""" | ||
|
||
dvc_option = "outs" | ||
zn_type = utils.ZnTypes.RESULTS | ||
|
||
def get_filename(self, instance) -> pathlib.Path: | ||
"""Overwrite filename to csv""" | ||
return pathlib.Path("nodes", instance.node_name, f"{self.name}.db") | ||
|
||
def save(self, instance): | ||
"""Save value with ase.db.connect""" | ||
atoms: AtomsList = getattr(instance, self.name) | ||
file = self.get_filename(instance) | ||
# file.parent.mkdir(exist_ok=True, parents=True) | ||
with ase.db.connect(file, append=False) as db: | ||
for atom in tqdm.tqdm(atoms, desc=f"Writing atoms to {file}"): | ||
db.write(atom, group=instance.node_name) | ||
|
||
def get_data_from_files(self, instance) -> LazyAtomsSequence: | ||
"""Load value with ase.db.connect""" | ||
return LazyAtomsSequence( | ||
database=self.get_filename(instance).resolve().as_posix() | ||
) | ||
|
||
|
||
class FileToASE(Node): | ||
"""Read a ASE compatible file and make it available as list of atoms objects | ||
The atoms object is a LazyAtomsSequence | ||
""" | ||
|
||
file: typing.Union[str, pathlib.Path] = dvc.deps() | ||
frames_to_read: int = zn.params(None) | ||
|
||
atoms: AtomsList = ZnAtoms() | ||
|
||
def post_init(self): | ||
if not self.is_loaded: | ||
self.file = pathlib.Path(self.file) | ||
dvc_file = pathlib.Path(self.file.name + ".dvc") | ||
if not dvc_file.exists(): | ||
log.warning( | ||
"The File you are adding as a dependency does not seem to be tracked" | ||
f" by DVC. Please run 'dvc add {self.file}' to avoid tracking the" | ||
" file directly with GIT." | ||
) | ||
|
||
def run(self): | ||
self.atoms = [] | ||
for config, atom in enumerate( | ||
tqdm.tqdm(ase.io.iread(self.file), desc="Reading File") | ||
): | ||
if self.frames_to_read is not None: | ||
if config >= self.frames_to_read: | ||
break | ||
self.atoms.append(atom) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
"""znlib / ZnTrack examples""" | ||
from znlib.examples.general import ( | ||
AddInputs, | ||
InputToMetric, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters