From 3652f91018ddca6955787444592ddf84bfbf521a Mon Sep 17 00:00:00 2001
From: gbg141 <guille_gbg@hotmail.com>
Date: Tue, 14 May 2024 20:38:22 +0200
Subject: [PATCH] CAN Attention: standard management of heads

---
 test/nn/cell/test_can.py        |   6 +-
 test/nn/cell/test_can_layer.py  |  13 ++--
 topomodelx/nn/cell/can.py       |   4 +-
 topomodelx/nn/cell/can_layer.py |  49 ++++++++--------
 tutorials/cell/can_train.ipynb  | 101 ++++++++++++++++----------------
 5 files changed, 86 insertions(+), 87 deletions(-)

diff --git a/test/nn/cell/test_can.py b/test/nn/cell/test_can.py
index 43534402..ecbe5ded 100644
--- a/test/nn/cell/test_can.py
+++ b/test/nn/cell/test_can.py
@@ -17,7 +17,7 @@ def test_forward(self):
             in_channels_1=2,
             out_channels=2,
             dropout=0.5,
-            heads=1,
+            heads=2,
             n_layers=2,
             att_lift=False,
             pooling=True,
@@ -31,8 +31,8 @@ def test_forward(self):
         ).to_sparse()
 
         x_0, x_1 = (
-            torch.tensor(x_0).float().to(device),
-            torch.tensor(x_1).float().to(device),
+            x_0.clone().detach().float().to(device),
+            x_1.clone().detach().float().to(device),
         )
         adjacency_1 = adjacency_1.float().to(device)
         adjacency_2 = adjacency_1.float().to(device)
diff --git a/test/nn/cell/test_can_layer.py b/test/nn/cell/test_can_layer.py
index cf981ff0..d8080c08 100644
--- a/test/nn/cell/test_can_layer.py
+++ b/test/nn/cell/test_can_layer.py
@@ -13,7 +13,7 @@ class TestCANLayer:
     def test_forward(self):
         """Test the forward method of CANLayer."""
         in_channels = 7
-        out_channels = 64
+        out_channels = 66
         dropout_values = [0.5, 0.7]
         heads_values = [1, 3]
         concat_values = [True, False]
@@ -65,12 +65,12 @@ def test_forward(self):
             )
             x_out = can_layer.forward(x_1, lower_neighborhood, upper_neighborhood)
             if concat:
-                assert x_out.shape == (n_cells, out_channels * heads)
-            else:
                 assert x_out.shape == (n_cells, out_channels)
