Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: working principle of RecursiveCheckpointAdjoint #564

Open
TommyGiak opened this issue Jan 10, 2025 · 1 comment
Open

Question: working principle of RecursiveCheckpointAdjoint #564

TommyGiak opened this issue Jan 10, 2025 · 1 comment

Comments

@TommyGiak
Copy link

Hi! I’ve recently started diving deeper into Neural ODEs, but I feel like I’m missing some foundational concepts regarding checkpointing in backpropagation.

Could you help clarify how backpropagation is implemented through checkpointing in RecursiveCheckpointAdjoint? Specifically, I have a few questions:

  1. What exactly is saved during the forward pass1?
  2. Is the gradient (with respect to the loss) computed purely through the internal operations of the solver, after recomputing the forward pass using the saved states?
  3. If we need to recompute all the states of the “reverse equation,” why does using more checkpoints result in faster computation?
  4. Why does the implementation need to be recursive to save memory during backpropagation? Could a non-recursive implementation achieve the same efficiency?

I’m a bit confused about these concepts right now 🫠. Thanks in advance for your help!

Footnotes

  1. From what I understand, only the states of the ODE solve at certain points (determined by the number of checkpoints) are saved, but what about the computational graph?

@lockwo
Copy link
Contributor

lockwo commented Jan 10, 2025

I'll just remark on a few things I know that might be helpful.

  1. What's being saved? Basically the state at that point as you note (for more details on how jax implements checkpointing/graph you can see: https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html). For diffrax, you can see in this example:
import diffrax
import jax
from jax import numpy as jnp
from jax.ad_checkpoint import print_saved_residuals

vector_field = lambda t, y, args: -y
term = diffrax.ODETerm(vector_field)
solver = diffrax.Euler()
saveat = diffrax.SaveAt(ts=jnp.linspace(0, 1, 10))

def f(y):
  return diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.01, y0=y, saveat=saveat, adjoint=diffrax.RecursiveCheckpointAdjoint(4000)).ys.mean()

print(jax.ad_checkpoint.print_saved_residuals(f, jnp.array(1.0)))

The output has something like

i32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
i32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
bool[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
bool[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
f32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
f32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
f32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
bool[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
f32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
i32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
i32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
i32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
i32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
i32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)
i32[4000] output of jitted function 'diffeqsolve' from /usr/local/lib/python3.10/dist-packages/equinox/_jit.py:244 (_call)

(because the state is a bunch of arrays). If you change this to be 1 checkpoint, you will see these arrays are size 1.

  1. I'm not sure I understand this one, what else would the gradient depend on?

  2. As hinted above, the adjoint has a nice tradeoff between memory and speed. If you have checkpointing every step you store more intermediate results for your gradient computation (as you can see the arrays are larger), thus you're speed is just the standard cost of backprop. However, if you checkpoint less often, you store less memory (but you have to recompute things in order to actually get the information you need for backprop, thus it increases time). The reason this matters is because if you are solving a DE, you might have a much higher depth than something like in traditional ML (e.g. I have set adaptive solvers with max_steps of 100 million, which is not always a tractable array size to store in memory).

  3. I don't know as much on this one. All I'll say is I believe the "recursive" applies to the sort of multi-leveled approach of while loops both being checkpoints (e.g. the fact that there are two while loops, one for the integration over time and one for the interpolation/saving potentially between solver steps, so you're like "recursing" levels of gradient computation, see https://cseweb.ucsd.edu/~tzli/cse291/sp2024/lectures/checkpointing.pdf for more).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants