You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Between jaxlib==0.4.32.dev20240807 and jaxlib==0.4.32.dev20240812 I observe a significant decrease of performance for integration of differential equations with many solver steps (up to 8x slower). Minimal example:
Tested on Ubuntu 22.04 and CPU backend. Runtime is on my PC 1.8 ms for the nightly jaxlib version 20240807 and 14.8ms for version 20240812. The difference is the largest if t_steps is quite large.
To test it quickly, I used the following one-liner: uv venv --python 3.12 && uv pip install diffrax numpy --pre jax==0.4.32.dev20240807 jaxlib==0.4.32.dev20240807 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --reinstall --exclude-newer 2024-09-20 && uv run test_diffrax.py && uv pip install diffrax numpy --pre jax==0.4.32.dev20240807 jaxlib==0.4.32.dev20240812 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --reinstall --exclude-newer 2024-09-20 && uv run test_diffrax.py, where test_diffrax.py is the script above.
I don't know whether I should have better opened the issue on the JAX Github, let me know if it isn't correct here.
The text was updated successfully, but these errors were encountered:
I'd definitely open this as an issue on the JAX GitHub! Probably this will be due to some change in the XLA compiler between those two builds, which unfortunately means there isn't much we can do about it from Diffrax.
If you're interested in digging into it then you might be able to locate the appropriate commit by bisecting through the XLA repo (https://github.com/openxla/xla/). I believe JAX itself hosts several benchmarks so probably something can be added to those to prevent a regression again afterwards.
Yes, with the suggested fix, I observe similar runtimes again.
A bit unrelated, I can't test it for the newest nightly Jax version, as Jax.core.ConcreteArray has been removed from Jax (jax-ml/jax@48f24b6) but ConcreteArray is used in diffrax. I don't know whether you already know about it.
Great that the source of the slowdown has been fixed.
Thanks for the heads-up on ConcreteArray. Judging from that PR it looks like we probably want to use something like jax.core.is_concrete(jax.core.get_aval(x)) instead now. I'll delay our next release of Diffrax until after that JAX release, and then we can try to include a fix for that at the same time as well.
Between
jaxlib==0.4.32.dev20240807
andjaxlib==0.4.32.dev20240812
I observe a significant decrease of performance for integration of differential equations with many solver steps (up to 8x slower). Minimal example:Tested on Ubuntu 22.04 and CPU backend. Runtime is on my PC 1.8 ms for the nightly jaxlib version 20240807 and 14.8ms for version 20240812. The difference is the largest if
t_steps
is quite large.To test it quickly, I used the following one-liner:
uv venv --python 3.12 && uv pip install diffrax numpy --pre jax==0.4.32.dev20240807 jaxlib==0.4.32.dev20240807 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --reinstall --exclude-newer 2024-09-20 && uv run test_diffrax.py && uv pip install diffrax numpy --pre jax==0.4.32.dev20240807 jaxlib==0.4.32.dev20240812 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --reinstall --exclude-newer 2024-09-20 && uv run test_diffrax.py
, wheretest_diffrax.py
is the script above.I don't know whether I should have better opened the issue on the JAX Github, let me know if it isn't correct here.
The text was updated successfully, but these errors were encountered: