Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixed SDEs being unnecessarily slow to solve when max_steps!=None.
The main changes needed to make this happen are in patrick-kidger/equinox#548 and as such this commit is fairly small -- it declares a dependency on a new (as yet unreleasd) version of Equinox, and removes the compatibility shim that was there before. **How things used to be.** To explain what's going on here a little more carefully: JAX only recently added support for communicating to a `custom_vjp` which input arguments were being perturbed, and which output cotangents were symbolic zeros. Prior to that, a `custom_vjp` basically just had to differentiate all inexact arrays. However, there are some nondifferentiable inexact arrays, in the sense that attempting to differentiate them will raise an error. Solving SDEs has one such array: the nondifferentiable input to a VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to reflect the fact that Brownian motion is nondifferentiable. So it used to be the case that the `custom_vjp` underlying `RecursiveCheckpointAdjoint` would differentiate the overall make-a-step function with respect to all inexact arrays, including the time variable, hit the `nondifferentiable` guard, and crash. One (unsafe) fix would have just been to remove the `nondifferentiable` guard. In practice I previously took the slower, safer option: silently switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The latter is much less efficient, but uses no `custom_vjp`, and thus used the perturbation and symbolic-zero propagation rules already present in JAX's AD machinery, and was thus safe to use here. **How things are now.** And so, what has now changed: JAX has now added support for tracking which inputs are perturbed, and which cotangents are symbolic zeros. The Equinox PR above uses this functionality to determine which parts of the carry need to be differentiated; no longer is it just "all inexact arrays". And thus this Diffrax PR removes the compatibility shim that is no longer needed. **Implications.** SDE solving should now be much faster. In particular this fixes the speed issue reported in #317.
- Loading branch information