Skip to content

Commit

Permalink
added prepare API
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak committed Aug 15, 2024
1 parent 2de2b26 commit b9a28ff
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ profile:
run:
./build/inference ${PWD}/export/model.so

.PHONY: run nsys build clean
.PHONY: run profile build clean

9 changes: 8 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.fx.experimental.proxy_tensor import make_fx
from tqdm.auto import tqdm
from e3nn import o3
from e3nn.util.jit import prepare

# Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18
torch.set_float32_matmul_precision("high")
Expand All @@ -23,8 +24,13 @@

device = "cuda" if torch.cuda.is_available() else "cpu"

def train(steps=200):

def build_model():
model = Model()
return model

def train(steps=200):
model = prepare(build_model)()
model = model.to(device)
opt = optim.Adam(model.parameters(), lr=0.01)

Expand All @@ -50,6 +56,7 @@ def update_fn(model, opt, graphs):
accuracy = (preds == labels).float().mean()

return loss.item(), accuracy.item()

# Dataset
graphs = get_tetris()
graphs = graphs.to(device=device)
Expand Down

0 comments on commit b9a28ff

Please sign in to comment.