+            else:
+                assert x_out.shape == (n_cells, out_channels // heads)
 
         # Test if there are no non-zero values in the neighborhood
-        heads = 1
+        heads = 3
         concat_list = [True, False]
         skip_connection = True
 
@@ -79,6 +79,7 @@ def test_forward(self):
                 can_layer = CANLayer(
                     in_channels=in_channels,
                     out_channels=out_channels,
+                    heads=heads,
                     concat=concat,
                     skip_connection=skip_connection,
                     version=version,
@@ -89,9 +90,9 @@ def test_forward(self):
                     torch.zeros_like(upper_neighborhood),
                 )
                 if concat:
-                    assert x_out.shape == (n_cells, out_channels * heads)
-                else:
                     assert x_out.shape == (n_cells, out_channels)
+                else:
+                    assert x_out.shape == (n_cells, out_channels // heads)
 
     def test_reset_parameters(self):
         """Test the reset_parameters method of CANLayer."""
diff --git a/topomodelx/nn/cell/can.py b/topomodelx/nn/cell/can.py
index a9e80a3e..a7f00ed1 100644
--- a/topomodelx/nn/cell/can.py
+++ b/topomodelx/nn/cell/can.py
@@ -94,7 +94,7 @@ def __init__(
         for _ in range(n_layers - 1):
             layers.append(
                 CANLayer(
-                    in_channels=out_channels * heads,
+                    in_channels=out_channels,
                     out_channels=out_channels,
                     dropout=dropout,
                     heads=heads,
@@ -110,7 +110,7 @@ def __init__(
                 layers.append(
                     PoolLayer(
                         k_pool=k_pool,
-                        in_channels_0=out_channels * heads,
+                        in_channels_0=out_channels,
                         signal_pool_activation=torch.nn.Sigmoid(),
                         readout=True,
                         **kwargs,
diff --git a/topomodelx/nn/cell/can_layer.py b/topomodelx/nn/cell/can_layer.py
index 912a0298..bae55f4f 100644
--- a/topomodelx/nn/cell/can_layer.py
+++ b/topomodelx/nn/cell/can_layer.py
@@ -439,9 +439,9 @@ def __init__(
         self.dropout = dropout
         self.add_self_loops = add_self_loops
 
-        self.lin = torch.nn.Linear(in_channels, heads * out_channels, bias=False)
-        self.att_weight_src = Parameter(torch.Tensor(1, heads, out_channels))
-        self.att_weight_dst = Parameter(torch.Tensor(1, heads, out_channels))
+        self.lin = torch.nn.Linear(in_channels, out_channels, bias=False)
+        self.att_weight_src = Parameter(torch.Tensor(1, heads, out_channels // heads))
+        self.att_weight_dst = Parameter(torch.Tensor(1, heads, out_channels // heads))
 
         self.reset_parameters()
 
@@ -468,7 +468,7 @@ def message(self, x_source):
         """
         # Compute the linear transformation on the source features
         x_message = self.lin(x_source).view(
-            -1, self.heads, self.out_channels
+            -1, self.heads, self.out_channels // self.heads
         )  # (n_k_cells, H, C)
 
         # compute the source and target messages
@@ -534,12 +534,13 @@ def forward(self, x_source, neighborhood):
         # If there are no non-zero values in the neighborhood, then the neighborhood is empty. -> return zero tensor
         if not neighborhood.values().nonzero().size(0) > 0 and self.concat:
             return torch.zeros(
-                (x_source.shape[0], self.out_channels * self.heads),
+                (x_source.shape[0], self.out_channels),
                 device=x_source.device,
             )  # (n_k_cells, H * C)
         if not neighborhood.values().nonzero().size(0) > 0 and not self.concat:
             return torch.zeros(
-                (x_source.shape[0], self.out_channels), device=x_source.device
+                (x_source.shape[0], self.out_channels // self.heads),
+                device=x_source.device,
             )  # (n_k_cells, C)
 
         # Add self-loops to the neighborhood matrix if necessary
@@ -559,9 +560,7 @@ def forward(self, x_source, neighborhood):
 
         # if concat true, concatenate the messages for each head. Otherwise, average the messages for each head.
         if self.concat:
-            return aggregated_message.view(
-                -1, self.heads * self.out_channels
-            )  # (n_k_cells, H * C)
+            return aggregated_message.view(-1, self.out_channels)  # (n_k_cells, H * C)
 
         return aggregated_message.mean(dim=1)  # (n_k_cells, C)
 
@@ -613,7 +612,7 @@ def __init__(
         heads: int,
         concat: bool,
         att_activation: torch.nn.Module,
-        add_self_loops: bool = False,
+        add_self_loops: bool = True,
         aggr_func: Literal["sum", "mean", "add"] = "sum",
         initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform",
         share_weights: bool = False,
@@ -634,17 +633,13 @@ def __init__(
 
         if share_weights:
             self.lin_src = self.lin_dst = torch.nn.Linear(
-                in_channels, heads * out_channels, bias=False
+                in_channels, out_channels, bias=False
             )
         else:
-            self.lin_src = torch.nn.Linear(
-                in_channels, heads * out_channels, bias=False
-            )
-            self.lin_dst = torch.nn.Linear(
-                in_channels, heads * out_channels, bias=False
-            )
+            self.lin_src = torch.nn.Linear(in_channels, out_channels, bias=False)
+            self.lin_dst = torch.nn.Linear(in_channels, out_channels, bias=False)
 
-        self.att_weight = Parameter(torch.Tensor(1, heads, out_channels))
+        self.att_weight = Parameter(torch.Tensor(1, heads, out_channels // heads))
 
         self.reset_parameters()
 
@@ -671,12 +666,12 @@ def message(self, x_source):
         """
         # Compute the linear transformation on the source features
         x_src_message = self.lin_src(x_source).view(
-            -1, self.heads, self.out_channels
+            -1, self.heads, self.out_channels // self.heads
         )  # (n_k_cells, H, C)
 
         # Compute the linear transformation on the source features
         x_dst_message = self.lin_dst(x_source).view(
-            -1, self.heads, self.out_channels
+            -1, self.heads, self.out_channels // self.heads
         )  # (n_k_cells, H, C)
 
         # Get the source and target projections of the neighborhood
@@ -737,12 +732,13 @@ def forward(self, x_source, neighborhood):
         # If there are no non-zero values in the neighborhood, then the neighborhood is empty. -> return zero tensor
         if not neighborhood.values().nonzero().size(0) > 0 and self.concat:
             return torch.zeros(
-                (x_source.shape[0], self.out_channels * self.heads),
+                (x_source.shape[0], self.out_channels),
                 device=x_source.device,
             )  # (n_k_cells, H * C)
         if not neighborhood.values().nonzero().size(0) > 0 and not self.concat:
             return torch.zeros(
-                (x_source.shape[0], self.out_channels), device=x_source.device
+                (x_source.shape[0], self.out_channels // self.heads),
+                device=x_source.device,
             )  # (n_k_cells, C)
 
         # Add self-loops to the neighborhood matrix if necessary
@@ -762,9 +758,7 @@ def forward(self, x_source, neighborhood):
 
         # if concat true, concatenate the messages for each head. Otherwise, average the messages for each head.
         if self.concat:
-            return aggregated_message.view(
-                -1, self.heads * self.out_channels
-            )  # (n_k_cells, H * C)
+            return aggregated_message.view(-1, self.out_channels)  # (n_k_cells, H * C)
 
         return aggregated_message.mean(dim=1)  # (n_k_cells, C)
 
@@ -836,6 +830,9 @@ def __init__(
         assert in_channels > 0, ValueError("Number of input channels must be > 0")
         assert out_channels > 0, ValueError("Number of output channels must be > 0")
         assert heads > 0, ValueError("Number of heads must be > 0")
+        assert out_channels % heads == 0, ValueError(
+            "Number of output channels must be divisible by the number of heads"
+        )
         assert dropout >= 0.0 and dropout <= 1.0, ValueError("Dropout must be in [0,1]")
 
         # assert that shared weight is True only if version is v2
@@ -893,7 +890,7 @@ def __init__(
 
         # linear transformation
         if skip_connection:
-            out_channels = out_channels * heads if concat else out_channels
+            out_channels = out_channels if concat else out_channels // heads
             self.lin = Linear(in_channels, out_channels, bias=False)
             self.eps = 1 + 1e-6
 
diff --git a/tutorials/cell/can_train.ipynb b/tutorials/cell/can_train.ipynb
index 122dd622..0cc2cd63 100644
--- a/tutorials/cell/can_train.ipynb
+++ b/tutorials/cell/can_train.ipynb
@@ -104,21 +104,22 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 1,
    "metadata": {
     "ExecuteTime": {
      "end_time": "2023-05-31T09:06:36.009880829Z",
      "start_time": "2023-05-31T09:06:34.285257706Z"
-    }
+    },
+    "metadata": {}
    },
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "<torch._C.Generator at 0x11d4905d0>"
+       "<torch._C.Generator at 0x11971c5d0>"
       ]
      },
-     "execution_count": 3,
+     "execution_count": 1,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -146,12 +147,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 2,
    "metadata": {
     "ExecuteTime": {
      "end_time": "2023-05-31T09:13:53.006542411Z",
      "start_time": "2023-05-31T09:13:52.963074076Z"
-    }
+    },
+    "metadata": {}
    },
    "outputs": [
     {
@@ -186,23 +188,15 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 3,
    "metadata": {
     "ExecuteTime": {
      "end_time": "2023-05-31T09:13:55.279147916Z",
      "start_time": "2023-05-31T09:13:55.269057585Z"
-    }
+    },
+    "metadata": {}
    },
    "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip\n",
-      "Processing...\n",
-      "Done!\n"
-     ]
-    },
     {
      "name": "stdout",
      "output_type": "stream",
@@ -253,12 +247,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 4,
    "metadata": {
     "ExecuteTime": {
      "end_time": "2023-05-31T09:13:55.832585216Z",
      "start_time": "2023-05-31T09:13:55.815448708Z"
-    }
+    },
+    "metadata": {}
    },
    "outputs": [],
    "source": [
@@ -298,8 +293,10 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
-   "metadata": {},
+   "execution_count": 5,
+   "metadata": {
+    "metadata": {}
+   },
    "outputs": [],
    "source": [
     "class Network(torch.nn.Module):\n",
@@ -324,7 +321,7 @@
     "            n_layers=n_layers,\n",
     "            att_lift=att_lift,\n",
     "        )\n",
-    "        self.lin_0 = torch.nn.Linear(heads * out_channels, 128)\n",
+    "        self.lin_0 = torch.nn.Linear(out_channels, 128)\n",
     "        self.lin_1 = torch.nn.Linear(128, num_classes)\n",
     "\n",
     "    def forward(self, x_0, x_1, adjacency, down_laplacian, up_laplacian):\n",
@@ -338,12 +335,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 6,
    "metadata": {
     "ExecuteTime": {
      "end_time": "2023-05-31T09:13:56.672913561Z",
      "start_time": "2023-05-31T09:13:56.667986426Z"
-    }
+    },
+    "metadata": {}
    },
    "outputs": [],
    "source": [
@@ -378,12 +376,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 7,
    "metadata": {
     "ExecuteTime": {
      "end_time": "2023-05-31T09:19:40.411845803Z",
      "start_time": "2023-05-31T09:19:40.408861921Z"
-    }
+    },
+    "metadata": {}
    },
    "outputs": [
     {
@@ -398,35 +397,35 @@
        "      (0): CANLayer(\n",
        "        (lower_att): MultiHeadCellAttention(\n",
        "          (att_activation): LeakyReLU(negative_slope=0.2)\n",
-       "          (lin): Linear(in_features=11, out_features=64, bias=False)\n",
+       "          (lin): Linear(in_features=11, out_features=32, bias=False)\n",
        "        )\n",
        "        (upper_att): MultiHeadCellAttention(\n",
        "          (att_activation): LeakyReLU(negative_slope=0.2)\n",
-       "          (lin): Linear(in_features=11, out_features=64, bias=False)\n",
+       "          (lin): Linear(in_features=11, out_features=32, bias=False)\n",
        "        )\n",
-       "        (lin): Linear(in_features=11, out_features=64, bias=False)\n",
+       "        (lin): Linear(in_features=11, out_features=32, bias=False)\n",
        "        (aggregation): Aggregation()\n",
        "      )\n",
        "      (1): CANLayer(\n",
        "        (lower_att): MultiHeadCellAttention(\n",
        "          (att_activation): LeakyReLU(negative_slope=0.2)\n",
-       "          (lin): Linear(in_features=64, out_features=64, bias=False)\n",
+       "          (lin): Linear(in_features=32, out_features=32, bias=False)\n",
        "        )\n",
        "        (upper_att): MultiHeadCellAttention(\n",
        "          (att_activation): LeakyReLU(negative_slope=0.2)\n",
-       "          (lin): Linear(in_features=64, out_features=64, bias=False)\n",
+       "          (lin): Linear(in_features=32, out_features=32, bias=False)\n",
        "        )\n",
-       "        (lin): Linear(in_features=64, out_features=64, bias=False)\n",
+       "        (lin): Linear(in_features=32, out_features=32, bias=False)\n",
        "        (aggregation): Aggregation()\n",
        "      )\n",
        "    )\n",
        "  )\n",
-       "  (lin_0): Linear(in_features=64, out_features=128, bias=True)\n",
+       "  (lin_0): Linear(in_features=32, out_features=128, bias=True)\n",
        "  (lin_1): Linear(in_features=128, out_features=2, bias=True)\n",
        ")"
       ]
      },
-     "execution_count": 9,
+     "execution_count": 7,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -446,12 +445,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 8,
    "metadata": {
     "ExecuteTime": {
      "end_time": "2023-05-31T09:19:41.150933630Z",
      "start_time": "2023-05-31T09:19:41.146986990Z"
-    }
+    },
+    "metadata": {}
    },
    "outputs": [],
    "source": [
@@ -479,37 +479,38 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 9,
    "metadata": {
     "ExecuteTime": {
      "end_time": "2023-05-31T09:19:42.918836083Z",
      "start_time": "2023-05-31T09:19:42.114801039Z"
-    }
+    },
+    "metadata": {}
    },
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Epoch: 1 loss: 0.6159 Train_acc: 0.6947\n",
+      "Epoch: 1 loss: 0.6330 Train_acc: 0.6718\n",
       "Test_acc: 0.5965\n",
-      "Epoch: 2 loss: 0.6099 Train_acc: 0.6947\n",
+      "Epoch: 2 loss: 0.6116 Train_acc: 0.6947\n",
       "Test_acc: 0.5965\n",
-      "Epoch: 3 loss: 0.6035 Train_acc: 0.6947\n",
+      "Epoch: 3 loss: 0.6071 Train_acc: 0.6947\n",
       "Test_acc: 0.5965\n",
-      "Epoch: 4 loss: 0.5966 Train_acc: 0.7176\n",
-      "Test_acc: 0.6316\n",
-      "Epoch: 5 loss: 0.5909 Train_acc: 0.7252\n",
+      "Epoch: 4 loss: 0.6027 Train_acc: 0.6947\n",
+      "Test_acc: 0.5965\n",
+      "Epoch: 5 loss: 0.5974 Train_acc: 0.7099\n",
       "Test_acc: 0.6491\n",
-      "Epoch: 6 loss: 0.5983 Train_acc: 0.7099\n",
-      "Test_acc: 0.6316\n",
-      "Epoch: 7 loss: 0.5884 Train_acc: 0.7252\n",
+      "Epoch: 6 loss: 0.5911 Train_acc: 0.7252\n",
       "Test_acc: 0.6491\n",
-      "Epoch: 8 loss: 0.5909 Train_acc: 0.7176\n",
+      "Epoch: 7 loss: 0.5979 Train_acc: 0.7176\n",
+      "Test_acc: 0.6140\n",
+      "Epoch: 8 loss: 0.5826 Train_acc: 0.7252\n",
       "Test_acc: 0.6316\n",
-      "Epoch: 9 loss: 0.5818 Train_acc: 0.7252\n",
+      "Epoch: 9 loss: 0.5908 Train_acc: 0.7252\n",
       "Test_acc: 0.6316\n",
-      "Epoch: 10 loss: 0.5879 Train_acc: 0.7252\n",
+      "Epoch: 10 loss: 0.5839 Train_acc: 0.7252\n",
       "Test_acc: 0.6316\n"
      ]
     }