Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxopt's L-BFGS-B with custom gradient not matching with scipy implementation #620

Open
jithendaraa opened this issue Jan 26, 2025 · 1 comment

Comments

@jithendaraa
Copy link

jithendaraa commented Jan 26, 2025

Context:

I am migrating code (causalnex's dynotears) from a numpy/scipy implementation to a jax implementation. This essentially involves moving from scipy's LBFGS-B to jaxopt's implementation so I can jit this function and run it faster.

Apart from the _func(..) to minimize, the code has a custom _grad(..) function defined for the optimization. I converted both _func() and _grad() to their jax counterparts, and am using jaxopt.LBFGSB with the custom grad function like.

Original numpy/scipy implementation

# initialise matrix, weights and constraints
wa_est = np.zeros(2 * (p_orders + 1) * d_vars**2)
wa_new = np.zeros(2 * (p_orders + 1) * d_vars**2)
rho, alpha, h_value, h_new = 1.0, 0.0, np.inf, np.inf

for n_iter in range(max_iter):
    while (rho < 1e20) and (h_new > 0.25 * h_value or h_new == np.inf):
        wa_new = sopt.minimize(
            _func, 
            wa_est, 
            method="L-BFGS-B", 
            jac=_grad, 
            bounds=bnds
        ).x
        h_new = _h(wa_new, d_vars, p_orders)
        if h_new > 0.25 * h_value:
            rho *= 10

    wa_est = wa_new
    h_value = h_new
    alpha += rho * h_value
    if h_value <= h_tol:
        break
    if h_value > h_tol and n_iter == max_iter - 1:
        warnings.warn("Failed to converge. Consider increasing max_iter.")

My current jaxopt implementation

# bnds is a list of (lower, upper) tuples, where upper might have None values. 
# Make it compatible with what jaxopt.LBFGSB expects
np_bnds = np.array(bnds)
lowers = jnp.array(np_bnds[:, 0].astype(float))
cleaned_uppers = np.where(np_bnds[:, 1] == None, jnp.inf, np_bnds[:, 1])
uppers = jnp.array(cleaned_uppers.astype(float))
jnp_lbfgs_bounds = (lowers, uppers)

lbfgsb_solver = LBFGSB(fun=_func_jax, value_and_grad=True)

for n_iter in range(max_iter):
    while (rho < 1e20) and (h_new > 0.25 * h_value or h_new == jnp.inf):
        wa_new = lbfgsb_solver.run(
            wa_est, 
            bounds=jnp_lbfgs_bounds
        ).params

        h_new = _h_jax(wa_new, d_vars, p_orders)
        if h_new > 0.25 * h_value:
            rho *= 10

    wa_est = wa_new
    h_value = h_new
    alpha += rho * h_value
    if h_value <= h_tol:
        break
    if h_value > h_tol and n_iter == max_iter - 1:
        warnings.warn("Failed to converge. Consider increasing max_iter.")

I have ensured that _func_jax returns (loss, _grad_jax(params)) compared to _func() which returns just the scalar. I'm not expecting exact answers between scipy/jaxopt implementations since I understand there will be numerical issues, even if seeds are set. But there seems to be a large mismatch between the scipy and jaxopt versions.

I do get some warnings during my run like:

WARNING: jaxopt.ZoomLineSearch: No interval satisfying curvature condition.Consider increasing maximal possible stepsize of the linesearch.
WARNING: jaxopt.ZoomLineSearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.

Would really help to understand what is causing these differences (and if they are expected or not)?

Versions:

jax: 0.4.31
jaxopt: 0.8.3
numpy: 1.23.5
scipy:  1.13.1
@jithendaraa
Copy link
Author

jithendaraa commented Jan 26, 2025

Here are the values from my scipy and jaxopt runs, in that order

Final objective: 4.84871 vs 1.5213378e+20 which seems concerning.

Optimized value of wa_new (scipy/numpy)

array([0.00000000e+00, 1.58362135e+00, 0.00000000e+00, 8.64218226e-01,
       0.00000000e+00, 5.30177996e-06, 0.00000000e+00, 7.23155393e-06,
       1.08522093e-04, 3.21722607e-01, 0.00000000e+00, 1.58837897e+00,
       0.00000000e+00, 2.17025397e-01, 5.02410566e-01, 9.24620810e-06,
       2.14839298e-01, 4.42866934e-05, 0.00000000e+00, 0.00000000e+00,
       4.66671845e-06, 9.10319126e-05, 2.85940340e-05, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 8.72885996e-04,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.54268367e-02,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       3.84189744e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       4.18412632e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 1.66891599e-01, 1.05293901e-01, 0.00000000e+00,
       0.00000000e+00, 1.10260318e-01, 1.31643372e-01, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 8.44444677e-01, 6.77765312e-01,
       0.00000000e+00, 5.05999801e-02, 1.00552653e-01, 0.00000000e+00,
       0.00000000e+00, 1.54582784e-01, 0.00000000e+00, 1.13265831e+00,
       0.00000000e+00, 0.00000000e+00, 7.67622627e-01, 0.00000000e+00,
       1.23048506e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       3.12075774e-03, 9.20508971e-03, 0.00000000e+00, 0.00000000e+00,
       2.24813920e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 3.25643453e-01, 0.00000000e+00, 0.00000000e+00,
       5.90566523e-02, 1.91172397e-01, 0.00000000e+00, 3.41801208e-02,
       0.00000000e+00, 1.21313949e-01, 7.98333390e-02, 0.00000000e+00])

Optimized value of wa_new (jaxopt)

[0.00000000e+00 5.46834469e-01 0.00000000e+00 2.78253078e-01
 7.89659377e-03 1.30613953e-01 0.00000000e+00 6.36773586e-01
 8.86588693e-01 1.54841435e+00 0.00000000e+00 8.11460614e-01
 0.00000000e+00 2.19483122e-01 3.30503821e-01 1.15662307e-01
 8.80117357e-01 1.58519089e-01 0.00000000e+00 7.15284869e-02
 7.65197736e-04 1.42049563e+00 1.31482124e-01 1.29014403e-01
 0.00000000e+00 0.00000000e+00 0.00000000e+00 1.60738841e-01
 0.00000000e+00 2.11134859e-04 7.69150443e-04 0.00000000e+00
 4.57316667e-01 5.89280963e-01 1.16028547e+00 1.55092672e-01
 8.82897153e-02 0.00000000e+00 8.81871283e-02 0.00000000e+00
 1.20857454e-04 3.02751184e-01 7.04137236e-02 0.00000000e+00
 4.84485656e-01 3.28849182e-02 8.90319526e-01 5.75939834e-04
 4.59955156e-01 0.00000000e+00 7.57855771e-04 0.00000000e+00
 0.00000000e+00 3.89564373e-02 5.76256029e-02 2.97483569e-03
 5.17785847e-01 1.83005854e-01 2.27913812e-01 1.99340269e-01
 0.00000000e+00 4.21245098e-01 4.41471756e-01 3.43075842e-01
 0.00000000e+00 4.60929386e-02 9.65917408e-02 7.06027970e-02
 3.14517925e-03 1.59189552e-01 1.21413296e-13 5.95388710e-01
 2.26187472e-12 8.46469775e-02 8.20553839e-01 8.88178420e-16
 1.55108944e-01 2.44511813e-02 0.00000000e+00 0.00000000e+00
 6.37098476e-02 4.20914859e-01 1.23776801e-01 1.65834844e-01
 4.72670466e-01 9.22673717e-02 8.98242220e-02 0.00000000e+00
 0.00000000e+00 1.64450020e-01 1.21536679e-04 5.83359003e-02
 8.76099318e-02 1.33519948e-01 0.00000000e+00 9.29643065e-02
 9.17113125e-02 3.66380751e-01 1.61518916e-01 7.18804449e-02]

@jithendaraa jithendaraa changed the title jaxopt's L-BFGS-B not matching with scipy implementation jaxopt's L-BFGS-B with custom gradient not matching with scipy implementation Jan 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant