Fixed SDEs being unnecessarily slow to solve when max_steps!=None. #320
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The main changes needed to make this happen are in
patrick-kidger/equinox#548
and as such this commit is fairly small -- it declares a dependency on
a new (as yet unreleasd) version of Equinox, and removes the
compatibility shim that was there before.
How things used to be.
To explain what's going on here a little more carefully: JAX only
recently added support for communicating to a
custom_vjp
whichinput arguments were being perturbed, and which output cotangents were
symbolic zeros. Prior to that, a
custom_vjp
basically just had todifferentiate all inexact arrays.
However, there are some nondifferentiable inexact arrays, in the
sense that attempting to differentiate them will raise an error.
Solving SDEs has one such array: the nondifferentiable input to a
VirtualBrownianTree, which is guarded by a
eqxi.nondifferentiable
toreflect the fact that Brownian motion is nondifferentiable.
So it used to be the case that the
custom_vjp
underlyingRecursiveCheckpointAdjoint
would differentiate the overallmake-a-step function with respect to all inexact arrays,
including the time variable, hit the
nondifferentiable
guard, andcrash.
One (unsafe) fix would have just been to remove the
nondifferentiable
guard. In practice I previously took the slower, safer option: silently
switch out
RecursiveCheckpointAdjoint
for aDirectAdjoint
. Thelatter is much less efficient, but uses no
custom_vjp
, and thus usedthe perturbation and symbolic-zero propagation rules already present
in JAX's AD machinery, and was thus safe to use here.
How things are now.
And so, what has now changed: JAX has now added support for tracking
which inputs are perturbed, and which cotangents are symbolic zeros.
The Equinox PR above uses this functionality to determine which parts
of the carry need to be differentiated; no longer is it just "all
inexact arrays".
And thus this Diffrax PR removes the compatibility shim that is no
longer needed.
Implications.
SDE solving should now be much faster. In particular this fixes the
speed issue reported in #317.