Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Adjoint of a Symplectic solver #541

Closed
ASKabalan opened this issue Dec 10, 2024 · 5 comments
Closed

[Question] Adjoint of a Symplectic solver #541

ASKabalan opened this issue Dec 10, 2024 · 5 comments

Comments

@ASKabalan
Copy link

Hello Patrick

Quick question on the adjoint methods

If I use a Symplectic solver is there a way to just do a reverse integration ?

With symplectic solvers, running back the steps adheres to the discritize then optimize strategy because we can follow back the same path (if I understood correctly)

Is using diffrax.BacksolveAdjoint enough?

Thanks

@patrick-kidger
Copy link
Owner

I think right now Diffrax doesn't support reverse integration through symplectic solves.

This kind of behaviour may improve in the future though... #528 :)

@ASKabalan
Copy link
Author

ASKabalan commented Dec 16, 2024

Hello,

Thank you for your answer

I am extremely interested in a backsolve that works for Symplectic Solvers

I created my own LeapFrog solver that is specific for Particle Mesh simulations, and I need to do the BackSolve process which is trivial (on paper) since I will be using only Constant Step Size

The way to do is is to run the adjoint terms in reverse (so perhaps have an adjoint step in the solver?)

If this is part of #528 then great !!

If not I would be more than happy to help

1 - Either remove the error NotImplementedError: diffrax.BacksolveAdjoint is only compatible with solvers that take a single term. and allow us to implement our own adjoint solvers
2 - Or have an adjoint step in the solver that runs AdjointODETerms in reverse

I would find the first solution better for me (and easier) but the second is perhaps cleaner

@patrick-kidger
Copy link
Owner

So I think in the short term the simplest best thing to do would be for you to define a custom adjoint method (subclass AbstractAdjoint and fill in your own behaviour). Indeed you could probably use the existing BacksolveAdjoint for inspiration, although I won't guarantee that it's exactly what you need.

In the long term I think we'll keep this use-case in mind as part of the reversible backprop work, in which case hopefully it will 'just work' at some point in the future.

@ASKabalan
Copy link
Author

Thank you very much for you answer 🙏

Before clossing the issue, I have a related question

I am trying to use use the BacksolveAdjoint with an ODETerm that contains a custom primitive function.
I will only need a reverse autodiff so I implemented a custom_vjp

the AdjointTerm evaluation is failing here
Because I do not have a diff rule for my primitive ..

So I did this MWE that simulate the issue

from jax import core
import jax
import jax.numpy as jnp
from jax._src.lib.mlir.dialects import hlo
from jax.interpreters import mlir
from jax import custom_vjp , custom_jvp

square_prim_p = core.Primitive("multiply_add")  # Create the primitive

def square_prim(x):
  return square_prim_p.bind(x)

def square_impl(x):
  return jnp.power(x, 2)

def square_abstract_eval(x):
  return core.ShapedArray(x.shape, x.dtype)

def square_lowering(ctx , xc):
  return hlo.MulOp(xc, xc).results


@custom_vjp
def square_vjp(x):
  return square_prim(x)

def square_vjp_fwd(x):
  return square_prim(x), x * 2

def square_vjp_bwd(res, g):
  jax.debug.print("res: {res}, g: {g}",res=res, g=g)
  return g * res , 


@custom_jvp
def square_jvp(x):
  return square_prim(x)

@square_jvp.defjvp
def square_jvp_impl(primals, tangents):
  x, = primals
  x_dot, = tangents
  primals_out = square_jvp(x)
  tangents_out = 2 * x_dot * x
  return primals_out, tangents_out


square_vjp.defvjp(square_vjp_fwd, square_vjp_bwd)

square_prim_p.def_impl(square_impl) 
square_prim_p.def_abstract_eval(square_abstract_eval)
mlir.register_lowering(square_prim_p, square_lowering)

vjp_grad = jax.jit(jax.grad(square_vjp))(jnp.array(3.0)) # works
jvp_grad = jax.jit(jax.grad(square_jvp))(jnp.array(3.0)) # works

print(f"vjp_grad: {vjp_grad}")
print(f"jvp_grad: {jvp_grad}")

def _fn_vjp(x):
  dy , vjp = jax.vjp(square_vjp, x)
  return vjp(dy)[0]

def _fn_jvp(x):
  dy , vjp = jax.vjp(square_jvp, x)
  return vjp(dy)[0]

jax.grad(_fn_jvp)(jnp.array(3.0)) # works
jax.grad(_fn_vjp)(jnp.array(3.0)) # fails

So I can safely say that custom_vjps are not compatible with the BacksolveAdjoint method?

@patrick-kidger
Copy link
Owner

For your example here, you can fix this via:

@@ -25,7 +25,7 @@ def square_vjp(x):
   return square_prim(x)
 
 def square_vjp_fwd(x):
-  return square_prim(x), x * 2
+  return square_vjp(x), x * 2
 
 def square_vjp_bwd(res, g):
   jax.debug.print("res: {res}, g: {g}",res=res, g=g)

Note that as you are performing second-order autodifferentiation then the internals of your first-order autodiff rule must themselves be autodifferentiable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants