diff --git a/train_e3nn.py b/train_e3nn.py index a238580..5f25888 100644 --- a/train_e3nn.py +++ b/train_e3nn.py @@ -217,7 +217,8 @@ def update_fn(model, opt, graphs): libcudart.cudaProfilerStart() loss, accuracy = update_fn(model, opt, graphs) - + torch.cuda.synchronize() + timings.append(time.time() - start) if step == 30: libcudart.cudaProfilerStop() diff --git a/train_e3nn_jax.py b/train_e3nn_jax.py index a8f93bc..5aa5b8e 100644 --- a/train_e3nn_jax.py +++ b/train_e3nn_jax.py @@ -5,6 +5,7 @@ import flax import jax import jax.numpy as jnp +import numpy as np import jraph import optax from tqdm.auto import tqdm @@ -134,21 +135,21 @@ def update_fn(params, opt_state, graphs): # compile jit wall = time.perf_counter() print("compiling...", flush=True) - _, _, accuracy = update_fn(params, opt_state, graphs) + for _ in range(3): + _, _, accuracy = update_fn(params, opt_state, graphs) print(f"initial accuracy = {100 * accuracy:.0f}%", flush=True) print(f"compilation took {time.perf_counter() - wall:.1f}s") # Train wall = time.perf_counter() - print("training...", flush=True) + timings = [] for _ in tqdm(range(steps)): - params, opt_state, accuracy = update_fn(params, opt_state, graphs) - - if accuracy == 1.0: - break + start = time.time() + params, opt_state, accuracy = jax.block_until_ready(update_fn(params, opt_state, graphs)) + timings.append(time.time() - start) print(f"final accuracy = {100 * accuracy:.0f}%") - print(f"training took {time.perf_counter() - wall:.1f}s") + print(f"Training time/step {np.mean(timings[20:])*1000:.3f} ms") if __name__ == "__main__":