Skip to content

Commit

Permalink
Add method to create a graph from jumps (#325)
Browse files Browse the repository at this point in the history
* Add method to create a graph from jumps

* Rename SimulationMetrics -> TrajectoryMetrics

* Add method to get single activation energy

* Add min/max energy filters

* Do not depend on labels for graph

* Attach label to node

* Rename test module

* Fix tests (there is one orphan node)

* Update src/gemdat/jumps.py

Co-authored-by: SCiarella <[email protected]>

* Update src/gemdat/jumps.py

Co-authored-by: SCiarella <[email protected]>

* Update src/gemdat/jumps.py

Co-authored-by: SCiarella <[email protected]>

* Fix line lengths

---------

Co-authored-by: SCiarella <[email protected]>
  • Loading branch information
stefsmeets and SCiarella authored Jun 6, 2024
1 parent 28427d5 commit 493612e
Show file tree
Hide file tree
Showing 16 changed files with 164 additions and 56 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ jumps.jump_diffusivity(dimensions=3)
To calculate different metrics, such as tracer diffusivity:

```python
from gemdat import SimulationMetrics
from gemdat import TrajectoryMetrics

metrics = SimulationMetrics(diff_trajectory)
metrics = TrajectoryMetrics(diff_trajectory)

metrics.tracer_diffusivity(dimensions=3)
metrics.haven_ratio(dimensions=3)
Expand Down
2 changes: 1 addition & 1 deletion docs/api/gemdat.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
- [gemdat.read_cif][gemdat.io.read_cif]
- [gemdat.load_known_material][gemdat.io.load_known_material]
- [gemdat.SimulationMetrics][gemdat.simulation_metrics.SimulationMetrics]
- [gemdat.TrajectoryMetrics][gemdat.metrics.TrajectoryMetrics]
- [gemdat.Transitions][gemdat.transitions.Transitions]
- [gemdat.Jumps][gemdat.jumps.Jumps]
- [gemdat.Trajectory][gemdat.trajectory.Trajectory]
Expand Down
2 changes: 1 addition & 1 deletion docs/api/gemdat_simulation_metrics.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
::: gemdat.simulation_metrics
::: gemdat.metrics
options:
show_root_heading: false
show_root_toc_entry: false
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ nav:
- gemdat.io: api/gemdat_io.md
- gemdat.plots: api/gemdat_plots.md
- gemdat.rdf: api/gemdat_rdf.md
- gemdat.simulation_metrics: api/gemdat_simulation_metrics.md
- gemdat.metrics: api/gemdat_metrics.md
- gemdat.trajectory: api/gemdat_trajectory.md
- gemdat.transitions: api/gemdat_transitions.md
- gemdat.jumps: api/gemdat_jumps.md
Expand Down
4 changes: 2 additions & 2 deletions src/gemdat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from .io import load_known_material, read_cif
from .jumps import Jumps
from .metrics import TrajectoryMetrics
from .orientations import Orientations
from .rdf import radial_distribution
from .shape import ShapeAnalyzer
from .simulation_metrics import SimulationMetrics
from .trajectory import Trajectory
from .transitions import Transitions
from .volume import Volume, trajectory_to_volume
Expand All @@ -18,7 +18,7 @@
'radial_distribution',
'read_cif',
'ShapeAnalyzer',
'SimulationMetrics',
'TrajectoryMetrics',
'Trajectory',
'trajectory_to_volume',
'Transitions',
Expand Down
123 changes: 104 additions & 19 deletions src/gemdat/jumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from math import ceil
from typing import TYPE_CHECKING, Callable

import networkx as nx
import numpy as np
import pandas as pd
from pymatgen.core.units import FloatWithUnit
Expand All @@ -13,7 +14,7 @@
from ._plot_backend import plot_backend
from .caching import weak_lru_cache
from .collective import Collective
from .simulation_metrics import SimulationMetrics
from .metrics import TrajectoryMetrics
from .transitions import Transitions, _calculate_transitions_matrix

if TYPE_CHECKING:
Expand Down Expand Up @@ -223,7 +224,7 @@ def collective(self, max_dist: float = 1) -> Collective:
sites = self.transitions.sites

time_step = trajectory.time_step
attempt_freq, _ = SimulationMetrics(trajectory).attempt_frequency()
attempt_freq, _ = TrajectoryMetrics(trajectory).attempt_frequency()

max_steps = ceil(1.0 / (attempt_freq * time_step))

Expand All @@ -237,7 +238,7 @@ def collective(self, max_dist: float = 1) -> Collective:

@weak_lru_cache()
def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
"""Calculate activation energies for jumps (UNITS?).
"""Calculate activation energies for jumps in eV.
Parameters
----------
Expand All @@ -251,7 +252,7 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
between site pairs.
"""
trajectory = self.trajectory
attempt_freq, _ = SimulationMetrics(trajectory).attempt_frequency()
attempt_freq, _ = TrajectoryMetrics(trajectory).attempt_frequency()

dct = {}

Expand All @@ -260,13 +261,13 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
atom_locations_parts = [
part.atom_locations() for part in self.transitions.split(n_parts)
]
jumps_counter_parts = [part.jumps_counter() for part in self.split(n_parts)]
counter_parts = [part.counter() for part in self.split(n_parts)]
n_floating = self.n_floating

for site_pair in self.site_pairs:
site_start, site_stop = site_pair

n_jumps = np.array([part[site_pair] for part in jumps_counter_parts])
n_jumps = np.array([part[site_pair] for part in counter_parts])

part_time = trajectory.total_time / n_parts

Expand All @@ -292,22 +293,106 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:

return df

def jumps_counter(self) -> Counter:
"""Calculate number of jumps between sites.
@weak_lru_cache()
def counter(self) -> Counter[tuple[str, str]]:
"""Count number of jumps between sites.
Returns
-------
jumps : dict[tuple[str, str], int]
Dictionary with number of jumps per sites combination
counter : Counter[tuple[str, str]]
Dictionary with site pairs as keys and corresponding
number of jumps as dictionary values
"""
labels = self.sites.labels
jumps = Counter(
[
(labels[i], labels[j])
for _, (i, j) in self.data[['start site', 'destination site']].iterrows()
]
)
return jumps
counter: Counter[tuple[str, str]] = Counter()
for (i, j), val in self._counter().items():
counter[labels[i], labels[j]] += val
return counter

@weak_lru_cache()
def _counter(self) -> Counter[tuple[int, int]]:
"""Count number of jumps between sites. Keys are site indices.
Returns
-------
counter : Counter[tuple[int, int]]
Dictionary with site pairs as keys and corresponding
number of jumps as dictionary values
"""
counter = Counter(zip(self.data['start site'], self.data['destination site']))
return counter

def activation_energy_between_sites(self, start: str, stop: str) -> float:
"""Returns activation energy between two sites.
Uses `Jumps.to_graph()` in the background. For a large number of operations,
it is more efficient to query the graph directly.
Parameters
----------
start : str
Label of the start site
stop : str
Label of the stop site
Returns
-------
e_act : float
Activation energy in eV
"""
G = self.to_graph()
edge_data = G.get_edge_data(start, stop)
if not edge_data:
raise IndexError(f'No jumps between ({start}) and ({stop})')
return edge_data['e_act']

@weak_lru_cache()
def to_graph(
self, min_e_act: float | None = None, max_e_act: float | None = None
) -> nx.DiGraph:
"""Create a graph from jumps data.
The edges are weighted by the activation energy. The nodes are indices that
correspond to `Jumps.sites`.
Parameters
----------
min_e_act : float
Reject edges with activation energy below this threshold
max_e_act : float
Reject edges with activation energy above this threshold
Returns
-------
G : nx.DiGraph
A networkx DiGraph object.
"""
min_e_act = min_e_act if min_e_act else float('-inf')
max_e_act = max_e_act if max_e_act else float('inf')

atom_percentage = [site.species.num_atoms for site in self.transitions.occupancy()]

attempt_freq, _ = self.trajectory.metrics().attempt_frequency()
temperature = self.trajectory.metadata['temperature']
kBT = Boltzmann * temperature

G = nx.DiGraph()

for i, site in enumerate(self.sites):
G.add_node(i, label=site.label)

for (start, stop), n_jumps in self._counter().items():
time_perc = atom_percentage[start] * self.trajectory.total_time

eff_rate = n_jumps / time_perc

e_act = -np.log(eff_rate / attempt_freq) * kBT
e_act /= elementary_charge

if min_e_act <= e_act <= max_e_act:
G.add_edge(start, stop, e_act=e_act)

return G

def split(self, n_parts: int) -> list[Jumps]:
"""Split the jumps into parts.
Expand Down Expand Up @@ -336,12 +421,12 @@ def rates(self, n_parts: int = 10) -> pd.DataFrame:
"""
dct = {}

parts = [part.jumps_counter() for part in self.split(n_parts)]
parts = [part.counter() for part in self.split(n_parts)]
part_time = self.trajectory.total_time / n_parts

for site_pair in self.site_pairs:
n_jumps = [part[site_pair] for part in parts]

part_time = self.trajectory.total_time / n_parts
denom = self.n_floating * part_time

jump_freq_mean = np.mean(n_jumps) / denom
Expand Down
8 changes: 4 additions & 4 deletions src/gemdat/simulation_metrics.py → src/gemdat/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from trajectory import Trajectory


class SimulationMetrics:
class TrajectoryMetrics:
"""Class for calculating different metrics and properties from a molecular
dynamics simulation."""

Expand Down Expand Up @@ -115,7 +115,7 @@ def tracer_diffusivity_center_of_mass(
"""
center_of_mass = self.trajectory.center_of_mass()

metrics = SimulationMetrics(center_of_mass)
metrics = TrajectoryMetrics(center_of_mass)

return metrics.tracer_diffusivity(dimensions=dimensions)

Expand Down Expand Up @@ -230,7 +230,7 @@ def amplitudes(self) -> np.ndarray:
return np.asarray(amplitudes)


class SimulationMetricsStd:
class TrajectoryMetricsStd:
"""Class for calculating different metrics and properties from a molecular
dynamics simulation.
Expand All @@ -246,7 +246,7 @@ def __init__(self, trajectories: list[Trajectory]):
trajectories: list[Trajectory]
Input trajectories
"""
self.metrics = [SimulationMetrics(trajectory) for trajectory in trajectories]
self.metrics = [TrajectoryMetrics(trajectory) for trajectory in trajectories]

def speed(self) -> tuple[np.ndarray, np.ndarray]:
"""Calculate mean speed and standard deviations.
Expand Down
4 changes: 1 addition & 3 deletions src/gemdat/plots/matplotlib/_frequency_vs_occurence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import matplotlib.pyplot as plt
import numpy as np

from gemdat.simulation_metrics import SimulationMetrics

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory

Expand All @@ -24,7 +22,7 @@ def frequency_vs_occurence(*, trajectory: Trajectory) -> plt.Figure:
fig : matplotlib.figure.Figure
Output figure
"""
metrics = SimulationMetrics(trajectory)
metrics = trajectory.metrics()
speed = metrics.speed()

length = speed.shape[1]
Expand Down
4 changes: 1 addition & 3 deletions src/gemdat/plots/matplotlib/_vibrational_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import numpy as np
from scipy import stats

from gemdat.simulation_metrics import SimulationMetrics

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory

Expand All @@ -25,7 +23,7 @@ def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure:
fig : matplotlib.figure.Figure
Output figure
"""
metrics = SimulationMetrics(trajectory)
metrics = trajectory.metrics()

fig, ax = plt.subplots()
ax.hist(metrics.amplitudes(), bins=100, density=True)
Expand Down
4 changes: 1 addition & 3 deletions src/gemdat/plots/plotly/_frequency_vs_occurence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import numpy as np
import plotly.graph_objects as go

from gemdat.simulation_metrics import SimulationMetrics

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory

Expand All @@ -24,7 +22,7 @@ def frequency_vs_occurence(*, trajectory: Trajectory) -> go.Figure:
fig : plotly.graph_objects.Figure.Figure
Output figure
"""
metrics = SimulationMetrics(trajectory)
metrics = trajectory.metrics()
speed = metrics.speed()

length = speed.shape[1]
Expand Down
6 changes: 2 additions & 4 deletions src/gemdat/plots/plotly/_vibrational_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import plotly.graph_objects as go
from scipy import stats

from gemdat.simulation_metrics import SimulationMetrics

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory

Expand All @@ -33,8 +31,8 @@ def vibrational_amplitudes(
"""

trajectories = trajectory.split(n_parts)
single_metrics = SimulationMetrics(trajectory)
metrics = [SimulationMetrics(trajectory).amplitudes() for trajectory in trajectories]
single_metrics = trajectory.metrics()
metrics = [trajectory.metrics().amplitudes() for trajectory in trajectories]

max_amp = max(max(metric) for metric in metrics)
min_amp = min(min(metric) for metric in metrics)
Expand Down
7 changes: 7 additions & 0 deletions src/gemdat/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
if TYPE_CHECKING:
from pymatgen.core import Structure

from .metrics import TrajectoryMetrics
from .transitions import Transitions
from .volume import Volume

Expand Down Expand Up @@ -613,6 +614,12 @@ def transitions_between_sites(
site_inner_fraction=site_inner_fraction,
)

def metrics(self) -> TrajectoryMetrics:
"""See [gemdat.TrajectoryMetrics][] for more info."""
from .metrics import TrajectoryMetrics

return TrajectoryMetrics(trajectory=self)

@plot_backend
def plot_displacement_per_atom(self, *, module, **kwargs):
"""See [gemdat.plots.displacement_per_atom][] for more info."""
Expand Down
4 changes: 2 additions & 2 deletions src/gemdat/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pymatgen.core import Structure

from .caching import weak_lru_cache
from .simulation_metrics import SimulationMetrics
from .metrics import TrajectoryMetrics
from .utils import bfill, ffill, integer_remap

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -108,7 +108,7 @@ def from_trajectory(
diff_trajectory = trajectory.filter(floating_specie)

if site_radius is None:
vibration_amplitude = SimulationMetrics(diff_trajectory).vibration_amplitude()
vibration_amplitude = TrajectoryMetrics(diff_trajectory).vibration_amplitude()

site_radius = _compute_site_radius(
trajectory=trajectory,
Expand Down
Loading

0 comments on commit 493612e

Please sign in to comment.