diff --git a/Makefile b/Makefile index 49b877a..ce2d3b8 100644 --- a/Makefile +++ b/Makefile @@ -14,5 +14,5 @@ profile: run: ./build/inference ${PWD}/export/model.so -.PHONY: run nsys build clean +.PHONY: run profile build clean diff --git a/train.py b/train.py index 01a909c..f436430 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ from torch.fx.experimental.proxy_tensor import make_fx from tqdm.auto import tqdm from e3nn import o3 +from e3nn.util.jit import prepare # Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18 torch.set_float32_matmul_precision("high") @@ -23,8 +24,13 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -def train(steps=200): + +def build_model(): model = Model() + return model + +def train(steps=200): + model = prepare(build_model)() model = model.to(device) opt = optim.Adam(model.parameters(), lr=0.01) @@ -50,6 +56,7 @@ def update_fn(model, opt, graphs): accuracy = (preds == labels).float().mean() return loss.item(), accuracy.item() + # Dataset graphs = get_tetris() graphs = graphs.to(device=device)