Skip to content

Commit

Permalink
CCXN Updated
Browse files Browse the repository at this point in the history
  • Loading branch information
gbg141 committed Apr 4, 2024
1 parent 8d1c1d4 commit bf30e09
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
12 changes: 8 additions & 4 deletions topomodelx/nn/cell/ccxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class CCXN(torch.nn.Module):
Number of CCXN layers.
att : bool
Whether to use attention.
**kwargs : optional
Additional arguments CCXNLayer.
References
----------
Expand All @@ -36,6 +38,7 @@ def __init__(
in_channels_2,
n_layers=2,
att=False,
**kwargs,
):
super().__init__()

Expand All @@ -45,11 +48,12 @@ def __init__(
in_channels_1=in_channels_1,
in_channels_2=in_channels_2,
att=att,
**kwargs,
)
for _ in range(n_layers)
)

def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
def forward(self, x_0, x_1, adjacency_0, incidence_2_t):
"""Forward computation through layers.
Parameters
Expand All @@ -58,9 +62,9 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
Input features on the nodes (0-cells).
x_1 : torch.Tensor, shape = (n_edges, in_channels_1)
Input features on the edges (1-cells).
neighborhood_0_to_0 : torch.Tensor, shape = (n_nodes, n_nodes)
adjacency_0 : torch.Tensor, shape = (n_nodes, n_nodes)
Adjacency matrix of rank 0 (up).
neighborhood_1_to_2 : torch.Tensor, shape = (n_faces, n_edges)
incidence_2_t : torch.Tensor, shape = (n_faces, n_edges)
Transpose of boundary matrix of rank 2.
Returns
Expand All @@ -73,5 +77,5 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
Final hidden states of the faces (2-cells).
"""
for layer in self.layers:
x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2)
x_0, x_1, x_2 = layer(x_0, x_1, adjacency_0, incidence_2_t)
return (x_0, x_1, x_2)
18 changes: 8 additions & 10 deletions topomodelx/nn/cell/ccxn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ class CCXNLayer(torch.nn.Module):
Dimension of input features on faces (2-cells).
att : bool, default=False
Whether to use attention.
Notes
-----
This is the architecture proposed for entire complex classification.
**kwargs : optional
Additional arguments for the modules of the CCXN layer.
References
----------
Expand All @@ -45,7 +43,7 @@ class CCXNLayer(torch.nn.Module):
"""

def __init__(
self, in_channels_0, in_channels_1, in_channels_2, att: bool = False
self, in_channels_0, in_channels_1, in_channels_2, att: bool = False, **kwargs
) -> None:
super().__init__()
self.conv_0_to_0 = Conv(
Expand All @@ -55,7 +53,7 @@ def __init__(
in_channels=in_channels_1, out_channels=in_channels_2, att=att
)

def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None):
def forward(self, x_0, x_1, adjacency_0, incidence_2_t, x_2=None):
r"""Forward pass.
The forward pass was initially proposed in [1]_.
Expand Down Expand Up @@ -97,9 +95,9 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None):
Input features on the nodes of the cell complex.
x_1 : torch.Tensor, shape = (n_1_cells, channels)
Input features on the edges of the cell complex.
neighborhood_0_to_0 : torch.sparse, shape = (n_0_cells, n_0_cells)
adjacency_0 : torch.sparse, shape = (n_0_cells, n_0_cells)
Neighborhood matrix mapping nodes to nodes (A_0_up).
neighborhood_1_to_2 : torch.sparse, shape = (n_2_cells, n_1_cells)
incidence_2_t : torch.sparse, shape = (n_2_cells, n_1_cells)
Neighborhood matrix mapping edges to faces (B_2^T).
x_2 : torch.Tensor, shape = (n_2_cells, channels)
Input features on the faces of the cell complex.
Expand All @@ -113,10 +111,10 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None):
x_0 = torch.nn.functional.relu(x_0)
x_1 = torch.nn.functional.relu(x_1)

x_0 = self.conv_0_to_0(x_0, neighborhood_0_to_0)
x_0 = self.conv_0_to_0(x_0, adjacency_0)
x_0 = torch.nn.functional.relu(x_0)

x_2 = self.conv_1_to_2(x_1, neighborhood_1_to_2, x_2)
x_2 = self.conv_1_to_2(x_1, incidence_2_t, x_2)
x_2 = torch.nn.functional.relu(x_2)

return x_0, x_1, x_2
8 changes: 3 additions & 5 deletions tutorials/cell/ccxn_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,8 @@
" self.lin_1 = torch.nn.Linear(in_channels_1, num_classes)\n",
" self.lin_2 = torch.nn.Linear(in_channels_2, num_classes)\n",
"\n",
" def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):\n",
" x_0, x_1, x_2 = self.base_model(\n",
" x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2\n",
" )\n",
" def forward(self, x_0, x_1, adjacency_0, incidence_2_t):\n",
" x_0, x_1, x_2 = self.base_model(x_0, x_1, adjacency_0, incidence_2_t)\n",
" x_0 = self.lin_0(x_0)\n",
" x_1 = self.lin_1(x_1)\n",
" x_2 = self.lin_2(x_2)\n",
Expand Down Expand Up @@ -436,7 +434,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/gbg141/Documents/TopoProjectX/TopoModelX/venv_modelx/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
"/Users/gbg141/Documents/Projects/TopoModelX/venv_tmx/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
" return F.mse_loss(input, target, reduction=self.reduction)\n"
]
},
Expand Down

0 comments on commit bf30e09

Please sign in to comment.