diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2812bcc --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*__pycache__ diff --git a/data.py b/data.py new file mode 100644 index 0000000..9c0db77 --- /dev/null +++ b/data.py @@ -0,0 +1,27 @@ +import torch +from torch_geometric.data import Data +def get_random_graph(nodes, cutoff): + + positions = torch.randn(nodes, 3) + + distance_matrix = positions[:, None, :] - positions[None, :, :] + distance_matrix = torch.linalg.norm(distance_matrix, dim=-1) + assert distance_matrix.shape == (nodes, nodes) + + senders, receivers = torch.nonzero(distance_matrix < cutoff).T + + z = torch.zeros(len(positions), dtype=torch.int32) # Create atomic_numbers tensor + + # Create edge index tensor by stacking senders and receivers + edge_index = torch.stack([senders, receivers], dim=0) + + # Create a PyTorch Geometric Data object + graph = Data( + pos = positions, + relative_vectors = positions[receivers] - positions[senders], + numbers=z, + edge_index=edge_index, + num_nodes=len(positions) + ) + + return graph diff --git a/model.py b/model.py new file mode 100644 index 0000000..f76200a --- /dev/null +++ b/model.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn +from e3nn import o3 +from utils import scatter_mean + + +class AtomEmbedding(nn.Module): + """Embeds atomic atomic_numbers into a learnable vector space.""" + + def __init__(self, embed_dims: int, max_atomic_number: int): + super().__init__() + self.embed_dims = embed_dims + self.max_atomic_number = max_atomic_number + self.embedding = nn.Embedding(num_embeddings=max_atomic_number, embedding_dim=embed_dims) + + def forward(self, atomic_numbers: torch.Tensor) -> torch.Tensor: + atom_embeddings = self.embedding(atomic_numbers) + return atom_embeddings + +class MLP(nn.Module): + + def __init__( + self, + input_dims: int, + output_dims: int, + hidden_dims: int = 32, + num_layers: int = 2): + + super(MLP, self).__init__() + layers = [] + for i in range(num_layers - 1): + layers.append(nn.Linear(input_dims if i == 0 else hidden_dims, hidden_dims)) + layers.append(nn.LayerNorm(hidden_dims)) + layers.append(nn.SiLU()) + layers.append(nn.Linear(hidden_dims, output_dims)) + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + +class SimpleNetwork(nn.Module): + """A layer of a simple E(3)-equivariant message passing network.""" + + sh_lmax: int = 2 + lmax: int = 2 + init_node_features: int = 16 + max_atomic_number: int = 12 + num_hops: int = 2 + output_dims: int = 1 + + def __init__(self, + relative_vectors_irreps: o3.Irreps, + node_features_irreps: o3.Irreps): + + super().__init__() + + self.embed = AtomEmbedding(self.init_node_features, self.max_atomic_number) + self.sph = o3.SphericalHarmonics(irreps_out=o3.Irreps.spherical_harmonics(self.sh_lmax), normalize=True, normalization="norm") + + + # Currently hardcoding 2 layers + + print("node_features_irreps", node_features_irreps) + + # Layer 0 + self.tp = o3.experimental.FullTensorProductv2(relative_vectors_irreps.regroup(), + node_features_irreps.regroup(), + filter_ir_out=[o3.Irrep(f"{l}e") for l in range(self.lmax+1)] + [o3.Irrep(f"{l}o") for l in range(self.lmax+1)]) + self.linear = o3.Linear(irreps_in=self.tp.irreps_out.regroup(), irreps_out=self.tp.irreps_out.regroup()) + print("TP+Linear", self.linear.irreps_out) + self.mlp = MLP(input_dims = 1, # Since we are inputing the norms will always be (..., 1) + output_dims = self.tp.irreps_out.num_irreps) + + + self.elementwise_tp = o3.experimental.ElementwiseTensorProductv2(o3.Irreps(f"{self.tp.irreps_out.num_irreps}x0e"), self.linear.irreps_out.regroup()) + print("node feature broadcasted", self.elementwise_tp.irreps_out) + + + # Layer 1 + node_features_irreps = self.elementwise_tp.irreps_out + + print("node_features_irreps", node_features_irreps) + self.tp2 = o3.experimental.FullTensorProductv2(relative_vectors_irreps.regroup(), + node_features_irreps.regroup(), + filter_ir_out=[o3.Irrep(f"{l}e") for l in range(self.lmax+1)] + [o3.Irrep(f"{l}o") for l in range(self.lmax+1)]) + self.linear2 = o3.Linear(irreps_in=self.tp2.irreps_out.regroup(), irreps_out=self.tp2.irreps_out.regroup()) + print("Layer 1 TP+Linear", self.linear2.irreps_out) + self.mlp2 = MLP(input_dims = 1, # Since we are inputing the norms will always be (..., 1) + output_dims = self.tp2.irreps_out.num_irreps) + + + print(f"Layer 1 scalars {self.tp2.irreps_out.num_irreps}x0e") + print(f"Layer 1 node_features {self.linear2.irreps_out.regroup()}") + self.elementwise_tp2 = o3.experimental.ElementwiseTensorProductv2(o3.Irreps(f"{self.tp2.irreps_out.num_irreps}x0e"), self.linear2.irreps_out.regroup()) + print("Layer 1 node feature broadcasted", self.elementwise_tp2.irreps_out) + + # Poor mans filter function (Can already feel the judgement). Replicating irreps_array.filter("0e") + self.filter_tp = o3.experimental.FullTensorProductv2(self.tp.irreps_out.regroup(), o3.Irreps("0e"), filter_ir_out=[o3.Irrep("0e")]) + self.register_buffer("dummy_input", torch.ones(1)) + + print("aggregated node features", self.filter_tp.irreps_out) + + self.readout_mlp = MLP(input_dims = self.filter_tp.irreps_out.num_irreps, + output_dims = self.output_dims) + + def forward(self, + numbers: torch.Tensor, + relative_vectors: torch.Tensor, + edge_index: torch.Tensor, + num_nodes: int) -> torch.Tensor: + + node_features = self.embed(numbers) + relative_vectors = relative_vectors + senders, receivers = edge_index + + relative_vectors_sh = self.sph(relative_vectors) + relative_vectors_norm = torch.linalg.norm(relative_vectors, axis=-1, keepdims=True) + + + # Currently harcoding 2 hops + + + # Layer 0 + + # Tensor product of the relative vectors and the neighbouring node features. + node_features_broadcasted = node_features[senders] + + tp = self.tp(relative_vectors_sh, node_features_broadcasted) + + + # Apply linear + tp = self.linear(tp) + + + # Simply multiply each irrep by a learned scalar, based on the norm of the relative vector. + scalars = self.mlp(relative_vectors_norm) + node_features_broadcasted = self.elementwise_tp(scalars, tp) + + + # Aggregate the node features back. + node_features = scatter_mean( + node_features_broadcasted, + index=receivers, + output_dim = node_features.shape[0] + ) + + # Layer 1 + node_features_broadcasted = node_features[senders] + + tp2 = self.tp2(relative_vectors_sh, node_features_broadcasted) + tp2 = self.linear2(tp2) + scalars2 = self.mlp2(relative_vectors_norm) + node_features_broadcasted = self.elementwise_tp2(scalars2, tp2) + node_features = scatter_mean( + node_features_broadcasted, + index=receivers, + output_dim = node_features.shape[0] + ) + + # # Global readout. + + # Filter out 0e + node_features = self.filter_tp(node_features, self.dummy_input) + + graph_globals = scatter_mean(node_features, output_dim=[num_nodes]) + return self.readout_mlp(graph_globals) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a50f078 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +e3nn @ git+https://github.com/mitkotak/e3nn@linear_pt2 +torch-geometric \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..9857926 --- /dev/null +++ b/train.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("..") + +from model import SimpleNetwork +from data import get_random_graph +import torch +from e3nn import o3 + +graph = get_random_graph(5, 2.5) +model = SimpleNetwork( + relative_vectors_irreps=o3.Irreps.spherical_harmonics(lmax=2), + node_features_irreps=o3.Irreps("16x0e"), +) + +# Currently turning off since Linear still needs weights +# Also need confirm that the model is working +model = torch.compile(model, fullgraph=True, disable=True) + + +model(graph.numbers, + graph.relative_vectors, + graph.edge_index, + graph.num_nodes) \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..03c5a56 --- /dev/null +++ b/utils.py @@ -0,0 +1,54 @@ +## Scatter mean function (courtesy of ChatGPT) + +import torch + +def scatter_mean(input, index=None, output_dim=None): + if index is not None: + # Case 1: Index is specified + output_size = index.max().item() + 1 + output = torch.zeros(output_size, input.size(1), device=input.device) + n = torch.zeros(output_size, device=input.device) + + for i in range(input.size(0)): + idx = index[i] + n[idx] += 1 + output[idx] += (input[i] - output[idx]) / n[idx] + + return output + + elif output_dim is not None: + # Case 2: Index is skipped, output_dim is specified + output = torch.zeros(len(output_dim), input.size(1), device=input.device) + + start_idx = 0 + for i, dim in enumerate(output_dim): + end_idx = start_idx + dim + if dim > 0: + segment_sum = input[start_idx:end_idx].sum(dim=0) + output[i] = segment_sum / dim + start_idx = end_idx + + return output + + else: + raise ValueError("Either 'index' or 'output_dim' must be specified.") + +# # Example usage for Case 1 (index specified): +# input1 = torch.randn(3000, 144) +# index1 = torch.randint(0, 1000, (3000,)) +# output1 = scatter_mean(input1, index=index1) +# print("Output shape (Case 1):", output1.shape) + +# # Example usage for Case 2 (index skipped, output_dim specified): +# input2 = torch.randn(3000, 144) +# output_dim = [3000] +# output2 = scatter_mean(input2, output_dim=output_dim) +# print("Output shape (Case 2):", output2.shape) + +# # Example usage for Case 3 (both spe): +# input = torch.randn(3000, 144) +# index = torch.randint(0, 1000, (3000,)) +# output_dim = [3000] + +# output = scatter_mean(input, index, output_dim) +# print(output.size()) # Should print torch.Size([1000, 144]) \ No newline at end of file