Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LBFGSB raises TypeError: got an unexpected keyword argument 'bounds' when being differentiated #463

Closed
phinate opened this issue Jul 3, 2023 · 2 comments · Fixed by #464
Labels
bug Something isn't working

Comments

@phinate
Copy link
Contributor

phinate commented Jul 3, 2023

Simple example (jaxopt 0.7, python 3.11.3, in a jupyter notebook)

import jaxopt
import jax.numpy as jnp
import jax.scipy as jsp
import jax

def pipeline(x, init_pars, bounds, data):
    def fit_objective(pars, data, x):
        return -jsp.stats.norm.logpdf(pars, loc=data+x, scale=1.0)
    solver = jaxopt.LBFGSB(fun=fit_objective, implicit_diff=True, maxiter=500, tol=1e-6)
    return solver.run(init_pars, bounds=bounds, data=data, x=x)[0]

pipeline(0.5, jnp.array(0.0), (jnp.array(0.0), jnp.array(10.0)), 1.0)
# -> Array(1.5, dtype=float32)

but if i try to differentiate this wrt x:

jax.grad(pipeline)(0.5, jnp.array(0.0), (jnp.array(0.0), jnp.array(10.0)), 1.0)

I get the following stack trace:

JaxStackTraceBeforeTransformation         Traceback (most recent call last)

File [~/.pyenv/versions/3.11.3/lib/python3.11/asyncio/base_events.py:607](https://file+.vscode-resource.vscode-cdn.net/Users/nsimpson/code/neos/relaxed/~/.pyenv/versions/3.11.3/lib/python3.11/asyncio/base_events.py:607), in run_forever()
    606 while True:
...
   3201             'got an unexpected keyword argument {arg!r}'.format(
   3202                 arg=next(iter(kwargs))))
   3204 return self._bound_arguments_cls(self, arguments)

TypeError: got an unexpected keyword argument 'bounds'

This only occurred for me when I'm specifying additional keyword arguments in the run call, as it somehow tries to apply bounds as an argument to fit_objective here. This behaviour is perhaps a bit ambiguous -- maybe bounds could be an __init__ attribute?

@mblondel mblondel added the bug Something isn't working label Jul 3, 2023
@mblondel
Copy link
Collaborator

mblondel commented Jul 3, 2023

Thanks for the bug report! I think I know why. We need to register the signature, as done with hyperparams_prox for ProximalGradient: see this line. We need to do the same thing for bounds.

@phinate
Copy link
Contributor Author

phinate commented Jul 4, 2023

Thanks for the bug report! I think I know why. We need to register the signature, as done with hyperparams_prox for ProximalGradient: see this line. We need to do the same thing for bounds.

This was exactly right! Fix in #464 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants