Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CWL now compatible with partially nondifferentiable body functions. #548

Merged
merged 5 commits into from
Oct 9, 2023

Conversation

patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger commented Oct 9, 2023

This is fairly finickity! So, the checkpointed while loop (CWL) needs
to figure out which cotangents to propagate backward. There are two
criteria for this:
(1) only some inputs actually need cotangents at all. Any parts of the
carry that are unrelated to those inputs can be skipped.
(2) some cotangents might be resolvable as symbolic zeros, and can thus
be skipped.

Previously, we handled (2) but did not handle (1). In practice this
was mostly fine. We just pretended all (inexact) inputs needed
cotangents, and let anything superfluous get DCE'd at the end.

However, this has one niche problem, which is that it is incompatible
with eqxi.nondifferentiable -- which is used to mark an inexact
input as nondifferentiable, and raise a trace-time error if we attempt
to differentiate it.

With this change, we're now a bit more careful: we run an additional
fixed-point iteration to check which inputs need cotangents, and then
skip the rest.

Note that this extra fixed-point iteration is actually done using a
JVP. This is designed to mirror what JAX does internally: perturbations
are tracked using a fixed-point iteration inside the JVP rule for
lax.{while_loop,scan}. I considered some other alternatives (using
VJPs or pe.dce_jaxpr) but I think those either fail in corner cases
or run into catch-22s.

@patrick-kidger patrick-kidger changed the title better loops CWL now compatible with partially nondifferentiable body functions. Oct 9, 2023
This is fairly finickity! So, the checkpointed while loop (CWL) needs
to figure out which cotangents to propagate backward. There are two
criteria for this:
(1) only some inputs actually need cotangents at all. Any parts of the
    carry that are unrelated to those inputs can be skipped.
(2) some cotangents might be resolvable as symbolic zeros, and can thus
    be skipped.

Previously, we handled (2) but did not handle (1). In practice this
was mostly fine. We just pretended all (inexact) inputs needed
cotangents, and let anything superfluous get DCE'd at the end.

However, this has one niche problem, which is that it is incompatible
with `eqxi.nondifferentiable` -- which is used to mark an inexact
input as nondifferentiable, and raise a trace-time error if we attempt
to differentiate it.

With this change, we're now a bit more careful: we run an additional
fixed-point iteration to check which inputs need cotangents, and then
skip the rest.

Note that this extra fixed-point iteration is actually done using a
JVP. This is designed to mirror what JAX does internally: perturbations
are tracked using a fixed-point iteration inside the JVP rule for
`lax.{while_loop,scan}`. I considered some other alternatives (using
VJPs or `pe.dce_jaxpr`) but I think those either fail in corner cases
or run into catch-22s.
patrick-kidger added a commit to patrick-kidger/diffrax that referenced this pull request Oct 9, 2023
The main changes needed to make this happen are in
patrick-kidger/equinox#548
and as such this commit is fairly small -- it declares a dependency on
a new (as yet unreleasd) version of Equinox, and removes the
compatibility shim that was there before.

**How things used to be.**

To explain what's going on here a little more carefully: JAX only
recently added support for communicating to a `custom_vjp` which
input arguments were being perturbed, and which output cotangents were
symbolic zeros. Prior to that, a `custom_vjp` basically just had to
differentiate all inexact arrays.

However, there are some nondifferentiable inexact arrays, in the
sense that attempting to differentiate them will raise an error.
Solving SDEs has one such array: the nondifferentiable input to a
VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to
reflect the fact that Brownian motion is nondifferentiable.

So it used to be the case that the `custom_vjp` underlying
`RecursiveCheckpointAdjoint` would differentiate the overall
make-a-step function with respect to all inexact arrays,
including the time variable, hit the `nondifferentiable` guard, and
crash.

One (unsafe) fix would have just been to remove the `nondifferentiable`
guard. In practice I previously took the slower, safer option: silently
switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The
latter is much less efficient, but uses no `custom_vjp`, and thus used
the perturbation and symbolic-zero propagation rules already present
in JAX's AD machinery, and was thus safe to use here.

**How things are now.**

And so, what has now changed: JAX has now added support for tracking
which inputs are perturbed, and which cotangents are symbolic zeros.

The Equinox PR above uses this functionality to determine which parts
of the carry need to be differentiated; no longer is it just "all
inexact arrays".

And thus this Diffrax PR removes the compatibility shim that is no
longer needed.

**Implications.**

SDE solving should now be much faster. In particular this fixes the
speed issue reported in #317.
@patrick-kidger patrick-kidger merged commit 520ac35 into main Oct 9, 2023
2 checks passed
@patrick-kidger patrick-kidger deleted the better-loops branch October 9, 2023 19:45
patrick-kidger added a commit to patrick-kidger/diffrax that referenced this pull request Oct 13, 2023
The main changes needed to make this happen are in
patrick-kidger/equinox#548
and as such this commit is fairly small -- it declares a dependency on
a new (as yet unreleasd) version of Equinox, and removes the
compatibility shim that was there before.

**How things used to be.**

To explain what's going on here a little more carefully: JAX only
recently added support for communicating to a `custom_vjp` which
input arguments were being perturbed, and which output cotangents were
symbolic zeros. Prior to that, a `custom_vjp` basically just had to
differentiate all inexact arrays.

However, there are some nondifferentiable inexact arrays, in the
sense that attempting to differentiate them will raise an error.
Solving SDEs has one such array: the nondifferentiable input to a
VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to
reflect the fact that Brownian motion is nondifferentiable.

So it used to be the case that the `custom_vjp` underlying
`RecursiveCheckpointAdjoint` would differentiate the overall
make-a-step function with respect to all inexact arrays,
including the time variable, hit the `nondifferentiable` guard, and
crash.

One (unsafe) fix would have just been to remove the `nondifferentiable`
guard. In practice I previously took the slower, safer option: silently
switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The
latter is much less efficient, but uses no `custom_vjp`, and thus used
the perturbation and symbolic-zero propagation rules already present
in JAX's AD machinery, and was thus safe to use here.

**How things are now.**

And so, what has now changed: JAX has now added support for tracking
which inputs are perturbed, and which cotangents are symbolic zeros.

The Equinox PR above uses this functionality to determine which parts
of the carry need to be differentiated; no longer is it just "all
inexact arrays".

And thus this Diffrax PR removes the compatibility shim that is no
longer needed.

**Implications.**

SDE solving should now be much faster. In particular this fixes the
speed issue reported in #317.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant