Skip to content

Commit

Permalink
tetris working
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak committed Aug 14, 2024
1 parent f6b2dd2 commit 4ee5e60
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 158 deletions.
5 changes: 3 additions & 2 deletions data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import radius_graph

def get_random_graph(nodes, cutoff) -> Data:
Expand Down Expand Up @@ -50,7 +51,7 @@ def get_tetris():
senders, receivers = edge_index

data = Data(
numbers=torch.ones((len(positions),), dtype=torch.int32), # node features
numbers=torch.ones((len(positions),1)), # node features
pos=positions, # node positions
edge_index=edge_index, # edge indices
y=label # graph label
Expand Down
215 changes: 103 additions & 112 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,135 +6,126 @@

from e3nn import o3

from torch_scatter import scatter_mean
from torch_scatter import scatter_mean, scatter_sum


class AtomEmbedding(nn.Module):
"""Embeds atomic atomic_numbers into a learnable vector space."""

def __init__(self, embed_dims: int, max_atomic_number: int):
super().__init__()
self.embed_dims = embed_dims
self.max_atomic_number = max_atomic_number
self.embedding = nn.Embedding(num_embeddings=max_atomic_number, embedding_dim=embed_dims)

def forward(self, atomic_numbers: torch.Tensor) -> torch.Tensor:
atom_embeddings = self.embedding(atomic_numbers)
return atom_embeddings

class MLP(nn.Module):

def __init__(
self,
input_dims: int,
output_dims: int,
hidden_dims: int = 32,
num_layers: int = 2):

super(MLP, self).__init__()
layers = []
for i in range(num_layers - 1):
layers.append(nn.Linear(input_dims if i == 0 else hidden_dims, hidden_dims))
layers.append(nn.LayerNorm(hidden_dims))
layers.append(nn.SiLU())
layers.append(nn.Linear(hidden_dims, output_dims))
self.model = nn.Sequential(*layers)

def forward(self, x):
return self.model(x)

class SimpleNetwork(nn.Module):
class Layer(nn.Module):
"""A layer of a simple E(3)-equivariant message passing network."""

sh_lmax: int = 2
lmax: int = 2
init_node_features: int = 16
max_atomic_number: int = 12
num_hops: int = 2
output_dims: int = 1


denominator: int = 1.5

def __init__(self,
relative_vectors_irreps: o3.Irreps,
node_features_irreps: o3.Irreps):
relative_positions_irreps: o3.Irreps,
node_features_irreps: o3.Irreps,
target_irreps: o3.Irreps):

super().__init__()

self.embed = AtomEmbedding(self.init_node_features, self.max_atomic_number)
self.sph = o3.SphericalHarmonics(irreps_out=o3.Irreps.spherical_harmonics(self.sh_lmax), normalize=True, normalization="norm")


# Currently hardcoding 1 layer

# print("node_features_irreps", node_features_irreps)
self.tp = o3.FullTensorProduct(relative_positions_irreps,
node_features_irreps)

self.tp = o3.FullTensorProduct(relative_vectors_irreps.regroup(),
node_features_irreps.regroup(),
filter_ir_out=[o3.Irrep(f"{l}e") for l in range(self.lmax+1)] + [o3.Irrep(f"{l}o") for l in range(self.lmax+1)])
self.linear = o3.Linear(irreps_in=self.tp.irreps_out.regroup(), irreps_out=self.tp.irreps_out.regroup())
# print("TP+Linear", self.linear.irreps_out)
self.mlp = MLP(input_dims = 1, # Since we are inputing the norms will always be (..., 1)
output_dims = self.tp.irreps_out.num_irreps)


self.elementwise_tp = o3.ElementwiseTensorProduct(o3.Irreps(f"{self.tp.irreps_out.num_irreps}x0e"), self.linear.irreps_out.regroup())
# print("node feature broadcasted", self.elementwise_tp.irreps_out)

# Poor mans filter function (Can already feel the judgement). Replicating irreps_array.filter("0e")
self.filter_tp = o3.FullTensorProduct(self.tp.irreps_out.regroup(), o3.Irreps("0e"), filter_ir_out=[o3.Irrep("0e")])
self.register_buffer("dummy_input", torch.ones(1))
tp_irreps = self.tp.irreps_out.regroup() + node_features_irreps

# print("aggregated node features", self.filter_tp.irreps_out)

self.readout_mlp = MLP(input_dims = self.filter_tp.irreps_out.num_irreps,
output_dims = self.output_dims)
self.linear = o3.Linear(irreps_in=tp_irreps,
irreps_out=target_irreps)

self.shortcut = o3.Linear(irreps_in=node_features_irreps,
irreps_out=target_irreps)

def forward(self,
numbers: torch.Tensor,
pos: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int) -> torch.Tensor:

node_features = self.embed(numbers)
senders, receivers = edge_index
relative_vectors = pos[receivers] - pos[senders]

relative_vectors_sh = self.sph(relative_vectors)
relative_vectors_norm = torch.linalg.norm(relative_vectors, axis=-1, keepdims=True)


# Currently harcoding 1 hop

# Layer 0
node_features,
relative_positions_sh,
senders, receivers) -> torch.Tensor:

# Tensor product of the relative vectors and the neighbouring node features.
node_features_broadcasted = node_features[senders]

tp = self.tp(relative_vectors_sh, node_features_broadcasted)


# Apply linear
tp = self.linear(tp)

# Resnet-style shortct
shortcut = self.shortcut(shortcut_aggregated)

# Simply multiply each irrep by a learned scalar, based on the norm of the relative vector.
scalars = self.mlp(relative_vectors_norm)
node_features_broadcasted = self.elementwise_tp(scalars, tp)


