From d239600afe647aebfe0b9bb89031d2a546816400 Mon Sep 17 00:00:00 2001 From: Coerulatus Date: Sat, 11 May 2024 11:52:20 +0000 Subject: [PATCH 1/2] efficient implementation of attention --- .pre-commit-config.yaml | 2 +- .../nn/hypergraph/allset_transformer_layer.py | 37 ++++++++++--------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ca6c2d41..9538a6e4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.10 + python: python3.11 repos: - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/topomodelx/nn/hypergraph/allset_transformer_layer.py b/topomodelx/nn/hypergraph/allset_transformer_layer.py index 1e805e4c..44c985c2 100644 --- a/topomodelx/nn/hypergraph/allset_transformer_layer.py +++ b/topomodelx/nn/hypergraph/allset_transformer_layer.py @@ -4,6 +4,8 @@ import torch import torch.nn.functional as F from torch import nn +from torch_geometric.utils import softmax +from torch_scatter import scatter from topomodelx.base.message_passing import MessagePassing @@ -257,8 +259,10 @@ def forward(self, x_source, neighborhood): # Obtain Y from Eq(8) in AllSet paper [1] # Skip-connection (broadcased) - x_message_on_target = x_message_on_target + self.multihead_att.Q_weight - + x_message_on_target = x_message_on_target + self.multihead_att.Q_weight.permute( + 1, 0, 2 + ) + x_message_on_target = x_message_on_target.unsqueeze(2) # Permute: n,h,q,c -> n,q,h,c x_message_on_target = x_message_on_target.permute(0, 2, 1, 3) x_message_on_target = self.ln0( @@ -368,21 +372,12 @@ def attention(self, x_source, neighborhood): torch.Tensor, shape = (n_target_cells, heads, number_queries, n_source_cells) Attention weights: one scalar per message between a source and a target cell. """ - x_K = torch.matmul(x_source, self.K_weight) - alpha = torch.matmul(self.Q_weight, x_K.transpose(1, 2)) - expanded_alpha = torch.sparse_coo_tensor( - indices=neighborhood.indices(), - values=alpha.permute(*torch.arange(alpha.ndim - 1, -1, -1))[ - self.source_index_j - ], - size=[ - neighborhood.shape[0], - neighborhood.shape[1], - alpha.shape[1], - alpha.shape[0], - ], - ) - return torch.sparse.softmax(expanded_alpha, dim=1).to_dense().transpose(1, 3) + x_K = torch.matmul(x_source, self.K_weight).permute(1, 0, 2) + alpha = (x_K * self.Q_weight.permute(1, 0, 2)).sum(-1) + alpha = F.leaky_relu(alpha, 0.2) + alpha_soft = softmax(alpha[self.source_index_j], index=self.target_index_i) + + return alpha_soft def forward(self, x_source, neighborhood): """Forward pass. @@ -410,7 +405,13 @@ def forward(self, x_source, neighborhood): attention_values = self.attention(x_source, neighborhood) x_message = torch.matmul(x_source, self.V_weight) - return torch.matmul(attention_values, x_message) + + x_message = x_message.permute(1, 0, 2)[ + self.source_index_j + ] * attention_values.unsqueeze(-1) + x_message = scatter(x_message, self.target_index_i, dim=0, reduce="sum") + + return x_message class MLP(nn.Sequential): From d702fafd59c577501be0a39caaadcd9a9169c03c Mon Sep 17 00:00:00 2001 From: Coerulatus Date: Sat, 11 May 2024 16:37:46 +0000 Subject: [PATCH 2/2] linting fix --- topomodelx/nn/hypergraph/allset_transformer_layer.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/topomodelx/nn/hypergraph/allset_transformer_layer.py b/topomodelx/nn/hypergraph/allset_transformer_layer.py index 44c985c2..8e91f441 100644 --- a/topomodelx/nn/hypergraph/allset_transformer_layer.py +++ b/topomodelx/nn/hypergraph/allset_transformer_layer.py @@ -375,9 +375,7 @@ def attention(self, x_source, neighborhood): x_K = torch.matmul(x_source, self.K_weight).permute(1, 0, 2) alpha = (x_K * self.Q_weight.permute(1, 0, 2)).sum(-1) alpha = F.leaky_relu(alpha, 0.2) - alpha_soft = softmax(alpha[self.source_index_j], index=self.target_index_i) - - return alpha_soft + return softmax(alpha[self.source_index_j], index=self.target_index_i) def forward(self, x_source, neighborhood): """Forward pass. @@ -409,10 +407,7 @@ def forward(self, x_source, neighborhood): x_message = x_message.permute(1, 0, 2)[ self.source_index_j ] * attention_values.unsqueeze(-1) - x_message = scatter(x_message, self.target_index_i, dim=0, reduce="sum") - - return x_message - + return scatter(x_message, self.target_index_i, dim=0, reduce="sum") class MLP(nn.Sequential): """MLP Module.