diff --git a/requirements.txt b/requirements.txt index a3aff7e..4031f88 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch -e3nn @ git+https://github.com/mitkotak/e3nn@linear_pt2 +e3nn torch-geometric torch-cluster torch-scatter @ git+https://github.com/rusty1s/pytorch_scatter.git diff --git a/tetris_jax_baseline.py b/tetris_jax_baseline.py new file mode 100644 index 0000000..70f9e9a --- /dev/null +++ b/tetris_jax_baseline.py @@ -0,0 +1,157 @@ +# Copied from https://github.com/e3nn/e3nn-jax/blob/245e17eb23deaccad9f2c9cfd40fe40515e3c074/examples/tetris_point.py#L13 + +import time + +import flax +import jax +import jax.numpy as jnp +import jraph +import optax +from tqdm.auto import tqdm + +import e3nn_jax as e3nn + + +def tetris() -> jraph.GraphsTuple: + pos = [ + [[0, 0, 0], [0, 0, 1], [1, 0, 0], [1, 1, 0]], # chiral_shape_1 + [[1, 1, 1], [1, 1, 2], [2, 1, 1], [2, 0, 1]], # chiral_shape_2 + [[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], # square + [[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]], # line + [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], # corner + [[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0]], # L + [[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 1]], # T + [[0, 0, 0], [1, 0, 0], [1, 1, 0], [2, 1, 0]], # zigzag + ] + pos = jnp.array(pos, dtype=jnp.float32) + labels = jnp.arange(8) + + graphs = [] + + for p, l in zip(pos, labels): + senders, receivers = e3nn.radius_graph(p, 1.1) + + graphs += [ + jraph.GraphsTuple( + nodes=p.reshape((4, 3)), # [num_nodes, 3] + edges=None, + globals=l[None], # [num_graphs] + senders=senders, # [num_edges] + receivers=receivers, # [num_edges] + n_node=jnp.array([len(p)]), # [num_graphs] + n_edge=jnp.array([len(senders)]), # [num_graphs] + ) + ] + + return jraph.batch(graphs) + + +class Layer(flax.linen.Module): + target_irreps: e3nn.Irreps + denominator: float + sh_lmax: int = 3 + + @flax.linen.compact + def __call__(self, graphs, positions): + target_irreps = e3nn.Irreps(self.target_irreps) + + def update_edge_fn(edge_features, sender_features, receiver_features, globals): + sh = e3nn.spherical_harmonics( + list(range(1, self.sh_lmax + 1)), + positions[graphs.receivers] - positions[graphs.senders], + True, + ) + return e3nn.concatenate( + [sender_features, e3nn.tensor_product(sender_features, sh)] + ).regroup() + + def update_node_fn(node_features, sender_features, receiver_features, globals): + node_feats = receiver_features / self.denominator + node_feats = e3nn.flax.Linear(target_irreps, name="linear_pre")(node_feats) + node_feats = e3nn.scalar_activation(node_feats) + node_feats = e3nn.flax.Linear(target_irreps, name="linear_post")(node_feats) + shortcut = e3nn.flax.Linear( + node_feats.irreps, name="shortcut", force_irreps_out=True + )(node_features) + return shortcut + node_feats + + return jraph.GraphNetwork(update_edge_fn, update_node_fn)(graphs) + + +class Model(flax.linen.Module): + @flax.linen.compact + def __call__(self, graphs): + positions = e3nn.IrrepsArray("1o", graphs.nodes) + graphs = graphs._replace(nodes=jnp.ones((len(positions), 1))) + + layers = 2 * ["32x0e + 32x0o + 8x1e + 8x1o + 8x2e + 8x2o"] + ["0o + 7x0e"] + + for irreps in layers: + graphs = Layer(irreps, 1.5)(graphs, positions) + + # Readout logits + pred = e3nn.scatter_sum( + graphs.nodes.array, nel=graphs.n_node + ) # [num_graphs, 1 + 7] + odd, even1, even2 = pred[:, :1], pred[:, 1:2], pred[:, 2:] + logits = jnp.concatenate([odd * even1, -odd * even1, even2], axis=1) + assert logits.shape == (len(graphs.n_node), 8) # [num_graphs, num_classes] + + return logits + + +def train(steps=200): + model = Model() + + # Optimizer + opt = optax.adam(learning_rate=0.01) + + def loss_fn(params, graphs): + logits = model.apply(params, graphs) + labels = graphs.globals # [num_graphs] + + loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels) + loss = jnp.mean(loss) + return loss, logits + + @jax.jit + def update_fn(params, opt_state, graphs): + grad_fn = jax.grad(loss_fn, has_aux=True) + grads, logits = grad_fn(params, graphs) + labels = graphs.globals + accuracy = jnp.mean(jnp.argmax(logits, axis=1) == labels) + + updates, opt_state = opt.update(grads, opt_state) + params = optax.apply_updates(params, updates) + return params, opt_state, accuracy + + # Dataset + graphs = tetris() + + # Init + init = jax.jit(model.init) + params = init(jax.random.PRNGKey(3), graphs) + opt_state = opt.init(params) + + # compile jit + wall = time.perf_counter() + print("compiling...", flush=True) + _, _, accuracy = update_fn(params, opt_state, graphs) + print(f"initial accuracy = {100 * accuracy:.0f}%", flush=True) + print(f"compilation took {time.perf_counter() - wall:.1f}s") + + # Train + wall = time.perf_counter() + print("training...", flush=True) + for _ in tqdm(range(steps)): + params, opt_state, accuracy = update_fn(params, opt_state, graphs) + + if accuracy == 1.0: + break + + print(f"final accuracy = {100 * accuracy:.0f}%") + print(f"training took {time.perf_counter() - wall:.1f}s") + + +if __name__ == "__main__": + train() \ No newline at end of file diff --git a/train.py b/train.py index c44f7e8..d20b8d7 100644 --- a/train.py +++ b/train.py @@ -11,6 +11,9 @@ from e3nn import o3 from e3nn.util.jit import prepare + +torch.manual_seed(0) + # Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18 torch.set_float32_matmul_precision("high") import torch._dynamo.config @@ -93,21 +96,21 @@ def update_fn(model, opt, graphs): print(f"final accuracy = {100 * accuracy:.0f}%") print(f"training took {time.perf_counter() - wall:.1f}s") - # Export model - so_path = torch._export.aot_compile( - model, - args = (graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch), - options={"aot_inductor.output_path": os.path.join(os.getcwd(), "export/model.so"), - }) + # # Export model + # so_path = torch._export.aot_compile( + # model, + # args = (graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch), + # options={"aot_inductor.output_path": os.path.join(os.getcwd(), "export/model.so"), + # }) - print("node_features", graphs.numbers.shape, graphs.numbers.dtype) - print("pos", graphs.pos.shape, graphs.pos.dtype) - print("edge_index", graphs.edge_index, graphs.edge_index.shape, graphs.edge_index.dtype) - print("batch", graphs.batch, graphs.batch.shape, graphs.batch.dtype) + # print("node_features", graphs.numbers) + # print("pos", graphs.pos) + # print("edge_index", graphs.edge_index) + # print("batch", graphs.batch) - runner = torch._C._aoti.AOTIModelContainerRunnerCuda(os.path.join(os.getcwd(), f"export/model.so"), 1, device) - outputs_export = runner.run([graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch]) - print(f"output {outputs_export[0]}") + # runner = torch._C._aoti.AOTIModelContainerRunnerCuda(os.path.join(os.getcwd(), f"export/model.so"), 1, device) + # outputs_export = runner.run([graphs.numbers,graphs.pos,graphs.edge_index,graphs.batch]) + # print(f"output {outputs_export[0]}") if __name__ == "__main__":