-
-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question: working principle of RecursiveCheckpointAdjoint #564
Comments
I'll just remark on a few things I know that might be helpful.
import diffrax
import jax
from jax import numpy as jnp
from jax.ad_checkpoint import print_saved_residuals
vector_field = lambda t, y, args: -y
term = diffrax.ODETerm(vector_field)
solver = diffrax.Euler()
saveat = diffrax.SaveAt(ts=jnp.linspace(0, 1, 10))
def f(y):
return diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.01, y0=y, saveat=saveat, adjoint=diffrax.RecursiveCheckpointAdjoint(4000)).ys.mean()
print(jax.ad_checkpoint.print_saved_residuals(f, jnp.array(1.0))) The output has something like
(because the state is a bunch of arrays). If you change this to be 1 checkpoint, you will see these arrays are size 1.
|
Hi! I’ve recently started diving deeper into Neural ODEs, but I feel like I’m missing some foundational concepts regarding checkpointing in backpropagation.
Could you help clarify how backpropagation is implemented through checkpointing in RecursiveCheckpointAdjoint? Specifically, I have a few questions:
I’m a bit confused about these concepts right now 🫠. Thanks in advance for your help!
Footnotes
From what I understand, only the states of the ODE solve at certain points (determined by the number of checkpoints) are saved, but what about the computational graph? ↩
The text was updated successfully, but these errors were encountered: