Skip to content

Commit

Permalink
chaneg
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak committed Nov 25, 2024
1 parent 99631e3 commit eb4d8fb
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 119 deletions.
117 changes: 0 additions & 117 deletions train.py

This file was deleted.

119 changes: 117 additions & 2 deletions model.py → train_e3nn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
import time
import os
import nvtx

import torch
import torch.nn as nn
from e3nn import o3
import torch.optim as optim
import torch.nn.functional as F
from torch.fx.experimental.proxy_tensor import make_fx

from torch_geometric.data import Data
from torch_scatter import scatter_mean, scatter_sum

from tqdm.auto import tqdm
from e3nn import o3
from e3nn.util.jit import prepare

import numpy as np

torch.manual_seed(0)

# Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18
torch.set_float32_matmul_precision("high")
import torch._dynamo.config
import torch._inductor.config

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True

from data import get_tetris, get_random_graph

device = "cuda" if torch.cuda.is_available() else "cpu"

class Layer(nn.Module):
"""A layer of a simple E(3)-equivariant message passing network."""

Expand Down Expand Up @@ -129,4 +155,93 @@ def forward(self,
odd, even1, even2 = pred[:, :1], pred[:, 1:2], pred[:, 2:]
logits = torch.concatenate([odd * even1, -odd * even1, even2], dim=1)
assert logits.shape == (8, 8) # [num_graphs, num_classes]
return logits
return logits


def build_model():
model = Model()
return model

def train(steps=200):
model = prepare(build_model)()
model = model.to(device)
opt = optim.Adam(model.parameters(), lr=0.01)

@nvtx.annotate(color="red")
def loss_fn(model_output, graphs):
logits = model_output
labels = graphs.y # [num_graphs]
loss = F.cross_entropy(logits, labels)
loss = torch.mean(loss)
return loss, logits

@nvtx.annotate(color="blue")
def update_fn(model, opt, graphs):
model_output = model(graphs.numbers, graphs.pos, graphs.edge_index, graphs.batch)
loss, logits = loss_fn(model_output, graphs)

opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()

labels = graphs.y
preds = torch.argmax(logits, dim=1)
accuracy = (preds == labels).float().mean()

return loss.item(), accuracy.item()

# Dataset
graphs = get_tetris()
graphs = graphs.to(device=device)

# Compile
model = torch.compile(model, fullgraph=True, mode="reduce-overhead")

wall = time.perf_counter()
print("compiling...", flush=True)
# Warmup runs
for i in range(3):
loss, accuracy = update_fn(model, opt, graphs)
print(f"initial accuracy = {100 * accuracy:.0f}%", flush=True)
print(f"compilation took {time.perf_counter() - wall:.1f}s")

from ctypes import cdll
libcudart = cdll.LoadLibrary('libcudart.so')

# Train
timings = []
print("training...", flush=True)
for step in tqdm(range(steps)):
start = time.time()
if step == 20:
libcudart.cudaProfilerStart()

loss, accuracy = update_fn(model, opt, graphs)

timings.append(time.time() - start)
if step == 30:
libcudart.cudaProfilerStop()


print(f"final accuracy = {100 * accuracy:.0f}%")
print(f"Training time/step {np.mean(timings[20:])*1000:.3f} ms")

# # Export model
# so_path = torch._export.aot_compile(
# model,
# args = (graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch),
# options={"aot_inductor.output_path": os.path.join(os.getcwd(), "export/model.so"),
# })

# print("node_features", graphs.numbers)
# print("pos", graphs.pos)
# print("edge_index", graphs.edge_index)
# print("batch", graphs.batch)

# runner = torch._C._aoti.AOTIModelContainerRunnerCuda(os.path.join(os.getcwd(), f"export/model.so"), 1, device)
# outputs_export = runner.run([graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch])
# print(f"output {outputs_export[0]}")


if __name__ == "__main__":
train()
File renamed without changes.

0 comments on commit eb4d8fb

Please sign in to comment.