Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jaxopt.LBFGS on Haiku.Module Parameters? #424

Closed
Xemin0 opened this issue Apr 20, 2023 · 9 comments
Closed

Use jaxopt.LBFGS on Haiku.Module Parameters? #424

Xemin0 opened this issue Apr 20, 2023 · 9 comments
Labels
question Further information is requested

Comments

@Xemin0
Copy link

Xemin0 commented Apr 20, 2023

How to use jaxopt.LBFGS on Haiku.Module parameters? which is stored as dictionary.

I am trying to create a wrapper

def mseWrapper(flatten_params, tree_struct, x, y_true):

for the loss function that takes both the tree_leaves and 'tree_def' of the original parameters, so the flattened parameters can be passed into jaxopt.LBFGS solver.

and I have initialized the LBFGS solver using

lbfgs_solver = jaxopt.LBFGS(fun = mseWrapper, \
                            value_and_grad = True, maxiter = 500, history_size = 4)

Sample code I wrote to update the weight using LBFGS solver:

# LBFGS
    # Flatten the params dictionary
    flat_params, treedef = tree_flatten(updated_params)
    for i in range(iter_lbfgs):
        flat_params, opt_state = lbfgs_solver.run(flat_params, tree_struct = treedef, x = inputs, y_true = labels )
        loss = opt_state.value

        print(f'\r[Train LBFGS Step:{i}/{iter_lbfgs}]\tLoss:{loss:.4f}', end = '')
    # Unflatten the 
    updated_params = tree_unflatten(treedef, flat_params)

But it throws an error at the lbfgs_solver line

and Here's the full stack of traceback and error msgs

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/base.py in run(self, init_params, *args, **kwargs)
    253       run = decorator(run)
    254 
--> 255     return run(init_params, *args, **kwargs)
    256 
    257 

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/implicit_diff.py in wrapped_solver_fun(*args, **kwargs)
    249     args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
    250     keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251     return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
    252 
    253   return wrapped_solver_fun

    [... skipping hidden 4 frame]

~/opt/anaconda3/lib/python3.9/site-packages/jax/_src/core.py in concrete_aval(x)
   1338   if hasattr(x, '__jax_array__'):
   1339     return concrete_aval(x.__jax_array__())
-> 1340   raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
   1341                    "type")
   1342 

TypeError: Value PyTreeDef({'linear': {'b': *, 'w': *}, 'linear_1': {'b': *, 'w': *}, 'linear_2': {'b': *, 'w': *}, 'linear_3': {'b': *, 'w': *}, 'linear_4': {'b': *, 'w': *}, 'linear_5': {'b': *, 'w': *}, 'linear_6': {'b': *, 'w': *}, 'linear_7': {'b': *, 'w': *}}) with type <class 'jaxlib.xla_extension.pytree.PyTreeDef'> is not a valid JAX type

Seems like the solver is not able to handle pytree
Any tips I can work around it?


Sample code to create a simple MLP network, in case needed

# Define the Network
def myNet(x: jax.Array) -> jax.Array:
   mlp = hk.Sequential([
        hk.Linear(5), jax.nn.relu,
        hk.Linear(2),
   ])
   return mlp(x)

# Make the network
network = hk.without_apply_rng(hk.transform(myNet))

# Create sample input for network initialization
# 10 Equi-distant points between (-1, 1)
lb, ub = (-1, 1)
num_pts = 10
sample_x = jnp.reshape(jnp.linspace(lb, ub, num_pts) , (-1,1))

# Initialize the network parameters
init_params = network.init(jax.random.PRNGKey(seed = 0), sample_x)
@Xemin0
Copy link
Author

Xemin0 commented Apr 20, 2023

If I try to exclude that tree_def by making it a global variable, I'm still getting an error TypeError: iteration over a 0-d array ?

/var/folders/s4/_mlskmg11_7cd7yhvldl5cv80000gn/T/ipykernel_4579/21316609.py in trainApprox(params, x, y, opt_state, layers, network, iter_adam, iter_lbfgs, loss_fn)
     21     flatt_params, tree_def = tree_flatten(updated_params)
     22     for i in range(iter_lbfgs):
