diff --git a/examples/04_training/08_train_grit_model.py b/examples/04_training/08_train_grit_model.py new file mode 100644 index 000000000..cf244cbe3 --- /dev/null +++ b/examples/04_training/08_train_grit_model.py @@ -0,0 +1,245 @@ +"""Example of training Model.""" + +import os +from typing import Any, Dict, List, Optional + +from pytorch_lightning.loggers import WandbLogger +import torch +from torch.optim.adam import Adam + +from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR +from graphnet.data.constants import FEATURES, TRUTH +from graphnet.models import StandardModel +from graphnet.models.detector.prometheus import Prometheus +from graphnet.models.gnn import GRIT +from graphnet.models.graphs import KNNGraphRRWP +from graphnet.models.task.reconstruction import EnergyReconstruction +from graphnet.training.callbacks import PiecewiseLinearLR +from graphnet.training.loss_functions import LogCoshLoss +from graphnet.utilities.argparse import ArgumentParser +from graphnet.utilities.logging import Logger +from graphnet.data import GraphNeTDataModule +from graphnet.data.dataset import SQLiteDataset +from graphnet.data.dataset import ParquetDataset + +# Constants +features = FEATURES.PROMETHEUS +truth = TRUTH.PROMETHEUS + + +def main( + path: str, + pulsemap: str, + target: str, + truth_table: str, + gpus: Optional[List[int]], + max_epochs: int, + early_stopping_patience: int, + batch_size: int, + num_workers: int, + wandb: bool = False, +) -> None: + """Run example.""" + # Construct Logger + logger = Logger() + + # Initialise Weights & Biases (W&B) run + if wandb: + # Make sure W&B output directory exists + wandb_dir = "./wandb/" + os.makedirs(wandb_dir, exist_ok=True) + wandb_logger = WandbLogger( + project="example-script", + entity="graphnet-team", + save_dir=wandb_dir, + log_model=True, + ) + + logger.info(f"features: {features}") + logger.info(f"truth: {truth}") + + # Configuration + config: Dict[str, Any] = { + "path": path, + "pulsemap": pulsemap, + "batch_size": batch_size, + "num_workers": num_workers, + "target": target, + "early_stopping_patience": early_stopping_patience, + "fit": { + "gpus": gpus, + "max_epochs": max_epochs, + "distribution_strategy": "ddp_find_unused_parameters_true", + }, + "dataset_reference": ( + SQLiteDataset if path.endswith(".db") else ParquetDataset + ), + } + + archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs") + run_name = "grit_{}_example".format(config["target"]) + if wandb: + # Log configuration to W&B + wandb_logger.experiment.config.update(config) + + walk_length = 6 + graph_definition = KNNGraphRRWP( + detector=Prometheus(), + input_feature_names=features, + nb_nearest_neighbours=5, + walk_length=walk_length, + ) + dm = GraphNeTDataModule( + dataset_reference=config["dataset_reference"], + dataset_args={ + "truth": truth, + "truth_table": truth_table, + "features": features, + "graph_definition": graph_definition, + "pulsemaps": [config["pulsemap"]], + "path": config["path"], + }, + train_dataloader_kwargs={ + "batch_size": config["batch_size"], + "num_workers": config["num_workers"], + }, + test_dataloader_kwargs={ + "batch_size": config["batch_size"], + "num_workers": config["num_workers"], + }, + ) + + training_dataloader = dm.train_dataloader + validation_dataloader = dm.val_dataloader + + # Building model + backbone = GRIT( + nb_inputs=graph_definition.nb_outputs, + hidden_dim=32, + ksteps=walk_length, + ) + + task = EnergyReconstruction( + hidden_size=backbone.nb_outputs, + target_labels=config["target"], + loss_function=LogCoshLoss(), + transform_prediction_and_target=lambda x: torch.log10(x), + transform_inference=lambda x: torch.pow(10, x), + ) + + model = StandardModel( + graph_definition=graph_definition, + backbone=backbone, + tasks=[task], + optimizer_class=Adam, + optimizer_kwargs={"lr": 1e-03, "eps": 1e-03}, + scheduler_class=PiecewiseLinearLR, + scheduler_kwargs={ + "milestones": [ + 0, + len(training_dataloader) / 2, + len(training_dataloader) * config["fit"]["max_epochs"], + ], + "factors": [1e-2, 1, 1e-02], + }, + scheduler_config={ + "interval": "step", + }, + ) + + # Training model + model.fit( + training_dataloader, + validation_dataloader, + early_stopping_patience=config["early_stopping_patience"], + logger=wandb_logger if wandb else None, + **config["fit"], + ) + + # Get predictions + additional_attributes = model.target_labels + assert isinstance(additional_attributes, list) # mypy + + results = model.predict_as_dataframe( + validation_dataloader, + additional_attributes=additional_attributes + ["event_no"], + gpus=config["fit"]["gpus"], + ) + + # Save predictions and model to file + db_name = path.split("/")[-1].split(".")[0] + path = os.path.join(archive, db_name, run_name) + logger.info(f"Writing results to {path}") + os.makedirs(path, exist_ok=True) + + results.to_csv(f"{path}/results.csv") + + model.save(f"{path}/model.pth") + model.save_state_dict(f"{path}/state_dict.pth") + model.save_config(f"{path}/model_config.yml") + + +if __name__ == "__main__": + + # Parse command-line arguments + parser = ArgumentParser( + description=""" +Train GNN model without the use of config files. +""" + ) + + parser.add_argument( + "--path", + help="Path to dataset file (default: %(default)s)", + default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db", + ) + + parser.add_argument( + "--pulsemap", + help="Name of pulsemap to use (default: %(default)s)", + default="total", + ) + + parser.add_argument( + "--target", + help=( + "Name of feature to use as regression target (default: " + "%(default)s)" + ), + default="total_energy", + ) + + parser.add_argument( + "--truth-table", + help="Name of truth table to be used (default: %(default)s)", + default="mc_truth", + ) + + parser.with_standard_arguments( + "gpus", + ("max-epochs", 1), + "early-stopping-patience", + ("batch-size", 16), + "num-workers", + ) + + parser.add_argument( + "--wandb", + action="store_true", + help="If True, Weights & Biases are used to track the experiment.", + ) + + args, unknown = parser.parse_known_args() + + main( + args.path, + args.pulsemap, + args.target, + args.truth_table, + args.gpus, + args.max_epochs, + args.early_stopping_patience, + args.batch_size, + args.num_workers, + args.wandb, + ) diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index 08d699931..307fcb8c6 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -7,6 +7,51 @@ from typing import Optional from pytorch_lightning import LightningModule +from torch_geometric.utils import add_self_loops +from torch_geometric.data import Data +from torch_sparse import coalesce + +from graphnet.models.utils import full_edge_index + + +class LinearEdgeEncoder(LightningModule): + """Linear encoding for edge attributes.""" + + def __init__(self, dim_emb: int): + """Construct `LinearEdgeEncoder`. + + Args: + dim_emb: Embedding dimension. + """ + super().__init__() + + self.in_dim = 1 # TODO: generalize to more edge features -PW + self.encoder = torch.nn.Linear(self.in_dim, dim_emb) + + def forward(self, data: Data) -> Data: + """Forward pass.""" + data.edge_attr = self.encoder(data.edge_attr.view(-1, self.in_dim)) + return data + + +class LinearNodeEncoder(LightningModule): + """Linear encoding for nodes.""" + + def __init__(self, dim_in: int, dim_emb: int): + """Construct `LinearNodeEncoder`. + + Args: + dim_in: Input dimension. + dim_emb: Embedding dimension. + """ + super().__init__() + + self.encoder = torch.nn.Linear(dim_in, dim_emb) + + def forward(self, data: Data) -> Data: + """Forward pass.""" + data.x = self.encoder(data.x) + return data class SinusoidalPosEmb(LightningModule): @@ -172,3 +217,188 @@ def forward( sin_emb = self.sin_emb(1024 * four_distance.clip(-4, 4)) rel_attn = self.projection(sin_emb) return rel_attn + + +class RRWPLinearNodeEncoder(LightningModule): + """Relative random walk probability node encoder. + + Original code: + https://github.com/LiamMa/GRIT/blob/main/grit/encoder/rrwp_encoder.py + """ + + def __init__( + self, + emb_dim: int, + out_dim: int, + use_bias: bool = False, + apply_norm: bool = True, + norm_layer: nn.Module = nn.BatchNorm1d, + pe_name: str = "rrwp", + ): + """Construct `RRWPLinearNodeEncoder`. + + Args: + emb_dim: Embedding dimension. + out_dim: Output dimension. + use_bias: Apply bias to linear layer. + apply_norm: Apply normalization layer. + norm_layer: Normalization layer. + pe_name: Positional encoding name. + """ + super().__init__() + self.name = pe_name + self.apply_norm = apply_norm + + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + + if self.apply_norm: + self.norm = norm_layer(out_dim) + + def forward(self, data: Data) -> Data: + """Forward pass.""" + rrwp = data[f"{self.name}"] + rrwp = self.fc(rrwp) + + if self.norm: + rrwp = self.norm(rrwp) + + if "x" in data: + data.x = data.x + rrwp + else: + data.x = rrwp + + return data + + +class RRWPLinearEdgeEncoder(LightningModule): + """Relative random walk probability edge encoder. + + Original code: + https://github.com/LiamMa/GRIT/blob/main/grit/encoder/rrwp_encoder.py + """ + + def __init__( + self, + emb_dim: int, + out_dim: int, + use_bias: bool = False, + apply_norm: bool = True, + norm_layer: nn.Module = nn.BatchNorm1d, + pad_to_full_graph: bool = True, + fill_value: float = 0.0, + add_node_attr_as_self_loop: bool = False, + overwrite_old_attr: bool = False, + ): + """Construct `RRWPLinearEdgeEncoder`. + + Args: + emb_dim: Embedding dimension. + out_dim: Output dimension. + use_bias: Apply bias to linear layer. + apply_norm: Apply normalization layer. + norm_layer: Normalization layer. + pad_to_full_graph: Pad edges to fully-connected graph. + fill_value: Fill value for padding. + add_node_attr_as_self_loop: Add self loop edges with node attr. + overwrite_old_attr: Overwrite old edge attr. + """ + super().__init__() + + self.emb_dim = emb_dim + self.out_dim = out_dim + self.add_node_attr_as_self_loop = add_node_attr_as_self_loop + self.overwrite_old_attr = overwrite_old_attr + self.apply_norm = apply_norm + + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + self.pad_to_full_graph = pad_to_full_graph + self.fill_value = 0.0 + + padding = torch.ones(1, out_dim, dtype=torch.float) * fill_value + self.register_buffer("padding", padding) + + if self.apply_norm: + self.norm = norm_layer(out_dim) + + def forward(self, data: Data) -> Data: + """Forward pass.""" + rrwp_idx = data.rrwp_index + rrwp_val = data.rrwp_val + edge_index = data.edge_index + edge_attr = data.edge_attr + + rrwp_val = self.fc(rrwp_val) + + if edge_attr is None: + edge_attr = edge_index.new_zeros( + edge_index.size(1), rrwp_val.size(1) + ) + # zero padding for non-existing edges + + if self.overwrite_old_attr: + out_idx, out_val = rrwp_idx, rrwp_val + else: + edge_index, edge_attr = add_self_loops( + edge_index, edge_attr, num_nodes=data.num_nodes, fill_value=0.0 + ) + + out_idx, out_val = coalesce( + torch.cat([edge_index, rrwp_idx], dim=1), + torch.cat([edge_attr, rrwp_val], dim=0), + data.num_nodes, + data.num_nodes, + op="add", + ) + + if self.pad_to_full_graph: + edge_index_full = full_edge_index(out_idx, batch=data.batch) + edge_attr_pad = self.padding.repeat(edge_index_full.size(1), 1) + # zero padding to fully-connected graphs + out_idx = torch.cat([out_idx, edge_index_full], dim=1) + out_val = torch.cat([out_val, edge_attr_pad], dim=0) + out_idx, out_val = coalesce( + out_idx, out_val, data.num_nodes, data.num_nodes, op="add" + ) + + if self.apply_norm: + out_val = self.norm(out_val) + + data.edge_index, data.edge_attr = out_idx, out_val + return data + + +class RWSELinearNodeEncoder(LightningModule): + """Random walk structural node encoding.""" + + def __init__( + self, + emb_dim: int, + out_dim: int, + use_bias: bool = False, + ): + """Construct `RWSELinearEdgeEncoder`. + + Args: + emb_dim: Embedding dimension. + out_dim: Output dimension. + use_bias: Apply bias to linear layer. + """ + super().__init__() + + self.emb_dim = emb_dim + self.out_dim = out_dim + + self.encoder = nn.Linear(emb_dim, out_dim, bias=use_bias) + + def forward(self, data: Data) -> Data: + """Forward pass.""" + rwse = data.rwse + x = data.x + + rwse = self.encoder(rwse) + + data.x = torch.cat((x, rwse), dim=1) + + return data diff --git a/src/graphnet/models/components/layers.py b/src/graphnet/models/components/layers.py index 4f04067f9..386f17818 100644 --- a/src/graphnet/models/components/layers.py +++ b/src/graphnet/models/components/layers.py @@ -3,17 +3,25 @@ from typing import Any, Callable, Optional, Sequence, Union, List import torch +import torch.nn as nn from torch.functional import Tensor from torch_geometric.nn import EdgeConv -from torch_geometric.nn.pool import knn_graph +from torch_geometric.nn.pool import ( + knn_graph, + global_mean_pool, + global_add_pool, +) from torch_geometric.typing import Adj, PairTensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset -import torch.nn as nn +from torch_geometric.data import Data from torch.nn.functional import linear from torch.nn.modules import TransformerEncoder, TransformerEncoderLayer -from torch_geometric.utils import to_dense_batch +from torch_geometric.utils import to_dense_batch, softmax +from torch_scatter import scatter + from pytorch_lightning import LightningModule +from torch_geometric.utils import degree class DynEdgeConv(EdgeConv, LightningModule): @@ -593,3 +601,402 @@ def forward( ) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x + + +class GritSparseMHA(LightningModule): + """Proposed Attention Computation for GRIT. + + Original code: + https://github.com/LiamMa/GRIT/blob/main/grit/layer/grit_layer.py + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_heads: int, + use_bias: bool, + clamp: float = 5.0, + dropout: float = 0.0, + activation: nn.Module = nn.ReLU, + edge_enhance: bool = True, + ): + """Construct 'GritSparseMHA'. + + Args: + in_dim: Dimension of the input tensor. + out_dim: Dimension of the output tensor. + num_heads: Number of attention heads. + use_bias: Apply bias the key and value linear layers. + clamp: Clamp the absolute value of the attention scores to a value. + dropout: Dropout layer probability. + activation: Uninstantiated activation function. + E.g. `torch.nn.ReLU` + edge_enhance: Applies learnable weight matrix with node-pair in + output node calculation. + """ + super().__init__() + + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = nn.Dropout(dropout) + self.clamp = abs(clamp) if clamp is not None else None + self.edge_enhance = edge_enhance + + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + self.E = nn.Linear(in_dim, out_dim * num_heads * 2, bias=True) + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + nn.init.xavier_normal_(self.Q.weight) + nn.init.xavier_normal_(self.K.weight) + nn.init.xavier_normal_(self.E.weight) + nn.init.xavier_normal_(self.V.weight) + + self.Aw = nn.Parameter( + torch.zeros(self.out_dim, self.num_heads, 1), requires_grad=True + ) + nn.init.xavier_normal_(self.Aw) + + # TODO: Better activation function handling -PW + self.activation = activation() + + if self.edge_enhance: + self.VeRow = nn.Parameter( + torch.zeros(self.out_dim, self.num_heads, self.out_dim), + requires_grad=True, + ) + nn.init.xavier_normal_(self.VeRow) + + def forward(self, data: Data) -> Data: + """Forward pass.""" + Q_x = self.Q(data.x) + K_x = self.K(data.x) + V_x = self.V(data.x) + + if data.get("edge_attr", None) is not None: + E = self.E(data.edge_attr) + else: + E = None + + Q_x = Q_x.view(-1, self.num_heads, self.out_dim) + K_x = K_x.view(-1, self.num_heads, self.out_dim) + V_x = V_x.view(-1, self.num_heads, self.out_dim) + + # Applying Eq. 2 of the GRIT paper: + src = K_x[data.edge_index[0]] # (num relative) x num_heads x out_dim + dest = Q_x[data.edge_index[1]] # (num relative) x num_heads x out_dim + score = src + dest # element-wise multiplication + if E is not None: + E = E.view(-1, self.num_heads, self.out_dim * 2) + E_w, E_b = E[:, :, : self.out_dim], E[:, :, self.out_dim :] + score = score * E_w + score = torch.sqrt(torch.relu(score)) - torch.sqrt( + torch.relu(-score) + ) + score = score + E_b + + score = self.activation(score) + e_t = score # ehat_ij + + # Output edge + if E is not None: + wE = score.flatten(1) + + # Complete attention calculation + score = torch.einsum("ehd, dhc->ehc", score, self.Aw) + if self.clamp is not None: + score = torch.clamp(score, min=-self.clamp, max=self.clamp) + score = softmax(score, index=data.edge_index[1]).to( + dtype=data.x.dtype + ) # (num relative) x num_heads x 1 + score = self.dropout(score) + + # Aggregate with Attn-Score + V_x_weighted = ( + V_x[data.edge_index[0]] * score + ) # (num relative) x num_heads x out_dim + wV = torch.zeros_like( + V_x, dtype=score.dtype + ) # (num nodes in batch) x num_heads x out_dim + scatter(V_x_weighted, data.edge_index[1], dim=0, out=wV, reduce="add") + + # Adds the second term (W_Ev ehhat_ij) in the last line of Eq. 2 + if self.edge_enhance and E is not None: + rowV = scatter( + e_t * score, data.edge_index[1], dim=0, reduce="add" + ) + rowV = torch.einsum("nhd, dhc -> nhc", rowV, self.VeRow) + wV = wV + rowV + + return wV, wE + + +class GritTransformerLayer(LightningModule): + """Proposed Transformer Layer for GRIT. + + Original code: + https://github.com/LiamMa/GRIT/blob/main/grit/layer/grit_layer.py + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_heads: int, + dropout: float = 0.0, + norm: nn.Module = nn.BatchNorm1d, + residual: bool = True, + deg_scaler: bool = True, + activation: nn.Module = nn.ReLU, + norm_edges: bool = True, + update_edges: bool = True, + batch_norm_momentum: float = 0.1, + batch_norm_runner: bool = True, + rezero: bool = False, + enable_edge_transform: bool = True, + attn_bias: bool = False, + attn_dropout: float = 0.0, + attn_clamp: float = 5.0, + attn_activation: nn.Module = nn.ReLU, + attn_edge_enhance: bool = True, + ): + """Construct 'GritTransformerLayer'. + + Args: + in_dim: Dimension of the input tensor. + out_dim: Dimension of theo output tensor. + num_heads: Number of attention heads. + dropout: Dropout layer probability. + norm: Uninstantiated normalization layer. + Must be either `torch.nn.BatchNorm1d` or `torch.nn.LayerNorm`. + residual: Apply residual connections. + deg_scaler: Apply degree scaling after MHA. + activation: Uninstantiated activation function. + E.g. `torch.nn.ReLU` + norm_edges: Apply normalization to edges. + update_edges: Update edges after layer. + batch_norm_momentum: Momentum of batch normalization. + batch_norm_runner: Track running stats of batch normalization. + rezero: Apply learnable scaling parameters. + enable_edge_transform: Apply a FC to edges at the start + of the layer. + attn_bias: Add bias to keys and values in MHA block. + attn_dropout: Attention droput. + attn_clamp: Clamp absolute value of attention scores to a value. + attn_activation: Uninstantiated activation function for MHA block. + E.g. `torch.nn.ReLU` + attn_edge_enhance: Applies learnable weight matrix with node-pair + in output node calculation in MHA block. + """ + super().__init__() + + self.in_channels = in_dim + self.out_channels = out_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.residual = residual + self.update_edges = update_edges + self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_runner = batch_norm_runner + self.rezero = rezero + self.deg_scaler = deg_scaler + self.activation = activation() + + self.attention = GritSparseMHA( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=attn_bias, + dropout=attn_dropout, + clamp=attn_clamp, + activation=attn_activation, + edge_enhance=attn_edge_enhance, + ) + + self.fc1_x = nn.Linear(out_dim // num_heads * num_heads, out_dim) + if enable_edge_transform: + self.fc1_e = nn.Linear(out_dim // num_heads * num_heads, out_dim) + else: + self.fc1_e = nn.Identity() + + if self.deg_scaler: + self.deg_coef = nn.Parameter( + torch.zeros(1, out_dim // num_heads * num_heads, 2) + ) + nn.init.xavier_normal_(self.deg_coef) + + if norm == nn.LayerNorm: + self.norm1_x = norm(out_dim) + self.norm1_e = self.norm(out_dim) if norm_edges else nn.Identity() + elif norm == nn.BatchNorm1d: + self.norm1_x = norm( + out_dim, + track_running_stats=self.batch_norm_runner, + eps=1e-5, + momentum=self.batch_norm_momentum, + ) + self.norm1_e = ( + norm( + out_dim, + track_running_stats=self.batch_norm_runner, + eps=1e-5, + momentum=self.batch_norm_momentum, + ) + if norm_edges + else nn.Identity() + ) + else: + raise ValueError( + "GritTransformerLayer normalization layer must be 'LayerNorm' \ + or 'BatchNorm1d'!" + ) + + # FFN for x + self.FFN_x_layer1 = nn.Linear(out_dim, out_dim * 2) + self.FFN_x_layer2 = nn.Linear(out_dim * 2, out_dim) + + if norm == nn.LayerNorm: + self.norm2_x = norm(out_dim) + elif norm == nn.BatchNorm1d: + self.norm2_x = norm( + out_dim, + track_running_stats=self.batch_norm_runner, + eps=1e-5, + momentum=self.batch_norm_momentum, + ) + + if self.rezero: # Learnable scaling parameters + self.alpha1_x = nn.Parameter(torch.zeros(1, 1)) + self.alpha2_x = nn.Parameter(torch.zeros(1, 1)) + self.alpha1_e = nn.Parameter(torch.zeros(1, 1)) + + self.dropout1 = nn.Dropout(dropout) # Post-attention dropout on x + self.dropout2 = nn.Dropout(dropout) # Post-attention dropout on e + self.dropout3 = nn.Dropout(dropout) # Post-FFN dropout on x + + def forward(self, data: Data) -> Data: + """Forward pass.""" + x = data.x + num_nodes = data.num_nodes + log_deg = torch.log10( + degree(data.edge_index[0], num_nodes=num_nodes, dtype=data.x.dtype) + + 1 + ) + log_deg = log_deg.view(data.num_nodes, 1) + + x_attn_residual = x # for first residual connection + e_values_in = data.get("edge_attr", None) + e = None + + # Attention outputs + x_attn_out, e_attn_out = self.attention(data) + + x = x_attn_out.view(num_nodes, -1) + x = self.dropout1(x) + + # Apply degree scaler if enabled + if self.deg_scaler: + x = torch.stack([x, x * log_deg], dim=-1) + x = (x * self.deg_coef).sum(dim=-1) + + x = self.fc1_x(x) + if e_attn_out is not None: + e = e_attn_out.flatten(1) + e = self.dropout2(e) + e = self.fc1_e(e) + + if self.residual: + if self.rezero: + x = x * self.alpha1_x + x = x_attn_residual + x + + if e is not None: + if self.rezero: + e = e * self.alpha1_e + e = e + e_values_in + + x = self.norm1_x(x) + if e is not None: + e = self.norm1_e(e) + + # FFN for x + x_ffn_residual = x # Residual over the FFN + x = self.FFN_x_layer1(x) + x = self.activation(x) + x = self.dropout3(x) + x = self.FFN_x_layer2(x) + + if self.residual: + if self.rezero: + x = x * self.alpha2_x + x = x_ffn_residual + x # residual connection + + x = self.norm2_x(x) + + data.x = x + if self.update_edges: + data.edge_attr = e + else: + data.edge_attr = e_values_in + + return data + + +# TODO: This is a prediction head... we probably want only the graph stuff here +# and let the Tasks handle the last layer. -PW +class SANGraphHead(LightningModule): + """SAN prediction head for graph prediction tasks. + + Original code: + https://github.com/LiamMa/GRIT/blob/main/grit/head/san_graph.py + """ + + def __init__( + self, + dim_in: int, + dim_out: int = 1, + L: int = 2, + activation: nn.Module = nn.ReLU, + pooling: str = "mean", + ): + """Construct `SANGraphHead`. + + Args: + dim_in: Input dimension. + dim_out: Output dimension. + L: Number of hidden layers. + activation: Uninstantiated activation function. + E.g. `torch.nn.ReLU` + pooling: Node-wise pooling operation. Either "mean" or "add". + """ + super().__init__() + if pooling == "mean": + self.pooling_fun = global_mean_pool + elif pooling == "add": + self.pooling_fun = global_add_pool + else: + raise RuntimeError("Currently supports only 'add' or 'mean'.") + + fc_layers = [ + nn.Linear(dim_in // 2**n, dim_in // 2 ** (n + 1), bias=True) + for n in range(L) + ] + assert dim_in // 2**L >= dim_out, "Too much dim reduction!" + fc_layers.append(nn.Linear(dim_in // 2**L, dim_out, bias=True)) + self.fc_layers = nn.ModuleList(fc_layers) + self.L = L + self.activation = activation() + self.dim_out = dim_out + + def forward(self, data: Data) -> Tensor: + """Forward Pass.""" + graph_emb = self.pooling_fun(data.x, data.batch) + for i in range(self.L): + graph_emb = self.fc_layers[i](graph_emb) + graph_emb = self.activation(graph_emb) + graph_emb = self.fc_layers[self.L](graph_emb) + # Original code applied a final linear layer to project to dim_out, + # but we will let the Task layer do that. + return graph_emb diff --git a/src/graphnet/models/gnn/__init__.py b/src/graphnet/models/gnn/__init__.py index 8bf94af19..9bbc1fb20 100644 --- a/src/graphnet/models/gnn/__init__.py +++ b/src/graphnet/models/gnn/__init__.py @@ -7,3 +7,4 @@ from .RNN_tito import RNN_TITO from .icemix import DeepIce from .particlenet import ParticleNeT +from .grit import GRIT diff --git a/src/graphnet/models/gnn/grit.py b/src/graphnet/models/gnn/grit.py new file mode 100644 index 000000000..f42928a3f --- /dev/null +++ b/src/graphnet/models/gnn/grit.py @@ -0,0 +1,164 @@ +"""Implementation of GRIT, a graph transformer model. + +Original author: Liheng Ma +Original code: https://github.com/LiamMa/GRIT +Paper: "Graph Inductive Biases in Transformers without Message Passing", + https://arxiv.org/abs/2305.17589 + +Adapted by: Philip Weigel +""" + +import torch.nn as nn +from torch import Tensor +from torch_geometric.data import Data + +from graphnet.models.gnn.gnn import GNN + +from graphnet.models.components.layers import ( + GritTransformerLayer, + SANGraphHead, +) +from graphnet.models.components.embedding import ( + RRWPLinearEdgeEncoder, + RRWPLinearNodeEncoder, + LinearNodeEncoder, + LinearEdgeEncoder, + RWSELinearNodeEncoder, +) + + +class GRIT(GNN): + """GRIT is a graph transformer model. + + Original code: + https://github.com/LiamMa/GRIT/blob/main/grit/network/grit_model.py + """ + + def __init__( + self, + nb_inputs: int, + hidden_dim: int, + nb_outputs: int = 1, + ksteps: int = 21, + n_layers: int = 10, + n_heads: int = 8, + pad_to_full_graph: bool = True, + add_node_attr_as_self_loop: bool = False, + dropout: float = 0.0, + fill_value: float = 0.0, + norm: nn.Module = nn.BatchNorm1d, + attn_dropout: float = 0.2, + edge_enhance: bool = True, + update_edges: bool = True, + attn_clamp: float = 5.0, + activation: nn.Module = nn.ReLU, + attn_activation: nn.Module = nn.ReLU, + norm_edges: bool = True, + enable_edge_transform: bool = True, + pred_head_layers: int = 2, + pred_head_activation: nn.Module = nn.ReLU, + pred_head_pooling: str = "mean", + position_encoding: str = "NoPE", + ): + """Construct `GRIT` model. + + Args: + nb_inputs: Number of inputs. + hidden_dim: Size of hidden dimension. + nb_outputs: Size of output dimension. + ksteps: Number of random walk steps. + n_layers: Number of GRIT layers. + n_heads: Number of heads in MHA. + pad_to_full_graph: Pad to form fully-connected graph. + add_node_attr_as_self_loop: Adds node attr as an self-edge. + dropout: Dropout probability. + fill_value: Padding value. + norm: Uninstantiated normalization layer. + Either `torch.nn.BatchNorm1d` or `torch.nn.LayerNorm`. + attn_dropout: Attention dropout probability. + edge_enhance: Applies learnable weight matrix with node-pair in + output node calculation for MHA. + update_edges: Update edge values after GRIT layer. + attn_clamp: Clamp absolute value of attention scores to a value. + activation: Uninstantiated activation function. + E.g. `torch.nn.ReLU` + attn_activation: Uninstantiated attention activation function. + E.g. `torch.nn.ReLU` + norm_edges: Apply normalization layer to edges. + enable_edge_transform: Apply transformation to edges. + pred_head_layers: Number of layers in the prediction head. + pred_head_activation: Uninstantiated prediction head activation + function. E.g. `torch.nn.ReLU` + pred_head_pooling: Pooling function to use for the prediction head, + either "mean" (default) or "add". + position_encoding: Method of position encoding. + """ + super().__init__(nb_inputs, nb_outputs) + self.position_encoding = position_encoding.lower() + if self.position_encoding == "nope": + encoders = [ + LinearNodeEncoder(nb_inputs, hidden_dim), + LinearEdgeEncoder(hidden_dim), + ] + elif self.position_encoding == "rrwp": + encoders = [ + LinearNodeEncoder(nb_inputs, hidden_dim), + LinearEdgeEncoder(hidden_dim), + RRWPLinearNodeEncoder(ksteps, hidden_dim), + RRWPLinearEdgeEncoder( + ksteps, + hidden_dim, + pad_to_full_graph=pad_to_full_graph, + add_node_attr_as_self_loop=add_node_attr_as_self_loop, + fill_value=fill_value, + ), + ] + elif self.position_encoding == "rwse": + encoders = [ + LinearNodeEncoder(nb_inputs, hidden_dim - (ksteps - 1)), + RWSELinearNodeEncoder(ksteps - 1, hidden_dim), + ] + self.encoders = nn.ModuleList(encoders) + + layers = [] + for _ in range(n_layers): + layers.append( + GritTransformerLayer( + in_dim=hidden_dim, + out_dim=hidden_dim, + num_heads=n_heads, + dropout=dropout, + activation=activation, + attn_dropout=attn_dropout, + norm=norm, + residual=True, + norm_edges=norm_edges, + enable_edge_transform=enable_edge_transform, + update_edges=update_edges, + attn_activation=attn_activation, + attn_clamp=attn_clamp, + attn_edge_enhance=edge_enhance, + ) + ) + self.layers = nn.ModuleList(layers) + self.head = SANGraphHead( + dim_in=hidden_dim, + dim_out=nb_outputs, + L=pred_head_layers, + activation=pred_head_activation, + pooling=pred_head_pooling, + ) + + def forward(self, x: Data) -> Tensor: + """Forward pass.""" + for encoder in self.encoders: + x = encoder(x) + + # Apply GRIT layers + for layer in self.layers: + x = layer(x) + + # Graph head + x = self.head(x) + + return x diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index e5db7d735..e6970bc62 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -6,4 +6,9 @@ """ from .graph_definition import GraphDefinition -from .graphs import KNNGraph, EdgelessGraph +from .graphs import ( + KNNGraph, + EdgelessGraph, + KNNGraphRRWP, + KNNGraphRWSE, +) diff --git a/src/graphnet/models/graphs/edges/__init__.py b/src/graphnet/models/graphs/edges/__init__.py index ad11df041..26088d1d5 100644 --- a/src/graphnet/models/graphs/edges/__init__.py +++ b/src/graphnet/models/graphs/edges/__init__.py @@ -5,5 +5,11 @@ and their features. """ -from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges +from .edges import ( + EdgeDefinition, + KNNEdges, + RadialEdges, + EuclideanEdges, + KNNDistanceEdges, +) from .minkowski import MinkowskiKNNEdges diff --git a/src/graphnet/models/graphs/edges/edges.py b/src/graphnet/models/graphs/edges/edges.py index d953ae0a0..2ce43fc68 100644 --- a/src/graphnet/models/graphs/edges/edges.py +++ b/src/graphnet/models/graphs/edges/edges.py @@ -6,6 +6,8 @@ import torch from torch_geometric.nn import knn_graph, radius_graph from torch_geometric.data import Data +from torch_geometric.utils import to_undirected +from torch_geometric.utils.num_nodes import maybe_num_nodes from graphnet.models.utils import calculate_distance_matrix from graphnet.models import Model @@ -82,6 +84,59 @@ def _construct_edges(self, graph: Data) -> Data: return graph +class KNNDistanceEdges(KNNEdges): + """Builds edges from the k-nearest neighbours with distance attribute.""" + + def __init__( + self, + nb_nearest_neighbours: int, + columns: List[int] = [0, 1, 2], + ): + """K-NN Edge definition with edge distances. + + Will connect nodes together with their `nb_nearest_neighbours` + nearest neighbours in the feature space given by `columns`. The + edges will be assigned values of the distance between the connecting + nodes. + + Args: + nb_nearest_neighbours: number of neighbours. + columns: Node features to use for distance calculation. + Defaults to [0,1,2]. + """ + # Base class constructor + super().__init__( + nb_nearest_neighbours=nb_nearest_neighbours, + columns=columns, + ) + + def _construct_edges(self, graph: Data) -> Data: + """Define K-NN edges.""" + graph = super()._construct_edges(graph) + + if graph.edge_index.numel() == 0: # Check if edge_index is empty + num_nodes = graph.num_nodes + self_loops = torch.arange(num_nodes).repeat(2, 1) + graph.edge_index = self_loops + + graph.num_nodes = maybe_num_nodes(graph.edge_index) + graph.edge_index = to_undirected( + graph.edge_index, num_nodes=graph.num_nodes + ) + position_data = graph.x[:, self._columns] + + src, tgt = graph.edge_index + + src_pos = position_data[src] + tgt_pos = position_data[tgt] + diff = src_pos - tgt_pos + + # Shape: [num_edges, 1] + graph.edge_attr = torch.norm(diff, p=2, dim=-1).unsqueeze(1) + + return graph + + class RadialEdges(EdgeDefinition): """Builds graph from a sphere of chosen radius centred at each node.""" diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index e642ed06c..658e1320c 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -2,12 +2,20 @@ from typing import List, Optional, Dict, Union, Any import torch +import numpy as np from numpy.random import Generator +from torch_geometric.data import Data + from .graph_definition import GraphDefinition from graphnet.models.detector import Detector -from graphnet.models.graphs.edges import KNNEdges +from graphnet.models.graphs.edges import ( + EdgeDefinition, + KNNEdges, + KNNDistanceEdges, +) from graphnet.models.graphs.nodes import NodeDefinition, NodesAsPulses +from graphnet.models.utils import add_full_rrwp, get_rw_landing_probs class KNNGraph(GraphDefinition): @@ -23,6 +31,7 @@ def __init__( seed: Optional[Union[int, Generator]] = None, nb_nearest_neighbours: int = 8, columns: List[int] = [0, 1, 2], + distance_as_edge_feature: bool = False, **kwargs: Any, ) -> None: """Construct k-nn graph representation. @@ -42,12 +51,17 @@ def __init__( Defaults to 8. columns: node feature columns used for distance calculation. Defaults to [0, 1, 2]. + distance_as_edge_feature: Add edge distances as an edge feature. + Defaults to False. """ # Base class constructor + edge_definition = ( + KNNDistanceEdges if distance_as_edge_feature else KNNEdges + ) super().__init__( detector=detector, node_definition=node_definition or NodesAsPulses(), - edge_definition=KNNEdges( + edge_definition=edge_definition( nb_nearest_neighbours=nb_nearest_neighbours, columns=columns, ), @@ -100,3 +114,156 @@ def __init__( seed=seed, **kwargs, ) + + +class KNNGraphRRWP(GraphDefinition): + """KNN Graph with relative random walk probabilities (RRWP). + + Identical to KNNGraph, but with five extra fields containing absolute and + relative positional encoding using RRWP. + + `abs_pe = graph["rrwp"] # RRWP absolute positional encoding values` + `rrwp_val = graph["rrwp_val"] # Non-zero values of the RRWP tensor` + `rrwp_index = graph["rrwp_index] # Corresponding row, col indices` `degree + = graph["deg"] # Degree of each node (num. of incoming edges)` + """ + + def __init__( + self, + detector: Detector, + node_definition: Optional[NodeDefinition] = None, + edge_definition: Optional[EdgeDefinition] = None, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + nb_nearest_neighbours: int = 8, + columns: List[int] = [0, 1, 2], + walk_length: int = 8, + **kwargs: Any, + ) -> None: + """Construct k-nn graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + edge_definition: Definition of edges in the graph. + input_feature_names: Name of input feature columns. + dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. + nb_nearest_neighbours: Number of edges for each node. + Defaults to 8. + columns: node feature columns used for distance calculation. + Defaults to [0, 1, 2]. + walk_length: number of steps for the random walk. + Defaults to 8. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition or NodesAsPulses(), + edge_definition=edge_definition + or KNNDistanceEdges( + nb_nearest_neighbours=nb_nearest_neighbours, + columns=columns, + ), + dtype=dtype, + input_feature_names=input_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, + **kwargs, + ) + self.walk_length = walk_length + + def forward( # type: ignore + self, + input_features: np.ndarray, + input_feature_names: List[str], + **kwargs, + ) -> Data: + """Forward pass.""" + graph = super().forward(input_features, input_feature_names, **kwargs) + graph = add_full_rrwp(graph, walk_length=self.walk_length) + + return graph + + +class KNNGraphRWSE(GraphDefinition): + """KNN Graph with random walk structural encoding (RWSE). + + Identical to KNNGraph but with an additional field containing the values + obtained from RWSE. The encoding can be accessed via + + `rwse = graph["rwse"] # random walk structural encoding` + """ + + def __init__( + self, + detector: Detector, + node_definition: Optional[NodeDefinition] = None, + edge_definition: Optional[EdgeDefinition] = None, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + nb_nearest_neighbours: int = 8, + columns: List[int] = [0, 1, 2], + walk_length: int = 8, + **kwargs: Any, + ) -> None: + """Construct k-nn graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + edge_definition: Definition of edges in the graph. + input_feature_names: Name of input feature columns. + dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. + nb_nearest_neighbours: Number of edges for each node. + Defaults to 8. + columns: node feature columns used for distance calculation. + Defaults to [0, 1, 2]. + walk_length: number of steps for the random walk. + Defaults to 8. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition or NodesAsPulses(), + edge_definition=edge_definition + or KNNEdges( + nb_nearest_neighbours=nb_nearest_neighbours, + columns=columns, + ), + dtype=dtype, + input_feature_names=input_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, + **kwargs, + ) + self.walk_length = walk_length + + def forward( # type: ignore + self, + input_features: np.ndarray, + input_feature_names: List[str], + **kwargs, + ) -> Data: + """Forward pass.""" + graph = super().forward(input_features, input_feature_names, **kwargs) + ksteps = torch.arange(1, self.walk_length) + graph.rwse = get_rw_landing_probs( + ksteps=ksteps, edge_index=graph.edge_index, edge_weight=None + ) + return graph diff --git a/src/graphnet/models/utils.py b/src/graphnet/models/utils.py index 11b73d06f..06cf41f8c 100644 --- a/src/graphnet/models/utils.py +++ b/src/graphnet/models/utils.py @@ -1,13 +1,17 @@ """Utility functions for `graphnet.models`.""" -from typing import List, Tuple, Any, Union -from torch_geometric.nn import knn_graph -from torch_geometric.data import Batch +from typing import List, Tuple, Any, Union, Optional + import torch from torch import Tensor, LongTensor -from torch_geometric.utils import homophily -from torch_geometric.data import Data +from torch_geometric.nn import knn_graph +from torch_geometric.data import Batch, Data +from torch_geometric.utils import homophily, to_dense_adj +from torch_geometric.utils.num_nodes import maybe_num_nodes + +from torch_scatter import scatter, scatter_add +from torch_sparse import SparseTensor def calculate_xyzt_homophily( @@ -116,3 +120,188 @@ def get_fields(data: Union[Data, List[Data]], fields: List[str]) -> Tensor: torch.cat([d[label].reshape(-1, 1) for d in data], dim=0) ) return torch.cat(labels, dim=1) + + +def full_edge_index( + edge_index: Tensor, batch: Optional[Tensor] = None +) -> Tensor: + """Return the full batched sparse adjacency matrices given by edge indices. + + Return batched sparse adjacency matrices with exactly those edges that + are not in the input `edge_index` while ignoring self-loops. Implementation + inspired by `torch_geometric.utils.to_dense_adj`. + + Original code: + https://github.com/LiamMa/GRIT/blob/main/grit/encoder/rrwp_encoder.py + + Args: + edge_index: The edge indices. + batch: Batch vector, which assigns each node to a specific example. + + Returns: + Complementary edge index. + """ + if batch is None: + batch = edge_index.new_zeros(edge_index.max().item() + 1) + + batch_size = batch.max().item() + 1 + one = batch.new_ones(batch.size(0)) + num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce="add") + cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) + + negative_index_list = [] + for i in range(batch_size): + n = num_nodes[i].item() + size = [n, n] + adj = torch.ones(size, dtype=torch.short, device=edge_index.device) + + adj = adj.view(size) + _edge_index = adj.nonzero(as_tuple=False).t().contiguous() + negative_index_list.append(_edge_index + cum_nodes[i]) + + edge_index_full = torch.cat(negative_index_list, dim=1).contiguous() + return edge_index_full + + +@torch.no_grad() +def add_full_rrwp( + data: Data, + walk_length: int = 8, + attr_name_abs: str = "rrwp", + attr_name_rel: str = "rrwp", + add_identity: bool = True, + spd: bool = False, +) -> Data: + """Add relative random walk probabilities. + + Original code: + https://github.com/LiamMa/GRIT/blob/main/grit/transform/rrwp.py + + Args: + data: Input data. + walk_length: Number of random walks for encoding. + attr_name_abs: Absolute position encoding name. + attr_name_rel: Relative position encoding name. + add_identity: Add identity matrix to position encoding. + spd: Use shortest path distances. + """ + num_nodes = data.num_nodes + edge_index, edge_weight = data.edge_index, data.edge_weight + + adj = SparseTensor.from_edge_index( + edge_index, + edge_weight, + sparse_sizes=(num_nodes, num_nodes), + ) + + # Compute D^{-1} A: + deg_inv = 1.0 / adj.sum(dim=1) + deg_inv[deg_inv == float("inf")] = 0 + adj = adj * deg_inv.view(-1, 1) + adj = adj.to_dense() + + pe_list = [] + i = 0 + if add_identity: + pe_list.append(torch.eye(num_nodes, dtype=data.x.dtype)) + i = i + 1 + + out = adj + pe_list.append(adj) + + if walk_length > 2: + for j in range(i + 1, walk_length): + out = out @ adj + pe_list.append(out) + + pe = torch.stack(pe_list, dim=-1) # n x n x k + + abs_pe = pe.diagonal().transpose(0, 1) # n x k + + rel_pe = SparseTensor.from_dense(pe, has_value=True) + rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() + rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) + # The framework of GRIT performing right-mul while adj is row-normalized, + # need to switch the order or row and col. + # Note: both can work but the current version is more reasonable. + + if spd: + spd_idx = walk_length - torch.arange(walk_length) + val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0) + val = torch.argmax(val, dim=-1) + rel_pe_val = torch.nn.functional.one_hot(val, walk_length).type( + torch.float + ) + abs_pe = torch.zeros_like(abs_pe) + + data[attr_name_abs] = abs_pe + data[f"{attr_name_rel}_index"] = rel_pe_idx + data[f"{attr_name_rel}_val"] = rel_pe_val + + return data + + +def get_rw_landing_probs( + ksteps: List, + edge_index: Tensor, + edge_weight: Tensor = None, + num_nodes: Optional[int] = None, + space_dim: int = 0, +) -> Tensor: + """Compute Random Walk landing probabilities for given list of K steps. + + Original code: + https://github.com/ETH-DISCO/Benchmarking-PEs + Args: + ksteps: List of k-steps for which to compute the RW landings + edge_index: PyG sparse representation of the graph + edge_weight: (optional) Edge weights + num_nodes: (optional) Number of nodes in the graph + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the random-walk diagonal by a factor `k^(space_dim/2)`. + In euclidean space, this correction means that the height of + the gaussian distribution stays almost constant across the number + of steps, if `space_dim` is the dimension of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + print(edge_index.shape) + if edge_weight is None: + edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) + num_nodes = maybe_num_nodes(edge_index, num_nodes) + source = edge_index[0] + + # Out degrees + deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) + deg_inv = deg.pow(-1.0) + deg_inv.masked_fill_(deg_inv == float("inf"), 0) + + if edge_index.numel() == 0: + P = edge_index.new_zeros((1, num_nodes, num_nodes)) + else: + # P = D^-1 * A + # 1 x (Num nodes) x (Num nodes) + P = torch.diag(deg_inv) @ to_dense_adj( + edge_index, max_num_nodes=num_nodes + ) + rws = [] + if ksteps == list(range(min(ksteps), max(ksteps) + 1)): + # Efficient way if ksteps are a consecutive sequence + Pk = P.clone().detach().matrix_power(min(ksteps)) + for k in range(min(ksteps), max(ksteps) + 1): + rws.append( + torch.diagonal(Pk, dim1=-2, dim2=-1) * (k ** (space_dim / 2)) + ) + Pk = Pk @ P + else: + # Explicitly raising P to power k for each k \in ksteps. + for k in ksteps: + rws.append( + torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) + * (k ** (space_dim / 2)) + ) + + # (Num nodes) x (K steps) + rw_landing = torch.cat(rws, dim=0).transpose(0, 1) + return rw_landing