Skip to content

Commit

Permalink
fixed timing
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak committed Nov 25, 2024
1 parent eb4d8fb commit a95b28d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
3 changes: 2 additions & 1 deletion train_e3nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def update_fn(model, opt, graphs):
libcudart.cudaProfilerStart()

loss, accuracy = update_fn(model, opt, graphs)

torch.cuda.synchronize()

timings.append(time.time() - start)
if step == 30:
libcudart.cudaProfilerStop()
Expand Down
15 changes: 8 additions & 7 deletions train_e3nn_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import flax
import jax
import jax.numpy as jnp
import numpy as np
import jraph
import optax
from tqdm.auto import tqdm
Expand Down Expand Up @@ -134,21 +135,21 @@ def update_fn(params, opt_state, graphs):
# compile jit
wall = time.perf_counter()
print("compiling...", flush=True)
_, _, accuracy = update_fn(params, opt_state, graphs)
for _ in range(3):
_, _, accuracy = update_fn(params, opt_state, graphs)
print(f"initial accuracy = {100 * accuracy:.0f}%", flush=True)
print(f"compilation took {time.perf_counter() - wall:.1f}s")

# Train
wall = time.perf_counter()
print("training...", flush=True)
timings = []
for _ in tqdm(range(steps)):
params, opt_state, accuracy = update_fn(params, opt_state, graphs)

if accuracy == 1.0:
break
start = time.time()
params, opt_state, accuracy = jax.block_until_ready(update_fn(params, opt_state, graphs))
timings.append(time.time() - start)

print(f"final accuracy = {100 * accuracy:.0f}%")
print(f"training took {time.perf_counter() - wall:.1f}s")
print(f"Training time/step {np.mean(timings[20:])*1000:.3f} ms")


if __name__ == "__main__":
Expand Down

0 comments on commit a95b28d

Please sign in to comment.