Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak committed Jul 25, 2024
0 parents commit 05056d6
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*__pycache__
27 changes: 27 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from torch_geometric.data import Data
def get_random_graph(nodes, cutoff):

positions = torch.randn(nodes, 3)

distance_matrix = positions[:, None, :] - positions[None, :, :]
distance_matrix = torch.linalg.norm(distance_matrix, dim=-1)
assert distance_matrix.shape == (nodes, nodes)

senders, receivers = torch.nonzero(distance_matrix < cutoff).T

z = torch.zeros(len(positions), dtype=torch.int32) # Create atomic_numbers tensor

# Create edge index tensor by stacking senders and receivers
edge_index = torch.stack([senders, receivers], dim=0)

# Create a PyTorch Geometric Data object
graph = Data(
pos = positions,
relative_vectors = positions[receivers] - positions[senders],
numbers=z,
edge_index=edge_index,
num_nodes=len(positions)
)

return graph
166 changes: 166 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import torch
import torch.nn as nn
from e3nn import o3
from utils import scatter_mean


class AtomEmbedding(nn.Module):
"""Embeds atomic atomic_numbers into a learnable vector space."""

def __init__(self, embed_dims: int, max_atomic_number: int):
super().__init__()
self.embed_dims = embed_dims
self.max_atomic_number = max_atomic_number
self.embedding = nn.Embedding(num_embeddings=max_atomic_number, embedding_dim=embed_dims)

def forward(self, atomic_numbers: torch.Tensor) -> torch.Tensor:
atom_embeddings = self.embedding(atomic_numbers)
return atom_embeddings

class MLP(nn.Module):

def __init__(
self,
input_dims: int,
output_dims: int,
hidden_dims: int = 32,
num_layers: int = 2):

super(MLP, self).__init__()
layers = []
for i in range(num_layers - 1):
layers.append(nn.Linear(input_dims if i == 0 else hidden_dims, hidden_dims))
layers.append(nn.LayerNorm(hidden_dims))
layers.append(nn.SiLU())
layers.append(nn.Linear(hidden_dims, output_dims))
self.model = nn.Sequential(*layers)

def forward(self, x):
return self.model(x)

class SimpleNetwork(nn.Module):
"""A layer of a simple E(3)-equivariant message passing network."""

sh_lmax: int = 2
lmax: int = 2
init_node_features: int = 16
max_atomic_number: int = 12
num_hops: int = 2
output_dims: int = 1

def __init__(self,
relative_vectors_irreps: o3.Irreps,
node_features_irreps: o3.Irreps):

super().__init__()

self.embed = AtomEmbedding(self.init_node_features, self.max_atomic_number)
self.sph = o3.SphericalHarmonics(irreps_out=o3.Irreps.spherical_harmonics(self.sh_lmax), normalize=True, normalization="norm")


# Currently hardcoding 2 layers

print("node_features_irreps", node_features_irreps)

# Layer 0
self.tp = o3.experimental.FullTensorProductv2(relative_vectors_irreps.regroup(),
node_features_irreps.regroup(),
filter_ir_out=[o3.Irrep(f"{l}e") for l in range(self.lmax+1)] + [o3.Irrep(f"{l}o") for l in range(self.lmax+1)])
self.linear = o3.Linear(irreps_in=self.tp.irreps_out.regroup(), irreps_out=self.tp.irreps_out.regroup())
print("TP+Linear", self.linear.irreps_out)
self.mlp = MLP(input_dims = 1, # Since we are inputing the norms will always be (..., 1)
output_dims = self.tp.irreps_out.num_irreps)


self.elementwise_tp = o3.experimental.ElementwiseTensorProductv2(o3.Irreps(f"{self.tp.irreps_out.num_irreps}x0e"), self.linear.irreps_out.regroup())
print("node feature broadcasted", self.elementwise_tp.irreps_out)


# Layer 1
node_features_irreps = self.elementwise_tp.irreps_out

print("node_features_irreps", node_features_irreps)
self.tp2 = o3.experimental.FullTensorProductv2(relative_vectors_irreps.regroup(),
node_features_irreps.regroup(),
filter_ir_out=[o3.Irrep(f"{l}e") for l in range(self.lmax+1)] + [o3.Irrep(f"{l}o") for l in range(self.lmax+1)])
self.linear2 = o3.Linear(irreps_in=self.tp2.irreps_out.regroup(), irreps_out=self.tp2.irreps_out.regroup())
print("Layer 1 TP+Linear", self.linear2.irreps_out)
self.mlp2 = MLP(input_dims = 1, # Since we are inputing the norms will always be (..., 1)
output_dims = self.tp2.irreps_out.num_irreps)


print(f"Layer 1 scalars {self.tp2.irreps_out.num_irreps}x0e")
print(f"Layer 1 node_features {self.linear2.irreps_out.regroup()}")
self.elementwise_tp2 = o3.experimental.ElementwiseTensorProductv2(o3.Irreps(f"{self.tp2.irreps_out.num_irreps}x0e"), self.linear2.irreps_out.regroup())
print("Layer 1 node feature broadcasted", self.elementwise_tp2.irreps_out)

