diff --git a/src/graphnet/models/graphs/edges/__init__.py b/src/graphnet/models/graphs/edges/__init__.py index 7da8baa7c..40c8bbeab 100644 --- a/src/graphnet/models/graphs/edges/__init__.py +++ b/src/graphnet/models/graphs/edges/__init__.py @@ -5,3 +5,4 @@ and their features. """ from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges +from .minkowski import MinkowskiKNNEdges diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py new file mode 100644 index 000000000..5d1134ec5 --- /dev/null +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -0,0 +1,98 @@ +"""Module containing EdgeDefinitions based on the Minkowski Metric.""" +from typing import Optional, List + +import torch +from torch_geometric.data import Data +from torch_geometric.utils import to_dense_batch +from graphnet.models.graphs.edges.edges import EdgeDefinition + + +def compute_minkowski_distance_mat( + x: torch.Tensor, + y: torch.Tensor, + c: float, + space_coords: Optional[List[int]] = None, + time_coord: Optional[int] = 3, +) -> torch.Tensor: + """Compute all pairwise Minkowski distances. + + Args: + x: First tensor of shape (n, d). + y: Second tensor of shape (m, d). + c: Speed of light, in scaled units. + space_coords: Indices of space coordinates. + time_coord: Index of time coordinate. + + Returns: Matrix of shape (n, m) of all pairwise Minkowski distances. + """ + space_coords = space_coords or [0, 1, 2] + assert x.dim() == 2, "x must be 2-dimensional" + assert y.dim() == 2, "x must be 2-dimensional" + dist = x[:, None] - y[None, :] + pos = dist[:, :, space_coords] + time = dist[:, :, time_coord] * c + return (pos**2).sum(dim=-1) - time**2 + + +class MinkowskiKNNEdges(EdgeDefinition): + """Builds edges between most light-like separated.""" + + def __init__( + self, + nb_nearest_neighbours: int, + c: float, + time_like_weight: float = 1.0, + space_coords: Optional[List[int]] = None, + time_coord: Optional[int] = 3, + ): + """Initialize MinkowskiKNNEdges. + + Args: + nb_nearest_neighbours: Number of neighbours to connect to. + c: Speed of light, in scaled units. + time_like_weight: Preference to time-like over space-like edges. + Scales time_like distances by this value, before finding + nearest neighbours. + space_coords: Coordinates of x, y, z. + time_coord: Coordinate of time. + """ + super().__init__(name=__name__, class_name=self.__class__.__name__) + self.nb_nearest_neighbours = nb_nearest_neighbours + self.c = c + self.time_like_weight = time_like_weight + self.space_coords = space_coords or [0, 1, 2] + self.time_coord = time_coord + + def _construct_edges(self, graph: Data) -> Data: + x, mask = to_dense_batch(graph.x, graph.batch) + count = 0 + row = [] + col = [] + for batch in range(x.shape[0]): + distance_mat = compute_minkowski_distance_mat( + x_masked := x[batch][mask[batch]], + x_masked, + self.c, + self.space_coords, + self.time_coord, + ) + num_points = x_masked.shape[0] + num_edges = min(self.nb_nearest_neighbours, num_points) + col += [ + c + for c in range(num_points) + for _ in range(count, count + num_edges) + ] + distance_mat[distance_mat < 0] *= -self.time_like_weight + distance_mat += ( + torch.eye(distance_mat.shape[0]) * 1e9 + ) # self-loops + distance_sorted = distance_mat.argsort(dim=1) + distance_sorted += count # offset by previous events + row += distance_sorted[:num_edges].flatten().tolist() + count += num_points + + graph.edge_index = torch.tensor( + [row, col], dtype=torch.long, device=graph.x.device + ) + return graph diff --git a/tests/models/test_minkowski.py b/tests/models/test_minkowski.py new file mode 100644 index 000000000..af66196cf --- /dev/null +++ b/tests/models/test_minkowski.py @@ -0,0 +1,160 @@ +"""Unit tests for minkowski based edges.""" +import pytest +import torch +from torch_geometric.data.data import Data + +from graphnet.models.graphs.edges import KNNEdges, MinkowskiKNNEdges +from graphnet.models.graphs.edges.minkowski import ( + compute_minkowski_distance_mat, +) + + +def test_compute_minkowski_distance_mat() -> None: + """Testing the computation of the Minkowski distance matrix.""" + vec1 = torch.tensor( + [ + [ + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 1.0, + 1.0, + ], + [ + 1.0, + 0.0, + 0.0, + 1.0, + ], + [ + 1.0, + 0.0, + 1.0, + 2.0, + ], + ] + ) + vec2 = torch.tensor( + [ + [ + 0.0, + 0.0, + 0.0, + -1.0, + ], + [ + 1.0, + 1.0, + 1.0, + 0.0, + ], + ] + ) + expected11 = torch.tensor( + [ + [ + 0.0, + 0.0, + 0.0, + -2.0, + ], + [ + 0.0, + 0.0, + 2.0, + 0.0, + ], + [ + 0.0, + 2.0, + 0.0, + 0.0, + ], + [ + -2.0, + 0.0, + 0.0, + 0.0, + ], + ] + ) + expected12 = torch.tensor( + [[-1.0, 3.0], [-3.0, 1.0], [-3.0, 1.0], [-7.0, -3.0]] + ) + expected22 = torch.tensor( + [ + [0.0, 2.0], + [2.0, 0.0], + ] + ) + mat11 = compute_minkowski_distance_mat(vec1, vec1, c=1.0) + mat12 = compute_minkowski_distance_mat(vec1, vec2, c=1.0) + mat22 = compute_minkowski_distance_mat(vec2, vec2, c=1.0) + + assert torch.allclose(mat11, expected11) + assert torch.allclose(mat12, expected12) + assert torch.allclose(mat22, expected22) + + +def test_minkowski_knn_edges() -> None: + """Testing the minkowski knn edge definition.""" + data = Data( + x=torch.tensor( + [ + [ + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 1.0, + 1.0, + ], + [ + 1.0, + 0.0, + 0.0, + 1.0, + ], + [ + 1.0, + 0.0, + 1.0, + 2.0, + ], + ] + ) + ) + edge_index = MinkowskiKNNEdges( + nb_nearest_neighbours=2, + c=1.0, + )(data).edge_index + expected = torch.tensor( + [ + [1, 2, 0, 3, 0, 3, 1, 2], + [0, 0, 1, 1, 2, 2, 3, 3], + ] + ) + assert torch.allclose(edge_index[1], expected[1]) + + # Allow for "permutation of connections" in edge_index[1] + assert torch.allclose( + edge_index[0, [0, 1]], expected[0, [0, 1]] + ) or torch.allclose(edge_index[1, [0, 1]], expected[1, [1, 0]]) + assert torch.allclose( + edge_index[0, [2, 3]], expected[0, [2, 3]] + ) or torch.allclose(edge_index[1, [2, 3]], expected[1, [3, 2]]) + assert torch.allclose( + edge_index[0, [4, 5]], expected[0, [4, 5]] + ) or torch.allclose(edge_index[1, [4, 5]], expected[1, [5, 4]]) + assert torch.allclose( + edge_index[0, [6, 7]], expected[0, [6, 7]] + ) or torch.allclose(edge_index[1, [6, 7]], expected[1, [7, 6]])