From 12d6794b74ff628f8509ec081b4c2e4e4f03e641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 23 Jan 2025 15:28:12 -0800 Subject: [PATCH] Update SimplicialDnDLifting to work with new design --- .../simplicial/test_SimplicialDnDLifting.py | 57 +++++++++++---- .../liftings/graph2simplicial/dnd_lifting.py | 72 ++++++++++++++----- 2 files changed, 97 insertions(+), 32 deletions(-) diff --git a/test/transforms/liftings/simplicial/test_SimplicialDnDLifting.py b/test/transforms/liftings/simplicial/test_SimplicialDnDLifting.py index 7537a22b..bd3344ca 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialDnDLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialDnDLifting.py @@ -4,8 +4,9 @@ import torch -from modules.data.utils.utils import load_manual_graph -from modules.transforms.liftings.graph2simplicial.dnd_lifting import ( +from topobenchmark.data.utils import load_manual_graph +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, SimplicialDnDLifting, ) @@ -18,8 +19,16 @@ def setup_method(self): self.data = load_manual_graph() # Initialise the SimplicialCliqueLifting class - self.lifting_signed = SimplicialDnDLifting(complex_dim=6, signed=True) - self.lifting_unsigned = SimplicialDnDLifting(complex_dim=6, signed=False) + lifting_map = SimplicialDnDLifting(complex_dim=6) + + self.lifting_signed = Graph2SimplicialLiftingTransform( + lifting=lifting_map, + signed=True, + ) + self.lifting_unsigned = Graph2SimplicialLiftingTransform( + lifting=lifting_map, + signed=False, + ) random.seed(42) @@ -47,10 +56,14 @@ def test_lift_topology(self): ] ) - U, S_unsigned, V = torch.svd(lifted_data_unsigned.incidence_1.to_dense()) + U, S_unsigned, V = torch.svd( + lifted_data_unsigned.incidence_1.to_dense() + ) U, S_signed, V = torch.svd(lifted_data_signed.incidence_1.to_dense()) assert torch.allclose( - expected_incidence_1_singular_values_unsigned, S_unsigned, atol=1.0e-04 + expected_incidence_1_singular_values_unsigned, + S_unsigned, + atol=1.0e-04, ), "Something is wrong with unsigned incidence_1 (nodes to edges)." assert torch.allclose( expected_incidence_1_singular_values_signed, S_signed, atol=1.0e-04 @@ -122,10 +135,14 @@ def test_lift_topology(self): ] ) - U, S_unsigned, V = torch.svd(lifted_data_unsigned.incidence_2.to_dense()) + U, S_unsigned, V = torch.svd( + lifted_data_unsigned.incidence_2.to_dense() + ) U, S_signed, V = torch.svd(lifted_data_signed.incidence_2.to_dense()) assert torch.allclose( - expected_incidence_2_singular_values_unsigned, S_unsigned, atol=1.0e-04 + expected_incidence_2_singular_values_unsigned, + S_unsigned, + atol=1.0e-04, ), "Something is wrong with unsigned incidence_2 (edges to triangles)." assert torch.allclose( expected_incidence_2_singular_values_signed, S_signed, atol=1.0e-04 @@ -253,10 +270,14 @@ def test_lift_topology(self): ] ) - U, S_unsigned, V = torch.svd(lifted_data_unsigned.incidence_3.to_dense()) + U, S_unsigned, V = torch.svd( + lifted_data_unsigned.incidence_3.to_dense() + ) U, S_signed, V = torch.svd(lifted_data_signed.incidence_3.to_dense()) assert torch.allclose( - expected_incidence_3_singular_values_unsigned, S_unsigned, atol=1.0e-04 + expected_incidence_3_singular_values_unsigned, + S_unsigned, + atol=1.0e-04, ), "Something is wrong with unsigned incidence_3 (edges to tetrahedrons)." assert torch.allclose( expected_incidence_3_singular_values_signed, S_signed, atol=1.0e-04 @@ -384,10 +405,14 @@ def test_lift_topology(self): ] ) - U, S_unsigned, V = torch.svd(lifted_data_unsigned.incidence_4.to_dense()) + U, S_unsigned, V = torch.svd( + lifted_data_unsigned.incidence_4.to_dense() + ) U, S_signed, V = torch.svd(lifted_data_signed.incidence_4.to_dense()) assert torch.allclose( - expected_incidence_4_singular_values_unsigned, S_unsigned, atol=1.0e-04 + expected_incidence_4_singular_values_unsigned, + S_unsigned, + atol=1.0e-04, ), "Something is wrong with unsigned incidence_4." assert torch.allclose( expected_incidence_4_singular_values_signed, S_signed, atol=1.0e-04 @@ -459,10 +484,14 @@ def test_lift_topology(self): ] ) - U, S_unsigned, V = torch.svd(lifted_data_unsigned.incidence_5.to_dense()) + U, S_unsigned, V = torch.svd( + lifted_data_unsigned.incidence_5.to_dense() + ) U, S_signed, V = torch.svd(lifted_data_signed.incidence_5.to_dense()) assert torch.allclose( - expected_incidence_5_singular_values_unsigned, S_unsigned, atol=1.0e-04 + expected_incidence_5_singular_values_unsigned, + S_unsigned, + atol=1.0e-04, ), "Something is wrong with unsigned incidence_5." assert torch.allclose( expected_incidence_5_singular_values_signed, S_signed, atol=1.0e-04 diff --git a/topobenchmark/transforms/liftings/graph2simplicial/dnd_lifting.py b/topobenchmark/transforms/liftings/graph2simplicial/dnd_lifting.py index 37c819e3..28e61d9a 100755 --- a/topobenchmark/transforms/liftings/graph2simplicial/dnd_lifting.py +++ b/topobenchmark/transforms/liftings/graph2simplicial/dnd_lifting.py @@ -1,46 +1,76 @@ +"""This modules implements the DnD lifting. + +The DnD lifting introduces a novel, non-deterministic, and somewhat +lighthearted approach to transforming graphs into simplicial complexes. +Inspired by the game mechanics of Dungeons & Dragons (D&D), this method +incorporates elements of randomness and character attributes to determine +the formation of simplices. This lifting aims to add an element of whimsy +and unpredictability to the graph-to-simplicial complex transformation +process, while still providing a serious and fully functional methodology. + +Each vertex in the graph is assigned the following attributes: degree centrality, +clustering coefficient, closeness centrality, eigenvector centrality, +betweenness centrality, and pagerank. +Simplices are created based on the neighborhood within a distance determined by +a D20 dice roll + the attribute value. The randomness from the dice roll, +modified by the node's attributes, ensures a non-deterministic process for each +lifting. The dice roll is influenced by different attributes based on the level +of the simplex being formed. The different attributes for different levels of +simplices are used in the order shown above, based on the role of those attributes +in the context of the graph structure. +""" + import random from itertools import combinations import networkx as nx +import torch from toponetx.classes import SimplicialComplex -from torch_geometric.data import Data -from modules.transforms.liftings.graph2simplicial.base import Graph2SimplicialLifting +from topobenchmark.transforms.liftings import LiftingMap + +class SimplicialDnDLifting(LiftingMap): + """Lifts graphs to simplicial complex domain -class SimplicialDnDLifting(Graph2SimplicialLifting): - r"""Lifts graphs to simplicial complex domain using a Dungeons & Dragons inspired system. + Uses a Dungeons & Dragons inspired system. Parameters ---------- - **kwargs : optional - Additional arguments for the class. + complex_dim : int + Dimension of the subcomplex. """ - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self, complex_dim=2): + super().__init__() + self.complex_dim = complex_dim - def lift_topology(self, data: Data) -> dict: - r"""Lifts the topology of a graph to a simplicial complex using Dungeons & Dragons (D&D) inspired mechanics. + def lift(self, domain): + """Lifts the topology of a graph to a simplicial complex. + + Uses Dungeons & Dragons (D&D) inspired mechanics. Parameters ---------- - data : Data - The input data to be lifted. + domain : nx.Graph + Graph to be lifted. Returns ------- - dict - The lifted topology. + toponetx.SimplicialComplex + Lifted simplicial complex. """ - graph = self._generate_graph_from_data(data) + graph = domain + simplicial_complex = SimplicialComplex() characters = self._assign_attributes(graph) simplices = [set() for _ in range(2, self.complex_dim + 1)] for node in graph.nodes: - simplicial_complex.add_node(node, features=data.x[node]) + simplicial_complex.add_node( + node, x=torch.tensor(domain.nodes[node]["x"]) + ) for node in graph.nodes: character = characters[node] @@ -57,7 +87,10 @@ def lift_topology(self, data: Data) -> dict: for set_k_simplices in simplices: simplicial_complex.add_simplices_from(list(set_k_simplices)) - return self._get_lifted_topology(simplicial_complex, graph) + # because ComplexData pads unexisting dimensions with empty matrices + simplicial_complex.practical_dim = self.complex_dim + + return simplicial_complex def _assign_attributes(self, graph): """Assign D&D-inspired attributes based on node properties.""" @@ -81,7 +114,10 @@ def _assign_attributes(self, graph): return attributes def _roll_dice(self, attributes, k): - """Simulate a D20 dice roll influenced by node attributes where a different attribute is used based on the simplex level.""" + """Simulate a D20 dice roll influenced by node attributes. + + A different attribute is used based on the simplex level. + """ attribute = None if k == 1: