diff --git a/topomodelx/nn/cell/ccxn.py b/topomodelx/nn/cell/ccxn.py index b6e3e042..9848df73 100644 --- a/topomodelx/nn/cell/ccxn.py +++ b/topomodelx/nn/cell/ccxn.py @@ -20,6 +20,8 @@ class CCXN(torch.nn.Module): Number of CCXN layers. att : bool Whether to use attention. + **kwargs : optional + Additional arguments CCXNLayer. References ---------- @@ -36,6 +38,7 @@ def __init__( in_channels_2, n_layers=2, att=False, + **kwargs, ): super().__init__() @@ -45,11 +48,12 @@ def __init__( in_channels_1=in_channels_1, in_channels_2=in_channels_2, att=att, + **kwargs, ) for _ in range(n_layers) ) - def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2): + def forward(self, x_0, x_1, adjacency_0, incidence_2_t): """Forward computation through layers. Parameters @@ -58,9 +62,9 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2): Input features on the nodes (0-cells). x_1 : torch.Tensor, shape = (n_edges, in_channels_1) Input features on the edges (1-cells). - neighborhood_0_to_0 : torch.Tensor, shape = (n_nodes, n_nodes) + adjacency_0 : torch.Tensor, shape = (n_nodes, n_nodes) Adjacency matrix of rank 0 (up). - neighborhood_1_to_2 : torch.Tensor, shape = (n_faces, n_edges) + incidence_2_t : torch.Tensor, shape = (n_faces, n_edges) Transpose of boundary matrix of rank 2. Returns @@ -73,5 +77,5 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2): Final hidden states of the faces (2-cells). """ for layer in self.layers: - x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2) + x_0, x_1, x_2 = layer(x_0, x_1, adjacency_0, incidence_2_t) return (x_0, x_1, x_2) diff --git a/topomodelx/nn/cell/ccxn_layer.py b/topomodelx/nn/cell/ccxn_layer.py index 248fd678..e596998b 100644 --- a/topomodelx/nn/cell/ccxn_layer.py +++ b/topomodelx/nn/cell/ccxn_layer.py @@ -25,10 +25,8 @@ class CCXNLayer(torch.nn.Module): Dimension of input features on faces (2-cells). att : bool, default=False Whether to use attention. - - Notes - ----- - This is the architecture proposed for entire complex classification. + **kwargs : optional + Additional arguments for the modules of the CCXN layer. References ---------- @@ -45,7 +43,7 @@ class CCXNLayer(torch.nn.Module): """ def __init__( - self, in_channels_0, in_channels_1, in_channels_2, att: bool = False + self, in_channels_0, in_channels_1, in_channels_2, att: bool = False, **kwargs ) -> None: super().__init__() self.conv_0_to_0 = Conv( @@ -55,7 +53,7 @@ def __init__( in_channels=in_channels_1, out_channels=in_channels_2, att=att ) - def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None): + def forward(self, x_0, x_1, adjacency_0, incidence_2_t, x_2=None): r"""Forward pass. The forward pass was initially proposed in [1]_. @@ -97,9 +95,9 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None): Input features on the nodes of the cell complex. x_1 : torch.Tensor, shape = (n_1_cells, channels) Input features on the edges of the cell complex. - neighborhood_0_to_0 : torch.sparse, shape = (n_0_cells, n_0_cells) + adjacency_0 : torch.sparse, shape = (n_0_cells, n_0_cells) Neighborhood matrix mapping nodes to nodes (A_0_up). - neighborhood_1_to_2 : torch.sparse, shape = (n_2_cells, n_1_cells) + incidence_2_t : torch.sparse, shape = (n_2_cells, n_1_cells) Neighborhood matrix mapping edges to faces (B_2^T). x_2 : torch.Tensor, shape = (n_2_cells, channels) Input features on the faces of the cell complex. @@ -113,10 +111,10 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None): x_0 = torch.nn.functional.relu(x_0) x_1 = torch.nn.functional.relu(x_1) - x_0 = self.conv_0_to_0(x_0, neighborhood_0_to_0) + x_0 = self.conv_0_to_0(x_0, adjacency_0) x_0 = torch.nn.functional.relu(x_0) - x_2 = self.conv_1_to_2(x_1, neighborhood_1_to_2, x_2) + x_2 = self.conv_1_to_2(x_1, incidence_2_t, x_2) x_2 = torch.nn.functional.relu(x_2) return x_0, x_1, x_2 diff --git a/tutorials/cell/ccxn_train.ipynb b/tutorials/cell/ccxn_train.ipynb index dc86c739..a58c1d41 100644 --- a/tutorials/cell/ccxn_train.ipynb +++ b/tutorials/cell/ccxn_train.ipynb @@ -269,10 +269,8 @@ " self.lin_1 = torch.nn.Linear(in_channels_1, num_classes)\n", " self.lin_2 = torch.nn.Linear(in_channels_2, num_classes)\n", "\n", - " def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):\n", - " x_0, x_1, x_2 = self.base_model(\n", - " x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2\n", - " )\n", + " def forward(self, x_0, x_1, adjacency_0, incidence_2_t):\n", + " x_0, x_1, x_2 = self.base_model(x_0, x_1, adjacency_0, incidence_2_t)\n", " x_0 = self.lin_0(x_0)\n", " x_1 = self.lin_1(x_1)\n", " x_2 = self.lin_2(x_2)\n", @@ -436,7 +434,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/gbg141/Documents/TopoProjectX/TopoModelX/venv_modelx/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", + "/Users/gbg141/Documents/Projects/TopoModelX/venv_tmx/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", " return F.mse_loss(input, target, reduction=self.reduction)\n" ] },