-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jaxopt's L-BFGS-B with custom gradient not matching with scipy implementation #620
Comments
Here are the values from my scipy and jaxopt runs, in that order Final objective: 4.84871 vs 1.5213378e+20 which seems concerning. Optimized value of
Optimized value of
|
jithendaraa
changed the title
jaxopt's L-BFGS-B not matching with scipy implementation
jaxopt's L-BFGS-B with custom gradient not matching with scipy implementation
Jan 26, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Context:
I am migrating code (causalnex's dynotears) from a numpy/scipy implementation to a jax implementation. This essentially involves moving from scipy's LBFGS-B to jaxopt's implementation so I can jit this function and run it faster.
Apart from the
_func(..)
to minimize, the code has a custom_grad(..)
function defined for the optimization. I converted both_func()
and_grad()
to their jax counterparts, and am usingjaxopt.LBFGSB
with the custom grad function like.Original numpy/scipy implementation
My current jaxopt implementation
I have ensured that
_func_jax
returns(loss, _grad_jax(params))
compared to_func()
which returns just the scalar. I'm not expecting exact answers between scipy/jaxopt implementations since I understand there will be numerical issues, even if seeds are set. But there seems to be a large mismatch between the scipy and jaxopt versions.I do get some warnings during my run like:
Would really help to understand what is causing these differences (and if they are expected or not)?
Versions:
The text was updated successfully, but these errors were encountered: