-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CWL now compatible with partially nondifferentiable body functions. #548
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
patrick-kidger
changed the title
better loops
CWL now compatible with partially nondifferentiable body functions.
Oct 9, 2023
This is fairly finickity! So, the checkpointed while loop (CWL) needs to figure out which cotangents to propagate backward. There are two criteria for this: (1) only some inputs actually need cotangents at all. Any parts of the carry that are unrelated to those inputs can be skipped. (2) some cotangents might be resolvable as symbolic zeros, and can thus be skipped. Previously, we handled (2) but did not handle (1). In practice this was mostly fine. We just pretended all (inexact) inputs needed cotangents, and let anything superfluous get DCE'd at the end. However, this has one niche problem, which is that it is incompatible with `eqxi.nondifferentiable` -- which is used to mark an inexact input as nondifferentiable, and raise a trace-time error if we attempt to differentiate it. With this change, we're now a bit more careful: we run an additional fixed-point iteration to check which inputs need cotangents, and then skip the rest. Note that this extra fixed-point iteration is actually done using a JVP. This is designed to mirror what JAX does internally: perturbations are tracked using a fixed-point iteration inside the JVP rule for `lax.{while_loop,scan}`. I considered some other alternatives (using VJPs or `pe.dce_jaxpr`) but I think those either fail in corner cases or run into catch-22s.
patrick-kidger
force-pushed
the
better-loops
branch
from
October 9, 2023 06:17
29a43f4
to
5c4a842
Compare
patrick-kidger
added a commit
to patrick-kidger/diffrax
that referenced
this pull request
Oct 9, 2023
The main changes needed to make this happen are in patrick-kidger/equinox#548 and as such this commit is fairly small -- it declares a dependency on a new (as yet unreleasd) version of Equinox, and removes the compatibility shim that was there before. **How things used to be.** To explain what's going on here a little more carefully: JAX only recently added support for communicating to a `custom_vjp` which input arguments were being perturbed, and which output cotangents were symbolic zeros. Prior to that, a `custom_vjp` basically just had to differentiate all inexact arrays. However, there are some nondifferentiable inexact arrays, in the sense that attempting to differentiate them will raise an error. Solving SDEs has one such array: the nondifferentiable input to a VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to reflect the fact that Brownian motion is nondifferentiable. So it used to be the case that the `custom_vjp` underlying `RecursiveCheckpointAdjoint` would differentiate the overall make-a-step function with respect to all inexact arrays, including the time variable, hit the `nondifferentiable` guard, and crash. One (unsafe) fix would have just been to remove the `nondifferentiable` guard. In practice I previously took the slower, safer option: silently switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The latter is much less efficient, but uses no `custom_vjp`, and thus used the perturbation and symbolic-zero propagation rules already present in JAX's AD machinery, and was thus safe to use here. **How things are now.** And so, what has now changed: JAX has now added support for tracking which inputs are perturbed, and which cotangents are symbolic zeros. The Equinox PR above uses this functionality to determine which parts of the carry need to be differentiated; no longer is it just "all inexact arrays". And thus this Diffrax PR removes the compatibility shim that is no longer needed. **Implications.** SDE solving should now be much faster. In particular this fixes the speed issue reported in #317.
This was referenced Oct 9, 2023
Merged
patrick-kidger
added a commit
to patrick-kidger/diffrax
that referenced
this pull request
Oct 13, 2023
The main changes needed to make this happen are in patrick-kidger/equinox#548 and as such this commit is fairly small -- it declares a dependency on a new (as yet unreleasd) version of Equinox, and removes the compatibility shim that was there before. **How things used to be.** To explain what's going on here a little more carefully: JAX only recently added support for communicating to a `custom_vjp` which input arguments were being perturbed, and which output cotangents were symbolic zeros. Prior to that, a `custom_vjp` basically just had to differentiate all inexact arrays. However, there are some nondifferentiable inexact arrays, in the sense that attempting to differentiate them will raise an error. Solving SDEs has one such array: the nondifferentiable input to a VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to reflect the fact that Brownian motion is nondifferentiable. So it used to be the case that the `custom_vjp` underlying `RecursiveCheckpointAdjoint` would differentiate the overall make-a-step function with respect to all inexact arrays, including the time variable, hit the `nondifferentiable` guard, and crash. One (unsafe) fix would have just been to remove the `nondifferentiable` guard. In practice I previously took the slower, safer option: silently switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The latter is much less efficient, but uses no `custom_vjp`, and thus used the perturbation and symbolic-zero propagation rules already present in JAX's AD machinery, and was thus safe to use here. **How things are now.** And so, what has now changed: JAX has now added support for tracking which inputs are perturbed, and which cotangents are symbolic zeros. The Equinox PR above uses this functionality to determine which parts of the carry need to be differentiated; no longer is it just "all inexact arrays". And thus this Diffrax PR removes the compatibility shim that is no longer needed. **Implications.** SDE solving should now be much faster. In particular this fixes the speed issue reported in #317.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is fairly finickity! So, the checkpointed while loop (CWL) needs
to figure out which cotangents to propagate backward. There are two
criteria for this:
(1) only some inputs actually need cotangents at all. Any parts of the
carry that are unrelated to those inputs can be skipped.
(2) some cotangents might be resolvable as symbolic zeros, and can thus
be skipped.
Previously, we handled (2) but did not handle (1). In practice this
was mostly fine. We just pretended all (inexact) inputs needed
cotangents, and let anything superfluous get DCE'd at the end.
However, this has one niche problem, which is that it is incompatible
with
eqxi.nondifferentiable
-- which is used to mark an inexactinput as nondifferentiable, and raise a trace-time error if we attempt
to differentiate it.
With this change, we're now a bit more careful: we run an additional
fixed-point iteration to check which inputs need cotangents, and then
skip the rest.
Note that this extra fixed-point iteration is actually done using a
JVP. This is designed to mirror what JAX does internally: perturbations
are tracked using a fixed-point iteration inside the JVP rule for
lax.{while_loop,scan}
. I considered some other alternatives (usingVJPs or
pe.dce_jaxpr
) but I think those either fail in corner casesor run into catch-22s.