Skip to content

Commit

Permalink
fixed test_trajectory tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ChiCheng45 committed Jan 16, 2024
1 parent c861a89 commit 9f2bf7c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 41 deletions.
15 changes: 13 additions & 2 deletions MDANSE/Src/MDANSE/MolecularDynamics/Configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,22 @@ def __setitem__(self, name: str, value: ArrayLike) -> None:
:param value: the value of the variable to be set
:type value: numpy.ndarray
"""

item = np.array(value)

if name == "unit_cell":
if item.shape != (3, 3):
raise ValueError(
f"Invalid item dimensions for {name}; a shape of (3, 3) "
f"was expected but data with shape of {item.shape} was "
f"provided."
)
else:
self._variables[name] = value
return

if item.shape != (self._chemical_system.number_of_atoms, 3):
raise ValueError(
f"Invalid item dimensions; a shape of {(self._chemical_system.number_of_atoms, 3)} was "
f"Invalid item dimensions for {name}; a shape of {(self._chemical_system.number_of_atoms, 3)} was "
f"expected but data with shape of {item.shape} was provided."
)

Expand Down
8 changes: 5 additions & 3 deletions MDANSE/Src/MDANSE/MolecularDynamics/Trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,14 @@ def __getitem__(self, frame):
"""

grp = self._h5_file["/configuration"]

configuration = {}
for k, v in grp.items():
configuration[k] = v[frame].astype(np.float64)

for k in self._h5_file.keys():
if k in ("time", "unit_cell"):
configuration[k] = self._h5_file[k][frame].astype(np.float64)

return configuration

