Skip to content

Commit

Permalink
linting fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Coerulatus committed May 11, 2024
1 parent d239600 commit d702faf
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions topomodelx/nn/hypergraph/allset_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,7 @@ def attention(self, x_source, neighborhood):
x_K = torch.matmul(x_source, self.K_weight).permute(1, 0, 2)
alpha = (x_K * self.Q_weight.permute(1, 0, 2)).sum(-1)
alpha = F.leaky_relu(alpha, 0.2)
alpha_soft = softmax(alpha[self.source_index_j], index=self.target_index_i)

return alpha_soft
return softmax(alpha[self.source_index_j], index=self.target_index_i)

def forward(self, x_source, neighborhood):
"""Forward pass.
Expand Down Expand Up @@ -409,10 +407,7 @@ def forward(self, x_source, neighborhood):
x_message = x_message.permute(1, 0, 2)[
self.source_index_j
] * attention_values.unsqueeze(-1)
x_message = scatter(x_message, self.target_index_i, dim=0, reduce="sum")

return x_message

return scatter(x_message, self.target_index_i, dim=0, reduce="sum")

class MLP(nn.Sequential):
"""MLP Module.
Expand Down

0 comments on commit d702faf

Please sign in to comment.