# Poor mans filter function (Can already feel the judgement). Replicating irreps_array.filter("0e")
self.filter_tp = o3.experimental.FullTensorProductv2(self.tp.irreps_out.regroup(), o3.Irreps("0e"), filter_ir_out=[o3.Irrep("0e")])
self.register_buffer("dummy_input", torch.ones(1))

print("aggregated node features", self.filter_tp.irreps_out)

self.readout_mlp = MLP(input_dims = self.filter_tp.irreps_out.num_irreps,
output_dims = self.output_dims)

def forward(self,
numbers: torch.Tensor,
relative_vectors: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int) -> torch.Tensor:

node_features = self.embed(numbers)
relative_vectors = relative_vectors
senders, receivers = edge_index

relative_vectors_sh = self.sph(relative_vectors)
relative_vectors_norm = torch.linalg.norm(relative_vectors, axis=-1, keepdims=True)


# Currently harcoding 2 hops


# Layer 0

# Tensor product of the relative vectors and the neighbouring node features.
node_features_broadcasted = node_features[senders]

tp = self.tp(relative_vectors_sh, node_features_broadcasted)


# Apply linear
tp = self.linear(tp)


# Simply multiply each irrep by a learned scalar, based on the norm of the relative vector.
scalars = self.mlp(relative_vectors_norm)
node_features_broadcasted = self.elementwise_tp(scalars, tp)


# Aggregate the node features back.
node_features = scatter_mean(
node_features_broadcasted,
index=receivers,
output_dim = node_features.shape[0]
)

# Layer 1
node_features_broadcasted = node_features[senders]

tp2 = self.tp2(relative_vectors_sh, node_features_broadcasted)
tp2 = self.linear2(tp2)
scalars2 = self.mlp2(relative_vectors_norm)
node_features_broadcasted = self.elementwise_tp2(scalars2, tp2)
node_features = scatter_mean(
node_features_broadcasted,
index=receivers,
output_dim = node_features.shape[0]
)

# # Global readout.

# Filter out 0e
node_features = self.filter_tp(node_features, self.dummy_input)

graph_globals = scatter_mean(node_features, output_dim=[num_nodes])
return self.readout_mlp(graph_globals)
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
e3nn @ git+https://github.com/mitkotak/e3nn@linear_pt2
torch-geometric
23 changes: 23 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import sys
sys.path.append("..")

from model import SimpleNetwork
from data import get_random_graph
import torch
from e3nn import o3

graph = get_random_graph(5, 2.5)
model = SimpleNetwork(
relative_vectors_irreps=o3.Irreps.spherical_harmonics(lmax=2),
node_features_irreps=o3.Irreps("16x0e"),
)

# Currently turning off since Linear still needs weights
# Also need confirm that the model is working
model = torch.compile(model, fullgraph=True, disable=True)


model(graph.numbers,
graph.relative_vectors,
graph.edge_index,
graph.num_nodes)
54 changes: 54 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
## Scatter mean function (courtesy of ChatGPT)

import torch

def scatter_mean(input, index=None, output_dim=None):
if index is not None:
# Case 1: Index is specified
output_size = index.max().item() + 1
output = torch.zeros(output_size, input.size(1), device=input.device)
n = torch.zeros(output_size, device=input.device)

for i in range(input.size(0)):
idx = index[i]
n[idx] += 1
output[idx] += (input[i] - output[idx]) / n[idx]

return output

elif output_dim is not None:
# Case 2: Index is skipped, output_dim is specified
output = torch.zeros(len(output_dim), input.size(1), device=input.device)

start_idx = 0
for i, dim in enumerate(output_dim):
end_idx = start_idx + dim
if dim > 0:
segment_sum = input[start_idx:end_idx].sum(dim=0)
output[i] = segment_sum / dim
start_idx = end_idx

return output

else:
raise ValueError("Either 'index' or 'output_dim' must be specified.")

# # Example usage for Case 1 (index specified):
# input1 = torch.randn(3000, 144)
# index1 = torch.randint(0, 1000, (3000,))
# output1 = scatter_mean(input1, index=index1)
# print("Output shape (Case 1):", output1.shape)

# # Example usage for Case 2 (index skipped, output_dim specified):
# input2 = torch.randn(3000, 144)
# output_dim = [3000]
# output2 = scatter_mean(input2, output_dim=output_dim)
# print("Output shape (Case 2):", output2.shape)

# # Example usage for Case 3 (both spe):
# input = torch.randn(3000, 144)
# index = torch.randint(0, 1000, (3000,))
# output_dim = [3000]

# output = scatter_mean(input, index, output_dim)
# print(output.size()) # Should print torch.Size([1000, 144])

0 comments on commit 05056d6

Please sign in to comment.