def __getstate__(self):
Expand Down Expand Up @@ -229,7 +232,6 @@ def read_com_trajectory(

indexes = [at.index for at in atoms]
masses = np.array([ATOMS_DATABASE[at.symbol]["atomic_weight"] for at in atoms])

grp = self._h5_file["/configuration"]

coords = grp["coordinates"][first:last:step, :, :].astype(np.float64)
Expand Down Expand Up @@ -444,7 +446,7 @@ def __init__(

self._h5_file = h5py.File(self._h5_filename, "w")

self._chemical_system = chemical_system.copy()
self._chemical_system = chemical_system

if selected_atoms is None:
self._selected_atoms = self._chemical_system.atom_list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,13 @@
@author: Eric C. Pellegrini
"""

import tempfile
import unittest

import numpy as np

from MDANSE.Chemistry.ChemicalEntity import Atom, ChemicalSystem
from MDANSE.MolecularDynamics.Configuration import RealConfiguration
from MDANSE.MolecularDynamics.Configuration import PeriodicRealConfiguration
from MDANSE.MolecularDynamics.Trajectory import Trajectory, TrajectoryWriter
from MDANSE.MolecularDynamics.UnitCell import UnitCell


class TestTrajectory(unittest.TestCase):
Expand All @@ -64,10 +63,9 @@ def setUp(self):
for i in range(self._nAtoms):
self._chemicalSystem.add_chemical_entity(Atom(symbol="H"))

self._filename = "test_traj.h5"

def test_write_trajectory(self):
tw = TrajectoryWriter(self._filename, self._chemicalSystem)
tf = tempfile.NamedTemporaryFile().name
tw = TrajectoryWriter(tf, self._chemicalSystem, 10)

allCoordinates = []
allUnitCells = []
Expand All @@ -76,15 +74,15 @@ def test_write_trajectory(self):
allTimes.append(i)
allUnitCells.append(np.random.uniform(0, 10, (3, 3)))
allCoordinates.append(np.random.uniform(0, 10, (self._nAtoms, 3)))
conf = RealConfiguration(
self._chemicalSystem, allCoordinates[-1], allUnitCells[-1]
conf = PeriodicRealConfiguration(
self._chemicalSystem, allCoordinates[-1], UnitCell(allUnitCells[-1])
)
self._chemicalSystem.configuration = conf
tw.chemical_system.configuration = conf
tw.dump_configuration(i)

tw.close()

t = Trajectory(self._filename)
t = Trajectory(tf)

for i in range(len(t)):
self.assertTrue(
Expand All @@ -98,7 +96,8 @@ def test_write_trajectory(self):
t.close()

def test_write_trajectory_with_velocities(self):
tw = TrajectoryWriter(self._filename, self._chemicalSystem)
tf = tempfile.NamedTemporaryFile().name
tw = TrajectoryWriter(tf, self._chemicalSystem, 10)

allCoordinates = []
allUnitCells = []
Expand All @@ -109,16 +108,16 @@ def test_write_trajectory_with_velocities(self):
allUnitCells.append(np.random.uniform(0, 10, (3, 3)))
allCoordinates.append(np.random.uniform(0, 10, (self._nAtoms, 3)))
allVelocities.append(np.random.uniform(0, 10, (self._nAtoms, 3)))
conf = RealConfiguration(
self._chemicalSystem, allCoordinates[-1], allUnitCells[-1]
conf = PeriodicRealConfiguration(
self._chemicalSystem, allCoordinates[-1], UnitCell(allUnitCells[-1])
)
conf.variables["velocities"] = allVelocities[-1]
self._chemicalSystem.configuration = conf
tw.dump_configuration(i)

tw.close()

t = Trajectory(self._filename)
t = Trajectory(tf)

for i in range(len(t)):
self.assertTrue(
Expand All @@ -135,7 +134,8 @@ def test_write_trajectory_with_velocities(self):
t.close()

def test_write_trajectory_with_gradients(self):
tw = TrajectoryWriter(self._filename, self._chemicalSystem)
tf = tempfile.NamedTemporaryFile().name
tw = TrajectoryWriter(tf, self._chemicalSystem, 10)

allCoordinates = []
allUnitCells = []
Expand All @@ -148,8 +148,8 @@ def test_write_trajectory_with_gradients(self):
allCoordinates.append(np.random.uniform(0, 10, (self._nAtoms, 3)))
allVelocities.append(np.random.uniform(0, 10, (self._nAtoms, 3)))
allGradients.append(np.random.uniform(0, 10, (self._nAtoms, 3)))
conf = RealConfiguration(
self._chemicalSystem, allCoordinates[-1], allUnitCells[-1]
conf = PeriodicRealConfiguration(
self._chemicalSystem, allCoordinates[-1], UnitCell(allUnitCells[-1])
)
conf.variables["velocities"] = allVelocities[-1]
conf.variables["gradients"] = allGradients[-1]
Expand All @@ -158,7 +158,7 @@ def test_write_trajectory_with_gradients(self):

tw.close()

t = Trajectory(self._filename)
t = Trajectory(tf)

for i in range(len(t)):
self.assertTrue(
Expand All @@ -178,7 +178,8 @@ def test_write_trajectory_with_gradients(self):
t.close()

def test_read_com_trajectory(self):
tw = TrajectoryWriter(self._filename, self._chemicalSystem)
tf = tempfile.NamedTemporaryFile().name
tw = TrajectoryWriter(tf, self._chemicalSystem, 1)

allCoordinates = []
allUnitCells = []
Expand All @@ -189,28 +190,17 @@ def test_read_com_trajectory(self):
allCoordinates.append(
[[0.0, 0.0, 0.0], [8.0, 8.0, 8.0], [4.0, 4.0, 4.0], [2.0, 2.0, 2.0]]
)
conf = RealConfiguration(
self._chemicalSystem, allCoordinates[-1], allUnitCells[-1]
conf = PeriodicRealConfiguration(
self._chemicalSystem, allCoordinates[-1], UnitCell(allUnitCells[-1])
)
self._chemicalSystem.configuration = conf
tw.dump_configuration(i)

tw.close()

t = Trajectory(self._filename)
t = Trajectory(tf)
com_trajectory = t.read_com_trajectory(
[0, 1, 2, 3], np.array([1.0, 1.0, 1.0, 1.0]), 0, 1, 1
self._chemicalSystem.atoms, 0, 1, 1
)
self.assertTrue(np.allclose(com_trajectory, [[1.0, 1.0, 1.0]], rtol=1.0e-6))
self.assertTrue(np.allclose(com_trajectory, [[3.5, 3.5, 3.5]], rtol=1.0e-6))
t.close()


def suite():
loader = unittest.TestLoader()
s = unittest.TestSuite()
s.addTest(loader.loadTestsFromTestCase(TestTrajectory))
return s


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 9f2bf7c

Please sign in to comment.