Skip to content

Commit

Permalink
Fixed SDEs being unnecessarily slow to solve when max_steps!=None.
Browse files Browse the repository at this point in the history
The main changes needed to make this happen are in
patrick-kidger/equinox#548
and as such this commit is fairly small -- it declares a dependency on
a new (as yet unreleasd) version of Equinox, and removes the
compatibility shim that was there before.

**How things used to be.**

To explain what's going on here a little more carefully: JAX only
recently added support for communicating to a `custom_vjp` which
input arguments were being perturbed, and which output cotangents were
symbolic zeros. Prior to that, a `custom_vjp` basically just had to
differentiate all inexact arrays.

However, there are some nondifferentiable inexact arrays, in the
sense that attempting to differentiate them will raise an error.
Solving SDEs has one such array: the nondifferentiable input to a
VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to
reflect the fact that Brownian motion is nondifferentiable.

So it used to be the case that the `custom_vjp` underlying
`RecursiveCheckpointAdjoint` would differentiate the overall
make-a-step function with respect to all inexact arrays,
including the time variable, hit the `nondifferentiable` guard, and
crash.

One (unsafe) fix would have just been to remove the `nondifferentiable`
guard. In practice I previously took the slower, safer option: silently
switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The
latter is much less efficient, but uses no `custom_vjp`, and thus used
the perturbation and symbolic-zero propagation rules already present
in JAX's AD machinery, and was thus safe to use here.

**How things are now.**

And so, what has now changed: JAX has now added support for tracking
which inputs are perturbed, and which cotangents are symbolic zeros.

The Equinox PR above uses this functionality to determine which parts
of the carry need to be differentiated; no longer is it just "all
inexact arrays".

And thus this Diffrax PR removes the compatibility shim that is no
longer needed.

**Implications.**

SDE solving should now be much faster. In particular this fixes the
speed issue reported in #317.
  • Loading branch information
patrick-kidger committed Oct 9, 2023
1 parent 15b6c6e commit fd35b51
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 18 deletions.
18 changes: 1 addition & 17 deletions diffrax/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jax.tree_util as jtu
from jax.typing import ArrayLike

from .adjoint import AbstractAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint
from .adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint
from .custom_types import Array, Bool, Int, PyTree, Scalar
from .event import AbstractDiscreteTerminatingEvent
from .global_interpolation import DenseInterpolation
Expand Down Expand Up @@ -633,22 +633,6 @@ def diffeqsolve(
"An SDE should not be solved with adaptive step sizes with Euler's "
"method, as it may not converge to the correct solution."
)
# TODO: remove these lines.
#
# These are to work around an edge case: on the backward pass,
# RecursiveCheckpointAdjoint currently tries to differentiate the overall
# per-step function wrt all floating-point arrays. In particular this includes
# `state.tprev`, which feeds into the control, which feeds into
# VirtualBrownianTree, which can't be differentiated.
# We're waiting on JAX to offer a way of specifying which arguments to a
# custom_vjp have symbolic zero *tangents* (not cotangents) so that we can more
# precisely determine what to differentiate wrt.
#
# We don't replace this in the case of an unsafe SDE because
# RecursiveCheckpointAdjoint will raise an error in that case anyway, so we
# should let the normal error be raised.
if isinstance(adjoint, RecursiveCheckpointAdjoint) and not is_unsafe_sde(terms):
adjoint = DirectAdjoint()
if is_unsafe_sde(terms):
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

python_requires = "~=3.9"

install_requires = ["jax>=0.4.13", "equinox>=0.10.11"]
install_requires = ["jax>=0.4.13", "equinox>=0.11.1"]

setuptools.setup(
name=name,
Expand Down

0 comments on commit fd35b51

Please sign in to comment.