diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ca6c2d41..9538a6e4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.10 + python: python3.11 repos: - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/topomodelx/nn/hypergraph/allset_transformer_layer.py b/topomodelx/nn/hypergraph/allset_transformer_layer.py index 1e805e4c..8e91f441 100644 --- a/topomodelx/nn/hypergraph/allset_transformer_layer.py +++ b/topomodelx/nn/hypergraph/allset_transformer_layer.py @@ -4,6 +4,8 @@ import torch import torch.nn.functional as F from torch import nn +from torch_geometric.utils import softmax +from torch_scatter import scatter from topomodelx.base.message_passing import MessagePassing @@ -257,8 +259,10 @@ def forward(self, x_source, neighborhood): # Obtain Y from Eq(8) in AllSet paper [1] # Skip-connection (broadcased) - x_message_on_target = x_message_on_target + self.multihead_att.Q_weight - + x_message_on_target = x_message_on_target + self.multihead_att.Q_weight.permute( + 1, 0, 2 + ) + x_message_on_target = x_message_on_target.unsqueeze(2) # Permute: n,h,q,c -> n,q,h,c x_message_on_target = x_message_on_target.permute(0, 2, 1, 3) x_message_on_target = self.ln0( @@ -368,21 +372,10 @@ def attention(self, x_source, neighborhood): torch.Tensor, shape = (n_target_cells, heads, number_queries, n_source_cells) Attention weights: one scalar per message between a source and a target cell. """ - x_K = torch.matmul(x_source, self.K_weight) - alpha = torch.matmul(self.Q_weight, x_K.transpose(1, 2)) - expanded_alpha = torch.sparse_coo_tensor( - indices=neighborhood.indices(), - values=alpha.permute(*torch.arange(alpha.ndim - 1, -1, -1))[ - self.source_index_j - ], - size=[ - neighborhood.shape[0], - neighborhood.shape[1], - alpha.shape[1], - alpha.shape[0], - ], - ) - return torch.sparse.softmax(expanded_alpha, dim=1).to_dense().transpose(1, 3) + x_K = torch.matmul(x_source, self.K_weight).permute(1, 0, 2) + alpha = (x_K * self.Q_weight.permute(1, 0, 2)).sum(-1) + alpha = F.leaky_relu(alpha, 0.2) + return softmax(alpha[self.source_index_j], index=self.target_index_i) def forward(self, x_source, neighborhood): """Forward pass. @@ -410,8 +403,11 @@ def forward(self, x_source, neighborhood): attention_values = self.attention(x_source, neighborhood) x_message = torch.matmul(x_source, self.V_weight) - return torch.matmul(attention_values, x_message) + x_message = x_message.permute(1, 0, 2)[ + self.source_index_j + ] * attention_values.unsqueeze(-1) + return scatter(x_message, self.target_index_i, dim=0, reduce="sum") class MLP(nn.Sequential): """MLP Module.