---> 23         flatten_params, opt_state = lbfgs_solver.run(flat_params, x = x, y_true = y, sizes = layers )
     24         loss = opt_state.value

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/base.py in run(self, init_params, *args, **kwargs)
    253       run = decorator(run)
    254 
--> 255     return run(init_params, *args, **kwargs)
    256 
    257 

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/implicit_diff.py in wrapped_solver_fun(*args, **kwargs)
    249     args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
    250     keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251     return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
    252 
    253   return wrapped_solver_fun

    [... skipping hidden 5 frame]

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/implicit_diff.py in solver_fun_flat(*flat_args)
    205     def solver_fun_flat(*flat_args):
    206       args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 207       return solver_fun(*args, **kwargs)
    208 
    209     def solver_fun_fwd(*flat_args):

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/base.py in _run(self, init_params, *args, **kwargs)
    195            *args,
    196            **kwargs) -> OptStep:
--> 197     state = self.init_state(init_params, *args, **kwargs)
    198 
    199     # We unroll the very first iteration. This allows `init_val` and `body_fun`

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/lbfgs.py in init_state(self, init_params, *args, **kwargs)
    282         stepsize=jnp.asarray(self.max_stepsize, dtype=dtype),
    283       )
--> 284     (value, aux), grad = self._value_and_grad_with_aux(init_params, *args, **kwargs)
    285     return LbfgsState(value=value,
    286                       grad=grad,

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/base.py in value_and_grad_fun(*a, **kw)
     75       fun_ = lambda *a, **kw: (fun(*a, **kw)[0], None)
     76       def value_and_grad_fun(*a, **kw):
---> 77         v, g = fun(*a, **kw)
     78         return (v, None), g
     79 

~/opt/anaconda3/lib/python3.9/site-packages/jax/_src/array.py in __iter__(self)
    301   def __iter__(self):
    302     if self.ndim == 0:
--> 303       raise TypeError("iteration over a 0-d array")  # same as numpy error
    304     else:
    305       assert self.is_fully_replicated or self.is_fully_addressable

TypeError: iteration over a 0-d array

@Xemin0
Copy link
Author

Xemin0 commented Apr 20, 2023

I guess the problem is with the transformed function network.apply, during the model's forward call

@mblondel
Copy link
Collaborator

mblondel commented Apr 21, 2023

Hi @Xemin0. jaxopt.LBFGS should natively support pytrees. So you should be able to replace def mseWrapper(flatten_params, tree_struct, x, y_true) with def mseWrapper(params, x, y_true).

@mblondel mblondel added the question Further information is requested label Apr 21, 2023
@Xemin0
Copy link
Author

Xemin0 commented Apr 25, 2023

sry for the late reply.

If I do that - use mseWrapper(params, x, y_true), and use lbfgs_solver.run on the mseWrapper I will get an error shown above TypeError: iteration over a 0-d array.

I originally thought it was because the Parameters from haiku APIs are a dictionary, but now I suspect it has something to do with the forward method of the network created by the haiku Sequential module.

I will provide a minimal code to reproduce the Error Message.

@Xemin0
Copy link
Author

Xemin0 commented Apr 25, 2023

Minimal Code to reproduce the issue:

'''
Minimal Code to test jaxopt.LBFGS on model using haiku APIs
'''
import haiku as hk
import jax
import jax.numpy as jnp
from jax import vmap

from jaxopt import LBFGS

# Define the Network
def myNet(x: jax.Array) -> jax.Array:
    mlp = hk.Sequential([
        hk.Linear(5), jax.nn.relu,
        hk.Linear(2),
    ])
    return mlp(x)


# Make the network (by using hk.transform)
'''
The forward method of the network:
- instead of using `predicted = myNet(x)`
- use 'predicted = network.apply(params, x)'
'''
network = hk.without_apply_rng(hk.transform(myNet))

# MSE loss defined for the network
def mse_loss(params, x, y_true):
    # vectorize the forward pass over the feature-dimension for each sample in the batch
    # output size [batch_size, feature_size = 1]
    y_pred = vmap(network.apply, in_axes = (None, 0))(params, x)
    diff = y_pred - y_true
    return jnp.mean( jnp.sum(diff ** 2, axis = 1) )

# Create sample input for the network 
# 10 Equi-distant points between (-1, 1)
lb, ub = (-1, 1)
num_pts = 10
sample_x = jnp.reshape(jnp.linspace(lb, ub, num_pts) , (-1,1))

# Initialize the network parameters
init_params = network.init(jax.random.PRNGKey(seed = 0), sample_x)

# Initialize the Optimizer
lbfgs_solver = LBFGS(fun = mse_loss, \
                     value_and_grad = True, maxiter = 500, history_size = 4)

# Perform one train/optimization step 
# to have the network approximate jnp.sin function
updated_params, opt_state = lbfgs_solver.run(init_params, x = sample_x, y_true = jnp.sin(sample_x))

@Xemin0
Copy link
Author

Xemin0 commented Apr 25, 2023

which will throw an error, here's the full stack of the trace-back

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/s4/_mlskmg11_7cd7yhvldl5cv80000gn/T/ipykernel_1593/221376305.py in <module>
     41 # Perform one train/optimizer step
     42 # to have the network approximate jnp.sin function
---> 43 updated_params, opt_state = lbfgs_solver.run(init_params, x = sample_x, y_true = jnp.sin(sample_x))

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/base.py in run(self, init_params, *args, **kwargs)
    253       run = decorator(run)
    254 
--> 255     return run(init_params, *args, **kwargs)
    256 
    257 

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/implicit_diff.py in wrapped_solver_fun(*args, **kwargs)
    249     args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
    250     keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251     return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
    252 
    253   return wrapped_solver_fun

    [... skipping hidden 5 frame]

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/implicit_diff.py in solver_fun_flat(*flat_args)
    205     def solver_fun_flat(*flat_args):
    206       args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 207       return solver_fun(*args, **kwargs)
    208 
    209     def solver_fun_fwd(*flat_args):

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/base.py in _run(self, init_params, *args, **kwargs)
    195            *args,
    196            **kwargs) -> OptStep:
