Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak committed Aug 16, 2024
2 parents 4606a64 + c61c3d0 commit 218c74b
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,31 @@
- [x] Get `torch.export` pipeline working
- [ ] Call `model.so` successfully in cpp-land

### (Preliminary) Training time comparision with e3nn-jax on RTX A5500

TODO: Need to test out on bigger datasets/models. The throughput is roughly the same but there might be some initialization differences that make JAX converge faster.

- e3nn + Torch 2

```python
compiling...
W0815 01:47:51.930000 140183613732416 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] xindex is not in var_ranges, defaulting to unknown range.
initial accuracy = 25%
compilation took 203.1s
training...
66%|██████████████████████████████████████████████████████████████████▋ | 132/200 [00:00<00:00, 448.17it/s]
final accuracy = 100%
training took 0.3s
```

- e3nn-jax

```python
compiling...
initial accuracy = 25%
compilation took 7.1s
training...
15%|███████████████▎ | 30/200 [00:00<00:00, 473.85it/s]
final accuracy = 100%
training took 0.1s
```

0 comments on commit 218c74b

Please sign in to comment.