diff --git a/test/nn/cell/test_can.py b/test/nn/cell/test_can.py index 43534402..ecbe5ded 100644 --- a/test/nn/cell/test_can.py +++ b/test/nn/cell/test_can.py @@ -17,7 +17,7 @@ def test_forward(self): in_channels_1=2, out_channels=2, dropout=0.5, - heads=1, + heads=2, n_layers=2, att_lift=False, pooling=True, @@ -31,8 +31,8 @@ def test_forward(self): ).to_sparse() x_0, x_1 = ( - torch.tensor(x_0).float().to(device), - torch.tensor(x_1).float().to(device), + x_0.clone().detach().float().to(device), + x_1.clone().detach().float().to(device), ) adjacency_1 = adjacency_1.float().to(device) adjacency_2 = adjacency_1.float().to(device) diff --git a/test/nn/cell/test_can_layer.py b/test/nn/cell/test_can_layer.py index cf981ff0..d8080c08 100644 --- a/test/nn/cell/test_can_layer.py +++ b/test/nn/cell/test_can_layer.py @@ -13,7 +13,7 @@ class TestCANLayer: def test_forward(self): """Test the forward method of CANLayer.""" in_channels = 7 - out_channels = 64 + out_channels = 66 dropout_values = [0.5, 0.7] heads_values = [1, 3] concat_values = [True, False] @@ -65,12 +65,12 @@ def test_forward(self): ) x_out = can_layer.forward(x_1, lower_neighborhood, upper_neighborhood) if concat: - assert x_out.shape == (n_cells, out_channels * heads) - else: assert x_out.shape == (n_cells, out_channels) + else: + assert x_out.shape == (n_cells, out_channels // heads) # Test if there are no non-zero values in the neighborhood - heads = 1 + heads = 3 concat_list = [True, False] skip_connection = True @@ -79,6 +79,7 @@ def test_forward(self): can_layer = CANLayer( in_channels=in_channels, out_channels=out_channels, + heads=heads, concat=concat, skip_connection=skip_connection, version=version, @@ -89,9 +90,9 @@ def test_forward(self): torch.zeros_like(upper_neighborhood), ) if concat: - assert x_out.shape == (n_cells, out_channels * heads) - else: assert x_out.shape == (n_cells, out_channels) + else: + assert x_out.shape == (n_cells, out_channels // heads) def test_reset_parameters(self): """Test the reset_parameters method of CANLayer.""" diff --git a/topomodelx/nn/cell/can.py b/topomodelx/nn/cell/can.py index a9e80a3e..a7f00ed1 100644 --- a/topomodelx/nn/cell/can.py +++ b/topomodelx/nn/cell/can.py @@ -94,7 +94,7 @@ def __init__( for _ in range(n_layers - 1): layers.append( CANLayer( - in_channels=out_channels * heads, + in_channels=out_channels, out_channels=out_channels, dropout=dropout, heads=heads, @@ -110,7 +110,7 @@ def __init__( layers.append( PoolLayer( k_pool=k_pool, - in_channels_0=out_channels * heads, + in_channels_0=out_channels, signal_pool_activation=torch.nn.Sigmoid(), readout=True, **kwargs, diff --git a/topomodelx/nn/cell/can_layer.py b/topomodelx/nn/cell/can_layer.py index 912a0298..bae55f4f 100644 --- a/topomodelx/nn/cell/can_layer.py +++ b/topomodelx/nn/cell/can_layer.py @@ -439,9 +439,9 @@ def __init__( self.dropout = dropout self.add_self_loops = add_self_loops - self.lin = torch.nn.Linear(in_channels, heads * out_channels, bias=False) - self.att_weight_src = Parameter(torch.Tensor(1, heads, out_channels)) - self.att_weight_dst = Parameter(torch.Tensor(1, heads, out_channels)) + self.lin = torch.nn.Linear(in_channels, out_channels, bias=False) + self.att_weight_src = Parameter(torch.Tensor(1, heads, out_channels // heads)) + self.att_weight_dst = Parameter(torch.Tensor(1, heads, out_channels // heads)) self.reset_parameters() @@ -468,7 +468,7 @@ def message(self, x_source): """ # Compute the linear transformation on the source features x_message = self.lin(x_source).view( - -1, self.heads, self.out_channels + -1, self.heads, self.out_channels // self.heads ) # (n_k_cells, H, C) # compute the source and target messages @@ -534,12 +534,13 @@ def forward(self, x_source, neighborhood): # If there are no non-zero values in the neighborhood, then the neighborhood is empty. -> return zero tensor if not neighborhood.values().nonzero().size(0) > 0 and self.concat: return torch.zeros( - (x_source.shape[0], self.out_channels * self.heads), + (x_source.shape[0], self.out_channels), device=x_source.device, ) # (n_k_cells, H * C) if not neighborhood.values().nonzero().size(0) > 0 and not self.concat: return torch.zeros( - (x_source.shape[0], self.out_channels), device=x_source.device + (x_source.shape[0], self.out_channels // self.heads), + device=x_source.device, ) # (n_k_cells, C) # Add self-loops to the neighborhood matrix if necessary @@ -559,9 +560,7 @@ def forward(self, x_source, neighborhood): # if concat true, concatenate the messages for each head. Otherwise, average the messages for each head. if self.concat: - return aggregated_message.view( - -1, self.heads * self.out_channels - ) # (n_k_cells, H * C) + return aggregated_message.view(-1, self.out_channels) # (n_k_cells, H * C) return aggregated_message.mean(dim=1) # (n_k_cells, C) @@ -613,7 +612,7 @@ def __init__( heads: int, concat: bool, att_activation: torch.nn.Module, - add_self_loops: bool = False, + add_self_loops: bool = True, aggr_func: Literal["sum", "mean", "add"] = "sum", initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", share_weights: bool = False, @@ -634,17 +633,13 @@ def __init__( if share_weights: self.lin_src = self.lin_dst = torch.nn.Linear( - in_channels, heads * out_channels, bias=False + in_channels, out_channels, bias=False ) else: - self.lin_src = torch.nn.Linear( - in_channels, heads * out_channels, bias=False - ) - self.lin_dst = torch.nn.Linear( - in_channels, heads * out_channels, bias=False - ) + self.lin_src = torch.nn.Linear(in_channels, out_channels, bias=False) + self.lin_dst = torch.nn.Linear(in_channels, out_channels, bias=False) - self.att_weight = Parameter(torch.Tensor(1, heads, out_channels)) + self.att_weight = Parameter(torch.Tensor(1, heads, out_channels // heads)) self.reset_parameters() @@ -671,12 +666,12 @@ def message(self, x_source): """ # Compute the linear transformation on the source features x_src_message = self.lin_src(x_source).view( - -1, self.heads, self.out_channels + -1, self.heads, self.out_channels // self.heads ) # (n_k_cells, H, C) # Compute the linear transformation on the source features x_dst_message = self.lin_dst(x_source).view( - -1, self.heads, self.out_channels + -1, self.heads, self.out_channels // self.heads ) # (n_k_cells, H, C) # Get the source and target projections of the neighborhood @@ -737,12 +732,13 @@ def forward(self, x_source, neighborhood): # If there are no non-zero values in the neighborhood, then the neighborhood is empty. -> return zero tensor if not neighborhood.values().nonzero().size(0) > 0 and self.concat: return torch.zeros( - (x_source.shape[0], self.out_channels * self.heads), + (x_source.shape[0], self.out_channels), device=x_source.device, ) # (n_k_cells, H * C) if not neighborhood.values().nonzero().size(0) > 0 and not self.concat: return torch.zeros( - (x_source.shape[0], self.out_channels), device=x_source.device + (x_source.shape[0], self.out_channels // self.heads), + device=x_source.device, ) # (n_k_cells, C) # Add self-loops to the neighborhood matrix if necessary @@ -762,9 +758,7 @@ def forward(self, x_source, neighborhood): # if concat true, concatenate the messages for each head. Otherwise, average the messages for each head. if self.concat: - return aggregated_message.view( - -1, self.heads * self.out_channels - ) # (n_k_cells, H * C) + return aggregated_message.view(-1, self.out_channels) # (n_k_cells, H * C) return aggregated_message.mean(dim=1) # (n_k_cells, C) @@ -836,6 +830,9 @@ def __init__( assert in_channels > 0, ValueError("Number of input channels must be > 0") assert out_channels > 0, ValueError("Number of output channels must be > 0") assert heads > 0, ValueError("Number of heads must be > 0") + assert out_channels % heads == 0, ValueError( + "Number of output channels must be divisible by the number of heads" + ) assert dropout >= 0.0 and dropout <= 1.0, ValueError("Dropout must be in [0,1]") # assert that shared weight is True only if version is v2 @@ -893,7 +890,7 @@ def __init__( # linear transformation if skip_connection: - out_channels = out_channels * heads if concat else out_channels + out_channels = out_channels if concat else out_channels // heads self.lin = Linear(in_channels, out_channels, bias=False) self.eps = 1 + 1e-6 diff --git a/tutorials/cell/can_train.ipynb b/tutorials/cell/can_train.ipynb index 122dd622..0cc2cd63 100644 --- a/tutorials/cell/can_train.ipynb +++ b/tutorials/cell/can_train.ipynb @@ -104,21 +104,22 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:06:36.009880829Z", "start_time": "2023-05-31T09:06:34.285257706Z" - } + }, + "metadata": {} }, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -146,12 +147,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:53.006542411Z", "start_time": "2023-05-31T09:13:52.963074076Z" - } + }, + "metadata": {} }, "outputs": [ { @@ -186,23 +188,15 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:55.279147916Z", "start_time": "2023-05-31T09:13:55.269057585Z" - } + }, + "metadata": {} }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip\n", - "Processing...\n", - "Done!\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -253,12 +247,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:55.832585216Z", "start_time": "2023-05-31T09:13:55.815448708Z" - } + }, + "metadata": {} }, "outputs": [], "source": [ @@ -298,8 +293,10 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, + "execution_count": 5, + "metadata": { + "metadata": {} + }, "outputs": [], "source": [ "class Network(torch.nn.Module):\n", @@ -324,7 +321,7 @@ " n_layers=n_layers,\n", " att_lift=att_lift,\n", " )\n", - " self.lin_0 = torch.nn.Linear(heads * out_channels, 128)\n", + " self.lin_0 = torch.nn.Linear(out_channels, 128)\n", " self.lin_1 = torch.nn.Linear(128, num_classes)\n", "\n", " def forward(self, x_0, x_1, adjacency, down_laplacian, up_laplacian):\n", @@ -338,12 +335,13 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:56.672913561Z", "start_time": "2023-05-31T09:13:56.667986426Z" - } + }, + "metadata": {} }, "outputs": [], "source": [ @@ -378,12 +376,13 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:40.411845803Z", "start_time": "2023-05-31T09:19:40.408861921Z" - } + }, + "metadata": {} }, "outputs": [ { @@ -398,35 +397,35 @@ " (0): CANLayer(\n", " (lower_att): MultiHeadCellAttention(\n", " (att_activation): LeakyReLU(negative_slope=0.2)\n", - " (lin): Linear(in_features=11, out_features=64, bias=False)\n", + " (lin): Linear(in_features=11, out_features=32, bias=False)\n", " )\n", " (upper_att): MultiHeadCellAttention(\n", " (att_activation): LeakyReLU(negative_slope=0.2)\n", - " (lin): Linear(in_features=11, out_features=64, bias=False)\n", + " (lin): Linear(in_features=11, out_features=32, bias=False)\n", " )\n", - " (lin): Linear(in_features=11, out_features=64, bias=False)\n", + " (lin): Linear(in_features=11, out_features=32, bias=False)\n", " (aggregation): Aggregation()\n", " )\n", " (1): CANLayer(\n", " (lower_att): MultiHeadCellAttention(\n", " (att_activation): LeakyReLU(negative_slope=0.2)\n", - " (lin): Linear(in_features=64, out_features=64, bias=False)\n", + " (lin): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (upper_att): MultiHeadCellAttention(\n", " (att_activation): LeakyReLU(negative_slope=0.2)\n", - " (lin): Linear(in_features=64, out_features=64, bias=False)\n", + " (lin): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", - " (lin): Linear(in_features=64, out_features=64, bias=False)\n", + " (lin): Linear(in_features=32, out_features=32, bias=False)\n", " (aggregation): Aggregation()\n", " )\n", " )\n", " )\n", - " (lin_0): Linear(in_features=64, out_features=128, bias=True)\n", + " (lin_0): Linear(in_features=32, out_features=128, bias=True)\n", " (lin_1): Linear(in_features=128, out_features=2, bias=True)\n", ")" ] }, - "execution_count": 9, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -446,12 +445,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:41.150933630Z", "start_time": "2023-05-31T09:19:41.146986990Z" - } + }, + "metadata": {} }, "outputs": [], "source": [ @@ -479,37 +479,38 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:42.918836083Z", "start_time": "2023-05-31T09:19:42.114801039Z" - } + }, + "metadata": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 1 loss: 0.6159 Train_acc: 0.6947\n", + "Epoch: 1 loss: 0.6330 Train_acc: 0.6718\n", "Test_acc: 0.5965\n", - "Epoch: 2 loss: 0.6099 Train_acc: 0.6947\n", + "Epoch: 2 loss: 0.6116 Train_acc: 0.6947\n", "Test_acc: 0.5965\n", - "Epoch: 3 loss: 0.6035 Train_acc: 0.6947\n", + "Epoch: 3 loss: 0.6071 Train_acc: 0.6947\n", "Test_acc: 0.5965\n", - "Epoch: 4 loss: 0.5966 Train_acc: 0.7176\n", - "Test_acc: 0.6316\n", - "Epoch: 5 loss: 0.5909 Train_acc: 0.7252\n", + "Epoch: 4 loss: 0.6027 Train_acc: 0.6947\n", + "Test_acc: 0.5965\n", + "Epoch: 5 loss: 0.5974 Train_acc: 0.7099\n", "Test_acc: 0.6491\n", - "Epoch: 6 loss: 0.5983 Train_acc: 0.7099\n", - "Test_acc: 0.6316\n", - "Epoch: 7 loss: 0.5884 Train_acc: 0.7252\n", + "Epoch: 6 loss: 0.5911 Train_acc: 0.7252\n", "Test_acc: 0.6491\n", - "Epoch: 8 loss: 0.5909 Train_acc: 0.7176\n", + "Epoch: 7 loss: 0.5979 Train_acc: 0.7176\n", + "Test_acc: 0.6140\n", + "Epoch: 8 loss: 0.5826 Train_acc: 0.7252\n", "Test_acc: 0.6316\n", - "Epoch: 9 loss: 0.5818 Train_acc: 0.7252\n", + "Epoch: 9 loss: 0.5908 Train_acc: 0.7252\n", "Test_acc: 0.6316\n", - "Epoch: 10 loss: 0.5879 Train_acc: 0.7252\n", + "Epoch: 10 loss: 0.5839 Train_acc: 0.7252\n", "Test_acc: 0.6316\n" ] }