# Aggregate the node features back.
node_features = scatter_mean(
shortcut_aggregated = scatter_mean(
node_features_broadcasted,
receivers.unsqueeze(1).expand(-1, node_features_broadcasted.size(dim=1)),
dim=0
dim=0,
dim_size=node_features.shape[0]
)

# Tensor product of the relative vectors and the neighbouring node features.
tp = self.tp(relative_positions_sh, node_features_broadcasted)

# Construct message by appending to existing node_feature
messages = torch.cat([node_features_broadcasted, tp], dim=-1)

# Aggregate the node features
node_feats = scatter_mean(
messages,
receivers.unsqueeze(1).expand(-1, messages.size(dim=1)),
dim=0,
dim_size=node_features.shape[0]
)


# Normalize
node_feats = node_feats / self.denominator

# Apply linear
node_feats = self.linear(node_feats)

# Skipping scalar activation for now

# Add shortcut to node_features
node_features = node_feats + shortcut
return node_features

# # Global readout.
class Model(torch.nn.Module):

sh_lmax: int = 3

# Filter out 0e
node_features = self.filter_tp(node_features, self.dummy_input)
def __init__(self):

super().__init__()

graph_globals = scatter_mean(node_features,
torch.zeros(num_nodes, dtype=torch.int64),
dim=0,
dim_size=8)
return self.readout_mlp(graph_globals)
node_features_irreps = o3.Irreps("1x0e")
relative_positions_irreps = o3.Irreps([(1,(l,(-1)**l)) for l in range(1,self.sh_lmax+1)])
self.sph = o3.SphericalHarmonics(
irreps_out=relative_positions_irreps,
normalize=True, normalization="norm")


output_irreps = [o3.Irreps("32x0e+8x1o+8x2e"), o3.Irreps("32x0e+8x1e+8x1o+8x2e+8x2o")] + [o3.Irreps("0o + 7x0e")]

layers = []
for target_irreps in output_irreps:
layers.append(Layer(
relative_positions_irreps,
node_features_irreps,
target_irreps))
node_features_irreps = target_irreps

self.layers = torch.nn.ModuleList(layers)

def forward(self, graphs):

node_features, pos, edge_index, batch, num_nodes = graphs.numbers, graphs.pos, graphs.edge_index, graphs.batch, graphs.num_nodes
senders, receivers = edge_index
relative_positions= pos[receivers] - pos[senders]

# Apply spherical harmonics
relative_positions_sh = self.sph(relative_positions)

for layer in self.layers:
node_features = layer(
node_features,
relative_positions_sh,
senders,
receivers,
)

# Readout logits
pred = scatter_sum(
node_features,
batch,
dim=0,
dim_size=8) # [num_graphs, 1 + 7] = [8,8]
odd, even1, even2 = pred[:, :1], pred[:, 1:2], pred[:, 2:]
logits = torch.concatenate([odd * even1, -odd * even1, even2], dim=1)
assert logits.shape == (8, 8) # [num_graphs, num_classes]
return logits
105 changes: 61 additions & 44 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,77 @@
import sys
sys.path.append("..")

from model import SimpleNetwork
from data import get_tetris, get_random_graph
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
from tqdm.auto import tqdm
from e3nn import o3

torch._dynamo.config.capture_scalar_outputs = True
# Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18
torch.set_float32_matmul_precision("high")
import torch._dynamo.config
import torch._inductor.config

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True

from model import Model
from data import get_tetris, get_random_graph

graphs = get_tetris()
device = "cuda" if torch.cuda.is_available() else "cpu"

model = SimpleNetwork(
relative_vectors_irreps=o3.Irreps.spherical_harmonics(lmax=2),
node_features_irreps=o3.Irreps("16x0e"),
)
def train(steps=200):
model = Model()
model = model.to(device)
opt = optim.Adam(model.parameters(), lr=0.01)

optimizer = torch.optim.Adam(model.parameters())
def loss_fn(model_output, graphs):
logits = model_output
labels = graphs.y # [num_graphs]
loss = F.cross_entropy(logits, labels)
loss = torch.mean(loss)
return loss, logits

def loss_fn(graphs):
logits = model(graphs.numbers,
graphs.pos,
graphs.edge_index,
graphs.num_nodes)
labels = graphs.y.unsqueeze(-1).float() # [num_graphs]
loss = torch.nn.functional.cross_entropy(logits, labels)
return loss, logits
def update_fn(model, opt, graphs):
model_output = model(graphs)
loss, logits = loss_fn(model_output, graphs)

opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()

labels = graphs.y
preds = torch.argmax(logits, dim=1)
accuracy = (preds == labels).float().mean()

def apply_random_rotation(graphs):
"""Apply a random rotation to the nodes of the graph."""
alpha, beta, gamma = torch.rand(3) * 2 * torch.pi - torch.pi
return loss.item(), accuracy.item()

rotated_pos = o3.angles_to_matrix(alpha, beta, gamma) @ graphs.pos.T
rotated_pos = rotated_pos.T
# Compile the update function
update_fn_compiled = torch.compile(update_fn, mode="reduce-overhead")

rotated_graphs = graphs.clone()
rotated_graphs.pos = rotated_pos
return rotated_graphs
# Dataset
graphs = get_tetris()
graphs = graphs.to(device=device)

model.train()
for _ in range(10):

graphs = apply_random_rotation(graphs)
optimizer.zero_grad()
loss, logits = loss_fn(graphs)
loss.backward()
optimizer.step()
# compile jit
wall = time.perf_counter()
print("compiling...", flush=True)
# Warmup runs
for i in range(3):
loss, accuracy = update_fn_compiled(model, opt, graphs)
print(f"initial accuracy = {100 * accuracy:.0f}%", flush=True)
print(f"compilation took {time.perf_counter() - wall:.1f}s")

preds = torch.argmax(logits, dim=1)
accuracy = (preds == graphs.y.squeeze()).float().mean()
# Train
wall = time.perf_counter()
print("training...", flush=True)
for _ in tqdm(range(steps)):
loss, accuracy = update_fn_compiled(model, opt, graphs)

# es = torch.export.export(model,
# (graphs.numbers,
# graphs.pos,
# graphs.edge_index,
# graphs.num_nodes))
# print(es)
if accuracy == 1.0:
break

print(f"final accuracy = {100 * accuracy:.0f}%")
print(f"training took {time.perf_counter() - wall:.1f}s")

if __name__ == "__main__":
train()

0 comments on commit 4ee5e60

Please sign in to comment.