-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use jaxopt.LBFGS
on Haiku.Module
Parameters?
#424
Comments
If I try to exclude that
|
I guess the problem is with the |
Hi @Xemin0. |
sry for the late reply. If I do that - use I originally thought it was because the Parameters from haiku APIs are a dictionary, but now I suspect it has something to do with the forward method of the network created by the haiku Sequential module. I will provide a minimal code to reproduce the Error Message. |
Minimal Code to reproduce the issue:
|
which will throw an error, here's the full stack of the trace-back
|
|
??? I thought the Well, setting Oof, just checked the document again. it did say |
No, when |
How to use
jaxopt.LBFGS
onHaiku.Module
parameters? which is stored as dictionary.I am trying to create a wrapper
for the loss function that takes both the
tree_leaves
and 'tree_def' of the original parameters, so the flattened parameters can be passed intojaxopt.LBFGS
solver.and I have initialized the
LBFGS
solver usingSample code I wrote to update the weight using LBFGS solver:
But it throws an error at the
lbfgs_solver
lineand Here's the full stack of traceback and error msgs
Seems like the solver is not able to handle
pytree
Any tips I can work around it?
Sample code to create a simple MLP network, in case needed
The text was updated successfully, but these errors were encountered: