Skip to content

Commit

Permalink
added baseline stuff + timing
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak committed Oct 22, 2024
1 parent ca28c83 commit 9172669
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 14 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch
e3nn @ git+https://github.com/mitkotak/e3nn@linear_pt2
e3nn
torch-geometric
torch-cluster
torch-scatter @ git+https://github.com/rusty1s/pytorch_scatter.git
157 changes: 157 additions & 0 deletions tetris_jax_baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copied from https://github.com/e3nn/e3nn-jax/blob/245e17eb23deaccad9f2c9cfd40fe40515e3c074/examples/tetris_point.py#L13

import time

import flax
import jax
import jax.numpy as jnp
import jraph
import optax
from tqdm.auto import tqdm

import e3nn_jax as e3nn


def tetris() -> jraph.GraphsTuple:
pos = [
[[0, 0, 0], [0, 0, 1], [1, 0, 0], [1, 1, 0]], # chiral_shape_1
[[1, 1, 1], [1, 1, 2], [2, 1, 1], [2, 0, 1]], # chiral_shape_2
[[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], # square
[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]], # line
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], # corner
[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0]], # L
[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 1]], # T
[[0, 0, 0], [1, 0, 0], [1, 1, 0], [2, 1, 0]], # zigzag
]
pos = jnp.array(pos, dtype=jnp.float32)
labels = jnp.arange(8)

graphs = []

for p, l in zip(pos, labels):
senders, receivers = e3nn.radius_graph(p, 1.1)

graphs += [
jraph.GraphsTuple(
nodes=p.reshape((4, 3)), # [num_nodes, 3]
edges=None,
globals=l[None], # [num_graphs]
senders=senders, # [num_edges]
receivers=receivers, # [num_edges]
n_node=jnp.array([len(p)]), # [num_graphs]
n_edge=jnp.array([len(senders)]), # [num_graphs]
)
]

return jraph.batch(graphs)


class Layer(flax.linen.Module):
target_irreps: e3nn.Irreps
denominator: float
sh_lmax: int = 3

@flax.linen.compact
def __call__(self, graphs, positions):
target_irreps = e3nn.Irreps(self.target_irreps)

def update_edge_fn(edge_features, sender_features, receiver_features, globals):
sh = e3nn.spherical_harmonics(
list(range(1, self.sh_lmax + 1)),
positions[graphs.receivers] - positions[graphs.senders],
True,
)
return e3nn.concatenate(
[sender_features, e3nn.tensor_product(sender_features, sh)]
).regroup()

def update_node_fn(node_features, sender_features, receiver_features, globals):
node_feats = receiver_features / self.denominator
node_feats = e3nn.flax.Linear(target_irreps, name="linear_pre")(node_feats)
node_feats = e3nn.scalar_activation(node_feats)
node_feats = e3nn.flax.Linear(target_irreps, name="linear_post")(node_feats)
shortcut = e3nn.flax.Linear(
node_feats.irreps, name="shortcut", force_irreps_out=True
)(node_features)
return shortcut + node_feats

return jraph.GraphNetwork(update_edge_fn, update_node_fn)(graphs)


class Model(flax.linen.Module):
@flax.linen.compact
def __call__(self, graphs):
positions = e3nn.IrrepsArray("1o", graphs.nodes)
graphs = graphs._replace(nodes=jnp.ones((len(positions), 1)))

layers = 2 * ["32x0e + 32x0o + 8x1e + 8x1o + 8x2e + 8x2o"] + ["0o + 7x0e"]

for irreps in layers:
graphs = Layer(irreps, 1.5)(graphs, positions)

# Readout logits
pred = e3nn.scatter_sum(
graphs.nodes.array, nel=graphs.n_node
) # [num_graphs, 1 + 7]
odd, even1, even2 = pred[:, :1], pred[:, 1:2], pred[:, 2:]
logits = jnp.concatenate([odd * even1, -odd * even1, even2], axis=1)
assert logits.shape == (len(graphs.n_node), 8) # [num_graphs, num_classes]

return logits


def train(steps=200):
model = Model()

# Optimizer
opt = optax.adam(learning_rate=0.01)

def loss_fn(params, graphs):
logits = model.apply(params, graphs)
labels = graphs.globals # [num_graphs]

loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
loss = jnp.mean(loss)
return loss, logits

@jax.jit
def update_fn(params, opt_state, graphs):
grad_fn = jax.grad(loss_fn, has_aux=True)
grads, logits = grad_fn(params, graphs)
labels = graphs.globals
accuracy = jnp.mean(jnp.argmax(logits, axis=1) == labels)

updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, accuracy

# Dataset
graphs = tetris()

# Init
init = jax.jit(model.init)
params = init(jax.random.PRNGKey(3), graphs)
opt_state = opt.init(params)

# compile jit
wall = time.perf_counter()
print("compiling...", flush=True)
_, _, accuracy = update_fn(params, opt_state, graphs)
print(f"initial accuracy = {100 * accuracy:.0f}%", flush=True)
print(f"compilation took {time.perf_counter() - wall:.1f}s")

# Train
wall = time.perf_counter()
print("training...", flush=True)
for _ in tqdm(range(steps)):
params, opt_state, accuracy = update_fn(params, opt_state, graphs)

if accuracy == 1.0:
break

print(f"final accuracy = {100 * accuracy:.0f}%")
print(f"training took {time.perf_counter() - wall:.1f}s")


if __name__ == "__main__":
train()
29 changes: 16 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from e3nn import o3
from e3nn.util.jit import prepare


torch.manual_seed(0)

# Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18
torch.set_float32_matmul_precision("high")
import torch._dynamo.config
Expand Down Expand Up @@ -93,21 +96,21 @@ def update_fn(model, opt, graphs):
print(f"final accuracy = {100 * accuracy:.0f}%")
print(f"training took {time.perf_counter() - wall:.1f}s")

# Export model
so_path = torch._export.aot_compile(
model,
args = (graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch),
options={"aot_inductor.output_path": os.path.join(os.getcwd(), "export/model.so"),
})
# # Export model
# so_path = torch._export.aot_compile(
# model,
# args = (graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch),
# options={"aot_inductor.output_path": os.path.join(os.getcwd(), "export/model.so"),
# })

print("node_features", graphs.numbers.shape, graphs.numbers.dtype)
print("pos", graphs.pos.shape, graphs.pos.dtype)
print("edge_index", graphs.edge_index, graphs.edge_index.shape, graphs.edge_index.dtype)
print("batch", graphs.batch, graphs.batch.shape, graphs.batch.dtype)
# print("node_features", graphs.numbers)
# print("pos", graphs.pos)
# print("edge_index", graphs.edge_index)
# print("batch", graphs.batch)

runner = torch._C._aoti.AOTIModelContainerRunnerCuda(os.path.join(os.getcwd(), f"export/model.so"), 1, device)
outputs_export = runner.run([graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch])
print(f"output {outputs_export[0]}")
# runner = torch._C._aoti.AOTIModelContainerRunnerCuda(os.path.join(os.getcwd(), f"export/model.so"), 1, device)
# outputs_export = runner.run([graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch])
# print(f"output {outputs_export[0]}")


if __name__ == "__main__":
Expand Down

0 comments on commit 9172669

Please sign in to comment.