Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed SDEs being unnecessarily slow to solve when max_steps!=None. #320

Merged
merged 1 commit into from
Oct 13, 2023

Conversation

patrick-kidger
Copy link
Owner

The main changes needed to make this happen are in
patrick-kidger/equinox#548
and as such this commit is fairly small -- it declares a dependency on
a new (as yet unreleasd) version of Equinox, and removes the
compatibility shim that was there before.

How things used to be.

To explain what's going on here a little more carefully: JAX only
recently added support for communicating to a custom_vjp which
input arguments were being perturbed, and which output cotangents were
symbolic zeros. Prior to that, a custom_vjp basically just had to
differentiate all inexact arrays.

However, there are some nondifferentiable inexact arrays, in the
sense that attempting to differentiate them will raise an error.
Solving SDEs has one such array: the nondifferentiable input to a
VirtualBrownianTree, which is guarded by a eqxi.nondifferentiable to
reflect the fact that Brownian motion is nondifferentiable.

So it used to be the case that the custom_vjp underlying
RecursiveCheckpointAdjoint would differentiate the overall
make-a-step function with respect to all inexact arrays,
including the time variable, hit the nondifferentiable guard, and
crash.

One (unsafe) fix would have just been to remove the nondifferentiable
guard. In practice I previously took the slower, safer option: silently
switch out RecursiveCheckpointAdjoint for a DirectAdjoint. The
latter is much less efficient, but uses no custom_vjp, and thus used
the perturbation and symbolic-zero propagation rules already present
in JAX's AD machinery, and was thus safe to use here.

How things are now.

And so, what has now changed: JAX has now added support for tracking
which inputs are perturbed, and which cotangents are symbolic zeros.

The Equinox PR above uses this functionality to determine which parts
of the carry need to be differentiated; no longer is it just "all
inexact arrays".

And thus this Diffrax PR removes the compatibility shim that is no
longer needed.

Implications.

SDE solving should now be much faster. In particular this fixes the
speed issue reported in #317.

The main changes needed to make this happen are in
patrick-kidger/equinox#548
and as such this commit is fairly small -- it declares a dependency on
a new (as yet unreleasd) version of Equinox, and removes the
compatibility shim that was there before.

**How things used to be.**

To explain what's going on here a little more carefully: JAX only
recently added support for communicating to a `custom_vjp` which
input arguments were being perturbed, and which output cotangents were
symbolic zeros. Prior to that, a `custom_vjp` basically just had to
differentiate all inexact arrays.

However, there are some nondifferentiable inexact arrays, in the
sense that attempting to differentiate them will raise an error.
Solving SDEs has one such array: the nondifferentiable input to a
VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to
reflect the fact that Brownian motion is nondifferentiable.

So it used to be the case that the `custom_vjp` underlying
`RecursiveCheckpointAdjoint` would differentiate the overall
make-a-step function with respect to all inexact arrays,
including the time variable, hit the `nondifferentiable` guard, and
crash.

One (unsafe) fix would have just been to remove the `nondifferentiable`
guard. In practice I previously took the slower, safer option: silently
switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The
latter is much less efficient, but uses no `custom_vjp`, and thus used
the perturbation and symbolic-zero propagation rules already present
in JAX's AD machinery, and was thus safe to use here.

**How things are now.**

And so, what has now changed: JAX has now added support for tracking
which inputs are perturbed, and which cotangents are symbolic zeros.

The Equinox PR above uses this functionality to determine which parts
of the carry need to be differentiated; no longer is it just "all
inexact arrays".

And thus this Diffrax PR removes the compatibility shim that is no
longer needed.

**Implications.**

SDE solving should now be much faster. In particular this fixes the
speed issue reported in #317.
@patrick-kidger
Copy link
Owner Author

(Ignore the failing CI, it just doesn't know about the as-yet-unreleased version of Equinox.)

@patrick-kidger patrick-kidger merged commit 2cc447f into main Oct 13, 2023
2 checks passed
@patrick-kidger patrick-kidger deleted the speedy-sde branch October 13, 2023 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant