diff --git a/train.py b/train.py deleted file mode 100644 index d20b8d7..0000000 --- a/train.py +++ /dev/null @@ -1,117 +0,0 @@ -import time -import os -import nvtx - -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -from torch.fx.experimental.proxy_tensor import make_fx -from tqdm.auto import tqdm -from e3nn import o3 -from e3nn.util.jit import prepare - - -torch.manual_seed(0) - -# Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18 -torch.set_float32_matmul_precision("high") -import torch._dynamo.config -import torch._inductor.config - -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.triton.unique_kernel_names = True - -from model import Model -from data import get_tetris, get_random_graph - -device = "cuda" if torch.cuda.is_available() else "cpu" - - -def build_model(): - model = Model() - return model - -def train(steps=200): - model = prepare(build_model)() - model = model.to(device) - opt = optim.Adam(model.parameters(), lr=0.01) - - @nvtx.annotate(color="red") - def loss_fn(model_output, graphs): - logits = model_output - labels = graphs.y # [num_graphs] - loss = F.cross_entropy(logits, labels) - loss = torch.mean(loss) - return loss, logits - - @nvtx.annotate(color="blue") - def update_fn(model, opt, graphs): - model_output = model(graphs.numbers, graphs.pos, graphs.edge_index, graphs.batch) - loss, logits = loss_fn(model_output, graphs) - - opt.zero_grad(set_to_none=True) - loss.backward() - opt.step() - - labels = graphs.y - preds = torch.argmax(logits, dim=1) - accuracy = (preds == labels).float().mean() - - return loss.item(), accuracy.item() - - # Dataset - graphs = get_tetris() - graphs = graphs.to(device=device) - - # Compile - model = torch.compile(model, fullgraph=True, mode="reduce-overhead") - - wall = time.perf_counter() - print("compiling...", flush=True) - # Warmup runs - for i in range(3): - loss, accuracy = update_fn(model, opt, graphs) - print(f"initial accuracy = {100 * accuracy:.0f}%", flush=True) - print(f"compilation took {time.perf_counter() - wall:.1f}s") - - from ctypes import cdll - libcudart = cdll.LoadLibrary('libcudart.so') - - # Train - wall = time.perf_counter() - print("training...", flush=True) - for step in tqdm(range(steps)): - if step == 20: - libcudart.cudaProfilerStart() - - loss, accuracy = update_fn(model, opt, graphs) - - if step == 30: - libcudart.cudaProfilerStop() - - if accuracy == 1.0: - break - - print(f"final accuracy = {100 * accuracy:.0f}%") - print(f"training took {time.perf_counter() - wall:.1f}s") - - # # Export model - # so_path = torch._export.aot_compile( - # model, - # args = (graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch), - # options={"aot_inductor.output_path": os.path.join(os.getcwd(), "export/model.so"), - # }) - - # print("node_features", graphs.numbers) - # print("pos", graphs.pos) - # print("edge_index", graphs.edge_index) - # print("batch", graphs.batch) - - # runner = torch._C._aoti.AOTIModelContainerRunnerCuda(os.path.join(os.getcwd(), f"export/model.so"), 1, device) - # outputs_export = runner.run([graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch]) - # print(f"output {outputs_export[0]}") - - -if __name__ == "__main__": - train() \ No newline at end of file diff --git a/model.py b/train_e3nn.py similarity index 54% rename from model.py rename to train_e3nn.py index b63a1fa..a238580 100644 --- a/model.py +++ b/train_e3nn.py @@ -1,10 +1,36 @@ +import time +import os +import nvtx + import torch import torch.nn as nn -from e3nn import o3 +import torch.optim as optim +import torch.nn.functional as F +from torch.fx.experimental.proxy_tensor import make_fx from torch_geometric.data import Data from torch_scatter import scatter_mean, scatter_sum +from tqdm.auto import tqdm +from e3nn import o3 +from e3nn.util.jit import prepare + +import numpy as np + +torch.manual_seed(0) + +# Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18 +torch.set_float32_matmul_precision("high") +import torch._dynamo.config +import torch._inductor.config + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True + +from data import get_tetris, get_random_graph + +device = "cuda" if torch.cuda.is_available() else "cpu" + class Layer(nn.Module): """A layer of a simple E(3)-equivariant message passing network.""" @@ -129,4 +155,93 @@ def forward(self, odd, even1, even2 = pred[:, :1], pred[:, 1:2], pred[:, 2:] logits = torch.concatenate([odd * even1, -odd * even1, even2], dim=1) assert logits.shape == (8, 8) # [num_graphs, num_classes] - return logits \ No newline at end of file + return logits + + +def build_model(): + model = Model() + return model + +def train(steps=200): + model = prepare(build_model)() + model = model.to(device) + opt = optim.Adam(model.parameters(), lr=0.01) + + @nvtx.annotate(color="red") + def loss_fn(model_output, graphs): + logits = model_output + labels = graphs.y # [num_graphs] + loss = F.cross_entropy(logits, labels) + loss = torch.mean(loss) + return loss, logits + + @nvtx.annotate(color="blue") + def update_fn(model, opt, graphs): + model_output = model(graphs.numbers, graphs.pos, graphs.edge_index, graphs.batch) + loss, logits = loss_fn(model_output, graphs) + + opt.zero_grad(set_to_none=True) + loss.backward() + opt.step() + + labels = graphs.y + preds = torch.argmax(logits, dim=1) + accuracy = (preds == labels).float().mean() + + return loss.item(), accuracy.item() + + # Dataset + graphs = get_tetris() + graphs = graphs.to(device=device) + + # Compile + model = torch.compile(model, fullgraph=True, mode="reduce-overhead") + + wall = time.perf_counter() + print("compiling...", flush=True) + # Warmup runs + for i in range(3): + loss, accuracy = update_fn(model, opt, graphs) + print(f"initial accuracy = {100 * accuracy:.0f}%", flush=True) + print(f"compilation took {time.perf_counter() - wall:.1f}s") + + from ctypes import cdll + libcudart = cdll.LoadLibrary('libcudart.so') + + # Train + timings = [] + print("training...", flush=True) + for step in tqdm(range(steps)): + start = time.time() + if step == 20: + libcudart.cudaProfilerStart() + + loss, accuracy = update_fn(model, opt, graphs) + + timings.append(time.time() - start) + if step == 30: + libcudart.cudaProfilerStop() + + + print(f"final accuracy = {100 * accuracy:.0f}%") + print(f"Training time/step {np.mean(timings[20:])*1000:.3f} ms") + + # # Export model + # so_path = torch._export.aot_compile( + # model, + # args = (graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch), + # options={"aot_inductor.output_path": os.path.join(os.getcwd(), "export/model.so"), + # }) + + # print("node_features", graphs.numbers) + # print("pos", graphs.pos) + # print("edge_index", graphs.edge_index) + # print("batch", graphs.batch) + + # runner = torch._C._aoti.AOTIModelContainerRunnerCuda(os.path.join(os.getcwd(), f"export/model.so"), 1, device) + # outputs_export = runner.run([graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch]) + # print(f"output {outputs_export[0]}") + + +if __name__ == "__main__": + train() \ No newline at end of file diff --git a/tetris_jax_baseline.py b/train_e3nn_jax.py similarity index 100% rename from tetris_jax_baseline.py rename to train_e3nn_jax.py