-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed the iMAML tutorial #425
Conversation
Unfortunately I cannot test on TPU on Colab because they are not supported anymore:
@fabianp Would you like me to test on Kaggle even if the quick button still points to Colab? |
If anyone is interested, I live streamed the last part of my implementation of this PR on twitch: https://www.twitch.tv/videos/1807716437 |
Thanks a lot for the contribution @zaccharieramzi. @fabianp Can you take a look when you get some time? Thanks! |
sure will do. the streaming is awesome 😄 |
thanks a lot for this @zaccharieramzi . Could you please run jupytext as described here (https://jaxopt.github.io/stable/developer.html#syncing-notebooks) so that we can see more easily the diff ? |
@fabianp sure done! |
Thanks a lot @zaccharieramzi, this looks great! And thanks for the link to reviewnb, it does look great indeed. We use jupytext mostly because JAX uses this same tooling. One last thing before we merge. Could you please squash these into a single commit ? (it makes it easier for us to sync with the Google infra) |
… generation, outer loss computation
b7f4b55
to
f6c0ca0
Compare
@fabianp sure done! |
commit 3614817 Author: Vincent Roulet <[email protected]> Date: Tue Jun 13 15:07:47 2023 -0700 added approx wolfe condition to fix bug at high precision commit d2211bc Author: Vincent Roulet <[email protected]> Date: Mon Jun 12 13:43:05 2023 -0700 fix attempt for failed test correctness lbfgsb commit 024bca8 Merge: c21479e 257b673 Author: Vincent Roulet <[email protected]> Date: Mon Jun 12 09:06:42 2023 -0700 Merge branch 'main' into zoom_linesearch_as_iterative_linesearch_solver commit 257b673 Merge: d40b6d7 0cfa882 Author: JAXopt authors <[email protected]> Date: Mon Jun 12 08:20:07 2023 -0700 Merge pull request google#440 from mblondel:lbfb_failure PiperOrigin-RevId: 539658221 commit 0cfa882 Author: Mathieu Blondel <[email protected]> Date: Mon Jun 12 16:28:14 2023 +0200 Drop Python 3.7 support. commit c21479e Merge: 675ae9d d108ddf Author: Vincent Roulet <[email protected]> Date: Fri Jun 9 15:33:21 2023 -0700 merging with main branch commit d108ddf Author: Srinivas Vasudevan <[email protected]> Date: Tue Jun 6 13:49:34 2023 -0700 Internal change PiperOrigin-RevId: 538281622 commit 414b5b9 Author: Guillaume Dalle <[email protected]> Date: Sat May 27 09:14:40 2023 +0200 Fix typos commit dcae685 Author: Mathieu Blondel <[email protected]> Date: Fri May 26 23:26:05 2023 +0200 Release v0.7. commit 83c1370 Author: Chansoo Lee <[email protected]> Date: Fri May 26 09:21:31 2023 -0700 Internal change PiperOrigin-RevId: 535636577 commit af48e7b Author: Zaccharie Ramzi <[email protected]> Date: Mon May 8 20:19:37 2023 +0200 fixed imaml tutorial (speed and correctness): phase application, data generation, outer loss computation commit 5392a81 Author: Fabian Pedregosa <[email protected]> Date: Wed Feb 15 15:33:31 2023 +0100 Misc improvements in resnet_flax example. * Better data augmentation, leading to 88% accuracy (from 70%) * Plots showing the data augmentation in action. * Options use the same format as distributed training examples. * Changed solver from Adam to SGD for better accuracy. commit a260cfb Author: Emily Fertig <[email protected]> Date: Tue Apr 25 08:47:41 2023 -0700 Internal change PiperOrigin-RevId: 526978774 commit e593c89 Author: Vincent Roulet <[email protected]> Date: Tue Apr 4 17:54:20 2023 -0700 Fixed prox to handle pytrees Fixed prox_lasso and prox_elastic_net to handle pytrees as inputs and floats for hyperparameters Added tests commit 675ae9d Author: Vincent Roulet <[email protected]> Date: Fri Jun 9 13:27:24 2023 -0700 integrated new zoom linesearch in all solvers, simplifying them commit 8da43d3 Author: Vincent Roulet <[email protected]> Date: Wed Jun 7 14:20:41 2023 -0700 minor edit commit 416e687 Author: Vincent Roulet <[email protected]> Date: Wed Jun 7 14:10:16 2023 -0700 fix copyright year commit 8f2d2c2 Author: Vincent Roulet <[email protected]> Date: Wed Jun 7 11:29:03 2023 -0700 fixed dtypes commit 8ca6e67 Author: Vincent Roulet <[email protected]> Date: Tue Jun 6 22:19:02 2023 -0700 minor edits commit 4c847ea Author: Vincent Roulet <[email protected]> Date: Tue Jun 6 21:41:53 2023 -0700 convert zoom_linesearch into an IterativeLineSearchSolver commit d40b6d7 Author: Srinivas Vasudevan <[email protected]> Date: Tue Jun 6 13:49:34 2023 -0700 Internal change PiperOrigin-RevId: 538281622 commit e87b9b9 Merge: 58ce7cb 3ccb6b9 Author: JAXopt authors <[email protected]> Date: Sat May 27 08:28:11 2023 -0700 Merge pull request google#435 from gdalle:patch-1 PiperOrigin-RevId: 535860339 commit 3ccb6b9 Author: Guillaume Dalle <[email protected]> Date: Sat May 27 09:14:40 2023 +0200 Fix typos commit 58ce7cb Merge: 541bbaa 7cf0567 Author: JAXopt authors <[email protected]> Date: Fri May 26 14:54:44 2023 -0700 Merge pull request google#434 from mblondel:release_0.7 PiperOrigin-RevId: 535720960 commit 7cf0567 Author: Mathieu Blondel <[email protected]> Date: Fri May 26 23:26:05 2023 +0200 Release v0.7. commit 541bbaa Author: Chansoo Lee <[email protected]> Date: Fri May 26 09:21:31 2023 -0700 Internal change PiperOrigin-RevId: 535636577 commit b934387 Merge: b3b6a0d f6c0ca0 Author: JAXopt authors <[email protected]> Date: Mon May 15 07:04:49 2023 -0700 Merge pull request google#425 from zaccharieramzi:fix-maml-example PiperOrigin-RevId: 532098566 commit f6c0ca0 Author: Zaccharie Ramzi <[email protected]> Date: Mon May 8 20:19:37 2023 +0200 fixed imaml tutorial (speed and correctness): phase application, data generation, outer loss computation commit b3b6a0d Merge: 4aa9bc9 fff693f Author: JAXopt authors <[email protected]> Date: Thu Apr 27 07:24:22 2023 -0700 Merge pull request google#401 from fabianp:resnet_flax PiperOrigin-RevId: 527570739 commit 4aa9bc9 Author: Emily Fertig <[email protected]> Date: Tue Apr 25 08:47:41 2023 -0700 Internal change PiperOrigin-RevId: 526978774 commit fff693f Author: Fabian Pedregosa <[email protected]> Date: Wed Feb 15 15:33:31 2023 +0100 Misc improvements in resnet_flax example. * Better data augmentation, leading to 88% accuracy (from 70%) * Plots showing the data augmentation in action. * Options use the same format as distributed training examples. * Changed solver from Adam to SGD for better accuracy. commit 4edd8ac Merge: 674a992 7da12ec Author: JAXopt authors <[email protected]> Date: Wed Apr 12 14:00:46 2023 -0700 Merge pull request google#420 from vroulet:fix_prox_pytree PiperOrigin-RevId: 523798658 commit 7da12ec Author: Vincent Roulet <[email protected]> Date: Tue Apr 4 17:54:20 2023 -0700 Fixed prox to handle pytrees Fixed prox_lasso and prox_elastic_net to handle pytrees as inputs and floats for hyperparameters Added tests commit 674a992 Merge: 1019f7b 18c4bd3 Author: JAXopt authors <[email protected]> Date: Wed Apr 5 01:39:26 2023 -0700 Merge pull request google#418 from LawrenceMMStewart:main PiperOrigin-RevId: 521986072 commit 18c4bd3 Author: LawrenceMMStewart <[email protected]> Date: Fri Mar 31 12:12:57 2023 +0200 added control variate to make_perturbed_argmax commit 1019f7b Merge: 36d7a0d 7f54e31 Author: JAXopt authors <[email protected]> Date: Thu Mar 23 18:16:45 2023 -0700 Merge pull request google#382 from aymgal:pr-hess_inv PiperOrigin-RevId: 519014528 commit 36d7a0d Author: Quentin Berthet <[email protected]> Date: Tue Mar 21 06:08:00 2023 -0700 Internal change PiperOrigin-RevId: 518250976 commit 7f54e31 Merge: a4f3956 ea8e0f1 Author: Aymeric Galan <[email protected]> Date: Thu Mar 16 15:13:49 2023 +0100 Merge remote-tracking branch 'upstream/main' into pr-hess_inv commit ea8e0f1 Merge: cb6ed9a e196ece Author: JAXopt authors <[email protected]> Date: Wed Mar 15 13:40:56 2023 -0700 Merge pull request google#412 from froystig:jit-bisect-test PiperOrigin-RevId: 516917371 commit e196ece Author: Roy Frostig <[email protected]> Date: Wed Mar 15 17:25:40 2023 +0000 avoid closing over dynamic jax tracers in the bisection solver Internally in jaxopt, we (should) try to maintain that the parameters to a solver class are "static" from jax's point of view. One reason for this is that class attributes might be read by any of the class' methods, including `run`. Meanwhile a bound `run` method serves as the solver function, which is passed through jaxopt's core `custom_root` mechanism in order to set it up with an implicit-diff-based custom VJP. Currently, that `custom_root` mechanism assumes that the solver function it receives has, in its closure, no arrays that are involved in any of jax's differentiation or staging. Re-stated using jax-internal jargon: `custom_root` assumes that the solver function it receives does not have tracers in its closure. But: a bound Python method (e.g. `o.run`) carries its bound instance (e.g. `o`) in its closure. The code in `bisection_test.py` did not conform to this requirement that all class attributes are static (in the jax transformation sense). Specifically, it constructed a `Bisection` instance, within a jitted function, given parameters (`lower` and `upper`) that depend on inputs to the jitted function. This change fixes that by hoisting the construction of this `Bisection` out from the jitted function (and marking it a static argument). Doing this fixes a jax "tracer leak" error raised in the jaxopt CI recently. This was not an issue until jax released version 0.4.4, for the rather technical reason that jax changed its `jit` implementation such that it eagerly stages out its function argument. This in turn led jax to encounter "jit tracers" (corresponding to `Bisection.{lower,upper}`) within the closure of a solver function (`Bisection.run`) in the course of custom-differentiating the solver function. commit cb6ed9a Author: Emily Fertig <[email protected]> Date: Wed Mar 15 10:43:20 2023 -0700 Internal change PiperOrigin-RevId: 516868349 commit a4f3956 Author: Aymeric Galan <[email protected]> Date: Thu Mar 9 16:32:48 2023 +0100 Attempt to fix failing test on python 3.9 regarding 32 vs 64-bits numbers commit abe44e4 Author: Aymeric Galan <[email protected]> Date: Thu Jan 19 13:05:54 2023 +0100 Add inverse hessian approximation to the returned state Add custom pytree registration for LbfgsInvHessProduct result Fix issue with undefined class Add docstring Remove drepecated comments Fix scipy.optimize module not found LbfgsInvHessProductPyTree constructor now compliant with JAX commit 040c8fc Author: Yash Katariya <[email protected]> Date: Tue Feb 21 15:24:39 2023 -0800 Internal change PiperOrigin-RevId: 511320677 commit a51d5ed Merge: 52d56ab f65001b Author: JAXopt authors <[email protected]> Date: Fri Feb 17 04:34:13 2023 -0800 Merge pull request google#398 from mblondel:add_isotonic_module PiperOrigin-RevId: 510398823 commit 52d56ab Merge: 0472831 6e6a0ab Author: JAXopt authors <[email protected]> Date: Fri Feb 17 01:08:50 2023 -0800 Merge pull request google#397 from mblondel:remove_matplotlib PiperOrigin-RevId: 510364159 commit f65001b Author: Mathieu Blondel <[email protected]> Date: Thu Feb 16 19:49:06 2023 +0100 Add isotonic module. commit 6e6a0ab Author: Mathieu Blondel <[email protected]> Date: Thu Feb 16 19:34:24 2023 +0100 Update requirements. commit 0472831 Author: Peter Hawkins <[email protected]> Date: Thu Feb 9 09:03:52 2023 -0800 Internal change PiperOrigin-RevId: 508389531 commit e1d8355 Merge: 0c8b25b 730b5a6 Author: JAXopt authors <[email protected]> Date: Thu Feb 9 07:43:10 2023 -0800 Merge pull request google#394 from mblondel:release_0.6 PiperOrigin-RevId: 508371158 commit 730b5a6 Author: Mathieu Blondel <[email protected]> Date: Thu Feb 9 16:02:32 2023 +0100 Release v0.6.
Basically implemented the correct iMAML training setup:
Then adapted all the plots.
TPU testing remains to be done.
This fixes #417 .
@fabianp don't hesitate to tell me already if some things should be improved w.r.t for example using LBFGS as the inner solver, a more complex MLP (with swish activation), ...