Skip to content

Commit

Permalink
Merge pull request #382 from aymgal:pr-hess_inv
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 519014528
  • Loading branch information
JAXopt authors committed Mar 24, 2023
2 parents 36d7a0d + 7f54e31 commit 1019f7b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
56 changes: 55 additions & 1 deletion jaxopt/_src/scipy_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,56 @@
import jax
import jax.numpy as jnp
import jax.tree_util as tree_util
from jax.config import config

from jaxopt._src import base
from jaxopt._src import implicit_diff as idf
from jaxopt._src import projection
from jaxopt._src.tree_util import tree_sub
from jax.tree_util import register_pytree_node_class

import numpy as onp
import scipy as osp
from scipy.optimize import LbfgsInvHessProduct


@register_pytree_node_class
class LbfgsInvHessProductPyTree(LbfgsInvHessProduct):
"""
Registers the LbfgsInvHessProduct object as a PyTree.
This object is typically returned by the L-BFSG-B optimizer to efficiently
store the inverse of the Hessian matrix evaluated at the best-fit parameters.
"""

def __init__(self, sk, yk):
"""
Construct the operator.
This is the same constructor as the original LbfgsInvHessProduct class,
except that numpy has been replaced by jax.numpy and no call to the
numpy.ndarray constuctor is performed.
"""
if sk.shape != yk.shape or sk.ndim != 2:
raise ValueError('sk and yk must have matching shape, (n_corrs, n)')
n_corrs, n = sk.shape
self.dtype = jnp.float64 if config.jax_enable_x64 is True else jnp.float32
self.shape = (n, n)
self.sk = sk
self.yk = yk
self.n_corrs = n_corrs
self.rho = 1 / jnp.einsum('ij,ij->i', sk, yk)


def __repr__(self):
return "LbfgsInvHessProduct(sk={}, yk={})".format(self.sk, self.yk)

def tree_flatten(self):
children = (self.sk, self.yk)
aux_data = None
return (children, aux_data)

@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)


class ScipyMinimizeInfo(NamedTuple):
Expand All @@ -52,6 +94,7 @@ class ScipyMinimizeInfo(NamedTuple):
success: bool
status: int
iter_num: int
hess_inv: Optional[Union[jnp.ndarray, LbfgsInvHessProductPyTree]]


class ScipyRootInfo(NamedTuple):
Expand Down Expand Up @@ -312,10 +355,21 @@ def scipy_fun(x_onp: onp.ndarray) -> Tuple[onp.ndarray, onp.ndarray]:
options=self.options)

params = tree_util.tree_map(jnp.asarray, onp_to_jnp(res.x))

if hasattr(res, 'hess_inv'):
if isinstance(res.hess_inv, osp.optimize.LbfgsInvHessProduct):
hess_inv = LbfgsInvHessProductPyTree(res.hess_inv.sk,
res.hess_inv.yk)
elif isinstance(res.hess_inv, onp.ndarray):
hess_inv = jnp.asarray(res.hess_inv)
else:
hess_inv = None

info = ScipyMinimizeInfo(fun_val=jnp.asarray(res.fun),
success=res.success,
status=res.status,
iter_num=res.nit)
iter_num=res.nit,
hess_inv=hess_inv)
return base.OptStep(params, info)

def run(self,
Expand Down
34 changes: 34 additions & 0 deletions tests/scipy_wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,40 @@ def wrapper(box_len):
self.assertArraysAllClose(array_num, array_custom, atol=1e-2)


class BFGSInverseHessianTest(test_util.JaxoptTestCase):

def setUp(self):
super().setUp()

true_sol = [0.5, -1.2]
# define a very simple objective
self.objective = lambda x: (x[0] - true_sol[0])**2. + (x[1] - true_sol[1])**2.
# analytical derivative of the objective function
self.derivative = lambda x: [(x[0] - true_sol[0]) * 2., (x[1] - true_sol[1]) * 2.]
# define the starting point as a random sample from the domain
r_min, r_max = -5.0, 5.0
onp.random.seed(6574)
self.x0 = r_min + onp.random.rand(2) * (r_max - r_min)

def test_inverse_hessian_bfgs(self):
# perform the bfgs algorithm search
res_scipy = osp.optimize.minimize(self.objective, self.x0,
method='BFGS', jac=self.derivative)
# run the same algorithm via jaxopt
res_jaxopt = ScipyMinimize(method='BFGS', fun=self.objective).run(self.x0)
# compare the two returned inverse Hessian matrices
self.assertAllClose(res_scipy.hess_inv, res_jaxopt.state.hess_inv)

def test_inverse_hessian_lbfgs(self):
# perform the l-bfgs-b algorithm search
res_scipy = osp.optimize.minimize(self.objective, self.x0,
method='L-BFGS-B', jac=self.derivative)
# run the same algorithm via jaxopt
res_jaxopt = ScipyMinimize(method='L-BFGS-B', fun=self.objective).run(self.x0)
# compare the two returned inverse Hessian matrices
self.assertAllClose(res_scipy.hess_inv.todense(), res_jaxopt.state.hess_inv.todense())


if __name__ == '__main__':
# Uncomment the line below in order to run in float64.
# jax.config.update("jax_enable_x64", True)
Expand Down

0 comments on commit 1019f7b

Please sign in to comment.