diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..dd815f1 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +project(aoti_example) + +find_package(CUDA) + +find_package(Torch REQUIRED) + +add_executable(inference inference.cpp) + +target_link_libraries(inference "${TORCH_LIBRARIES}" m) +set_property(TARGET inference PROPERTY CXX_STANDARD 17) \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..67fee6d --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +train: + python train.py + +build: + cmake -Bbuild -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` + cmake --build build + +clean: + rm -rf build + +profile: + nsys profile --capture-range=cudaProfilerApi --cuda-graph-trace=node --capture-range-end=stop -o profile -f true python train.py + +run: + ./build/inference export/model.so + +.PHONY: run nsys build clean + diff --git a/debug/scatter_debug.py b/debug/scatter_debug.py deleted file mode 100644 index 055be6d..0000000 --- a/debug/scatter_debug.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch - -torch._dynamo.config.capture_scalar_outputs = True - -def my_arithmetic(a, b): - wrk = torch.zeros(a.size(0), dtype=torch.float32) - wrk.scatter_add_(0, b, torch.ones_like(b, dtype=torch.float32)) - return wrk - -model = torch.compile(my_arithmetic, fullgraph=True, disable=True) -my_a = torch.randn([9]) -my_b = torch.ones(9, dtype=torch.int64) -print(model(my_a, my_b)) \ No newline at end of file diff --git a/inference.cpp b/inference.cpp new file mode 100644 index 0000000..ff405e3 --- /dev/null +++ b/inference.cpp @@ -0,0 +1,25 @@ +#include +#include + +#include +#include + +int main(int argc, char *argv[]) { + + char *model_path = NULL; // e.g. mode.so + model_path = argv[1]; + + c10::InferenceMode mode; + torch::inductor::AOTIModelContainerRunnerCuda *runner; + runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path, 1); + std::vector inputs = { + torch::randn({32,1}, at::kCUDA), + torch::randn({32,3}, at::kCUDA), + torch::randn({2,50}, at::kCUDA), + torch::randn({32,1}, at::kCUDA) + }; + std::vector outputs = runner->run(inputs); + std::cout << "Result from the first inference:"<< std::endl; + std::cout << outputs[0] << std::endl; + return 0; +} \ No newline at end of file diff --git a/model.py b/model.py index b35a1ec..04710bb 100644 --- a/model.py +++ b/model.py @@ -6,6 +6,7 @@ from e3nn import o3 +from torch_geometric.data import Data from torch_scatter import scatter_mean, scatter_sum class Layer(nn.Module): @@ -102,9 +103,13 @@ def __init__(self): self.layers = torch.nn.ModuleList(layers) - def forward(self, graphs): - - node_features, pos, edge_index, batch, num_nodes = graphs.numbers, graphs.pos, graphs.edge_index, graphs.batch, graphs.num_nodes + def forward(self, + node_features, + pos, + edge_index, + batch): + + # Passing in graphs make dynamo angry senders, receivers = edge_index relative_positions= pos[receivers] - pos[senders] diff --git a/modules.py b/modules.py deleted file mode 100644 index a023282..0000000 --- a/modules.py +++ /dev/null @@ -1,224 +0,0 @@ -from typing import List, Union, Optional, Callable - -import torch -import torch.nn.functional as F -import numpy as np -from e3nn import o3 - -class _MulIndexSliceHelper: - irreps: o3.Irreps - - def __init__(self, irreps) -> None: - self.irreps = irreps - - def __getitem__(self, index: slice) -> o3.Irreps: - if not isinstance(index, slice): - raise IndexError("Irreps.slice_by_mul only supports slices.") - - start, stop, stride = index.indices(self.irreps.num_irreps) - if stride != 1: - raise NotImplementedError("Irreps.slice_by_mul does not support strides.") - - out = [] - i = 0 - for mul, ir in self.irreps: - if start <= i and i + mul <= stop: - out.append((mul, ir)) - elif start < i + mul and i < stop: - out.append((min(stop, i + mul) - max(start, i), ir)) - i += mul - return o3.Irreps(out) - - -def slice_by_mul(irreps): - return _MulIndexSliceHelper(irreps) - -def filter( - irreps, - keep: Union["o3.Irreps", List[o3.Irrep]] = None, - *, - drop: Union["o3.Irreps", List[o3.Irrep]] = None, - lmax: int = None, -) -> "o3.Irreps": - if keep is None and drop is None and lmax is None: - return self - if keep is not None and drop is not None: - raise ValueError("Cannot specify both keep and drop") - if keep is not None and lmax is not None: - raise ValueError("Cannot specify both keep and lmax") - if drop is not None and lmax is not None: - raise ValueError("Cannot specify both drop and lmax") - - if keep is not None: - if isinstance(keep, str): - keep = o3.Irreps(keep) - if isinstance(keep, o3.Irrep): - keep = [keep] - keep = {o3.Irrep(ir) for ir in keep} - return o3.Irreps([(mul, ir) for mul, ir in irreps if ir in keep]) - - if drop is not None: - if isinstance(drop, str): - drop = o3.Irreps(drop) - if isinstance(drop, o3.Irrep): - drop = [drop] - drop = {o3.Irrep(ir) for ir in drop} - return o3.Irreps([(mul, ir) for mul, ir in irreps if ir not in drop]) - - if lmax is not None: - return o3.Irreps([(mul, ir) for mul, ir in irreps if ir.l <= lmax]) - -def soft_odd(x): - return (1 - torch.exp(-(x**2))) * x - -def normalspace(n: int) -> torch.Tensor: - return np.sqrt(2) * torch.erfinv(torch.linspace(-1.0, 1.0, n + 2)[1:-1]) - - -def normalize_function(phi: Callable[[float], float]) -> Callable[[float], float]: - x = normalspace(1_000_001) - c = torch.mean(phi(x) ** 2) ** 0.5 - c = c.item() - - if np.allclose(c, 1.0): - return phi - else: - - def rho(x): - return phi(x) / c - - return rho - -def parity_function(phi: Callable[[float], float]) -> int: - x = torch.linspace(0.0, 10.0, 256) - - a1, a2 = phi(x), phi(-x) - if torch.max(torch.abs(a1 - a2)) < 1e-5: - return 1 - elif torch.max(torch.abs(a1 + a2)) < 1e-5: - return -1 - else: - return 0 - -def is_zero_in_zero(phi: Callable[[float], float]) -> bool: - return torch.allclose(phi(torch.Tensor([0.0])), 0.0) - -# class ScalarActivation(nn.Module): - -# def __init__(self, -# irreps_in: o3.Irreps, -# acts: List[Optional[Callable[[float], float]]] = None, -# *, -# even_act: Callable[[float], float] = F.gelu, -# odd_act: Callable[[float], float] = soft_odd, -# normalize_act: bool = True): - -# super(ScalarActivation, self).__init__() - -# if acts is None: -# acts = [ -# {1: even_act, -1: odd_act}[ir.p] if ir.l == 0 else None -# for _, ir in irreps_in -# ] - -# assert len(irreps_in) == len(acts), (irreps_in, acts) -# irreps_out = [] -# paths = {} - -# for (mul, (l_in, p_in)), slice_x, act in zip(irreps_in, irreps_in.slices(), acts): -# if act is not None: -# if l_in != 0: -# raise ValueError( -# f"Activation: cannot apply an activation function to a non-scalar input. {irreps_in} {acts}" -# ) - -# if normalize_act: -# act = normalize_function(act) - -# p_out = parity_function(act) if p_in == -1 else p_in -# if p_out == 0: -# raise ValueError( -# "Activation: the parity is violated! The input scalar is odd but the activation is neither even nor odd." -# ) - -# irreps_out.append((mul, (0, p_out))) -# else: -# irreps_out.append((mul, (l_in, p_in))) - -# paths[l_in] = (slice_x, act) - -# self._same_acts = False -# # for performance, if all the activation functions are the same, we can apply it to the contiguous array as well: -# if acts and acts.count(acts[0]) == len(acts): -# if acts[0] is None: -# self.act = None -# else: -# act = acts[0] -# if normalize_act: -# self.act = normalize_function(act) - -# irreps_out = o3.Irreps(irreps_out) -# self.irreps_out, _, self.inv = irreps_out.sort() -# self.paths = paths - -# def forward(self, input: torch.Tensor): - -# if self._same_acts: -# if self.act is None: -# return input -# else: -# return self.act(input) - -# chunks = [] -# for (slice_x, act) in self.paths.values(): -# if act is None: -# chunks.append(input[..., slice_x]) -# else: -# chunks.append(act(input[..., slice_x])) - -# return torch.cat([chunks[i] for i in self.inv], dim=-1) - -# class Gate(torch.nn.Module): -# def __init__( -# self, -# irreps: o3.Irreps, -# even_act: Callable[[float], float] = F.gelu, -# odd_act: Callable[[float], float] = soft_odd, -# even_gate_act: Callable[[float], float] = F.sigmoid, -# odd_gate_act: Callable[[float], float] = F.tanh, -# normalize_act: bool = True): - -# scalars_irreps = filter(irreps, keep=["0e", "0o"]) -# vectors_irreps = filter(irreps, drop=["0e", "0o"]) - -# if scalars_irreps.dim < vectors_irreps.num_irreps: -# raise ValueError( -# "The input must have at least as many scalars as the number of non-scalar irreps" -# ) -# scalars_extra_irreps = scalars_irreps.slice_by_mul[ -# : scalars_irreps.irreps.dim - vectors_irreps.irreps.num_irreps -# ] -# scalars_gates_irreps = scalars_irreps.slice_by_mul[ -# scalars_irreps.irreps.dim - vectors_irreps.irreps.num_irreps : -# ] - -# self.scalars_extra = ScalarActivation( -# scalars_extra_irreps, -# even_act=even_act, -# odd_act=odd_act, -# normalize_act=normalize_act -# ) -# self.scalars_gates = ScalarActivation( -# scalars_gates_irreps, -# even_act=even_gate_act, -# odd_act=odd_gate_act, -# normalize_act=normalize_act, -# ) - -# self.elementwise_tp = o3.ElementwiseTensorProduct(scalars_extra_irreps, vectors_irreps) - - -# self.output_irreps = self.scalars_extra_irreps + self.elementwise_tp.irreps_out - - - diff --git a/nsys/profile.nsys-rep b/nsys/profile.nsys-rep new file mode 100644 index 0000000..ec3ed52 Binary files /dev/null and b/nsys/profile.nsys-rep differ diff --git a/train.py b/train.py index 6a57c07..01a909c 100644 --- a/train.py +++ b/train.py @@ -1,8 +1,12 @@ import time +import os +import nvtx + import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F +from torch.fx.experimental.proxy_tensor import make_fx from tqdm.auto import tqdm from e3nn import o3 @@ -24,6 +28,7 @@ def train(steps=200): model = model.to(device) opt = optim.Adam(model.parameters(), lr=0.01) + @nvtx.annotate(color="red") def loss_fn(model_output, graphs): logits = model_output labels = graphs.y # [num_graphs] @@ -31,8 +36,9 @@ def loss_fn(model_output, graphs): loss = torch.mean(loss) return loss, logits + @nvtx.annotate(color="blue") def update_fn(model, opt, graphs): - model_output = model(graphs) + model_output = model(graphs.numbers, graphs.pos, graphs.edge_index, graphs.batch) loss, logits = loss_fn(model_output, graphs) opt.zero_grad(set_to_none=True) @@ -44,34 +50,54 @@ def update_fn(model, opt, graphs): accuracy = (preds == labels).float().mean() return loss.item(), accuracy.item() - - # Compile the update function - update_fn_compiled = torch.compile(update_fn, mode="reduce-overhead") - # Dataset graphs = get_tetris() graphs = graphs.to(device=device) - # compile jit + # Compile + model = torch.compile(model, fullgraph=True, mode="reduce-overhead") + wall = time.perf_counter() print("compiling...", flush=True) # Warmup runs for i in range(3): - loss, accuracy = update_fn_compiled(model, opt, graphs) + loss, accuracy = update_fn(model, opt, graphs) print(f"initial accuracy = {100 * accuracy:.0f}%", flush=True) print(f"compilation took {time.perf_counter() - wall:.1f}s") + from ctypes import cdll + libcudart = cdll.LoadLibrary('libcudart.so') + # Train wall = time.perf_counter() print("training...", flush=True) - for _ in tqdm(range(steps)): - loss, accuracy = update_fn_compiled(model, opt, graphs) + for step in tqdm(range(steps)): + if step == 20: + libcudart.cudaProfilerStart() + + loss, accuracy = update_fn(model, opt, graphs) + + if step == 30: + libcudart.cudaProfilerStop() if accuracy == 1.0: break print(f"final accuracy = {100 * accuracy:.0f}%") print(f"training took {time.perf_counter() - wall:.1f}s") + + # Export model + so_path = torch._export.aot_compile( + model, + args = (graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch), + options={"aot_inductor.output_path": os.path.join(os.getcwd(), "export/model.so"), + }) + + print("Traced Shapes") + print("node_features", graphs.numbers.shape, graphs.numbers.dtype) + print("pos", graphs.pos.shape, graphs.pos.dtype) + print("edge_index", graphs.edge_index.shape, graphs.edge_index.dtype) + print("batch", graphs.batch.shape, graphs.batch.dtype) if __name__ == "__main__": train() \ No newline at end of file