--> 197     state = self.init_state(init_params, *args, **kwargs)
    198 
    199     # We unroll the very first iteration. This allows `init_val` and `body_fun`

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/lbfgs.py in init_state(self, init_params, *args, **kwargs)
    282         stepsize=jnp.asarray(self.max_stepsize, dtype=dtype),
    283       )
--> 284     (value, aux), grad = self._value_and_grad_with_aux(init_params, *args, **kwargs)
    285     return LbfgsState(value=value,
    286                       grad=grad,

~/opt/anaconda3/lib/python3.9/site-packages/jaxopt/_src/base.py in value_and_grad_fun(*a, **kw)
     75       fun_ = lambda *a, **kw: (fun(*a, **kw)[0], None)
     76       def value_and_grad_fun(*a, **kw):
---> 77         v, g = fun(*a, **kw)
     78         return (v, None), g
     79 

~/opt/anaconda3/lib/python3.9/site-packages/jax/_src/array.py in __iter__(self)
    301   def __iter__(self):
    302     if self.ndim == 0:
--> 303       raise TypeError("iteration over a 0-d array")  # same as numpy error
    304     else:
    305       assert self.is_fully_replicated or self.is_fully_addressable

TypeError: iteration over a 0-d array

@mblondel
Copy link
Collaborator

mblondel commented May 3, 2023

mse_loss returns only the value, not the value and the gradient. So you need to set value_and_grad=False.

@Xemin0
Copy link
Author

Xemin0 commented May 3, 2023

mse_loss returns only the value, not the value and the gradient. So you need to set value_and_grad=False.

??? I thought the value_and_grad argument is telling the LBFGS solver to use jax.value_and_grad instead of jax.grad while doing the autodifferentiation??

Well, setting value_and_grad=False (which is the default bool value) did work... I am confused now, but thank you !!


Oof, just checked the document again. it did say value_and_grad is meant to imply whether the fun (here is the mse_loss) returns just the value or both the value and the gradients. hmm, this is embarrassing...

@Xemin0 Xemin0 closed this as completed May 3, 2023
@mblondel
Copy link
Collaborator

mblondel commented May 3, 2023

No, when value_and_grad=True, it means that the provided function returns both the value and the gradient (this is a way of overriding the gradient)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants