Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the iMAML tutorial #425

Merged
merged 1 commit into from
May 15, 2023

Conversation

zaccharieramzi
Copy link
Contributor

Basically implemented the correct iMAML training setup:

  • corrected phase in sinus
  • made sure to use more than 4 tasks for meta-training
  • used vmap for outer loss computation

Then adapted all the plots.

TPU testing remains to be done.

This fixes #417 .

@fabianp don't hesitate to tell me already if some things should be improved w.r.t for example using LBFGS as the inner solver, a more complex MLP (with swish activation), ...

@zaccharieramzi
Copy link
Contributor Author

zaccharieramzi commented Apr 30, 2023

Unfortunately I cannot test on TPU on Colab because they are not supported anymore:

RuntimeError: 
As of JAX 0.4.0, JAX only supports TPU VMs, not the older Colab TPUs.

We recommend trying Kaggle Notebooks
(https://www.kaggle.com/code, click on "New Notebook" near the top) which offer
TPU VMs. You have to create an account, log in, and verify your account to get
accelerator support.
Once you do that, there's a new "TPU 1VM v3-8" accelerator option. This gives
you a TPU notebook environment similar to Colab, but using the newer TPU VM
architecture. This should be a less buggy, more performant, and overall better
experience than the older TPU node architecture.

It is also possible to use Colab together with a self-hosted Jupyter kernel
running on a Cloud TPU VM. See
https://research.google.com/colaboratory/local-runtimes.html
for details.

@fabianp Would you like me to test on Kaggle even if the quick button still points to Colab?
Otherwise, tested on GPU it is twice as fast, so working I'd say.

@zaccharieramzi zaccharieramzi marked this pull request as ready for review April 30, 2023 15:51
@zaccharieramzi
Copy link
Contributor Author

If anyone is interested, I live streamed the last part of my implementation of this PR on twitch: https://www.twitch.tv/videos/1807716437

@mblondel
Copy link
Collaborator

mblondel commented May 3, 2023

Thanks a lot for the contribution @zaccharieramzi.

@fabianp Can you take a look when you get some time? Thanks!

@fabianp
Copy link
Collaborator

fabianp commented May 3, 2023

sure will do. the streaming is awesome 😄

@fabianp
Copy link
Collaborator

fabianp commented May 7, 2023

thanks a lot for this @zaccharieramzi . Could you please run jupytext as described here (https://jaxopt.github.io/stable/developer.html#syncing-notebooks) so that we can see more easily the diff ?

@zaccharieramzi
Copy link
Contributor Author

@fabianp sure done!
I also added jupytext to the docs requirements as it was not present.
An alternative to syncing the 2 versions each time would be to use ReviewNB (I use it in some of my repos it works really well): https://www.reviewnb.com/. Don't know how easy it would be to put in place though with the 2 versioning systems.

@fabianp
Copy link
Collaborator

fabianp commented May 8, 2023

Thanks a lot @zaccharieramzi, this looks great! And thanks for the link to reviewnb, it does look great indeed. We use jupytext mostly because JAX uses this same tooling.

One last thing before we merge. Could you please squash these into a single commit ? (it makes it easier for us to sync with the Google infra)

@zaccharieramzi
Copy link
Contributor Author

@fabianp sure done!

@copybara-service copybara-service bot merged commit b934387 into google:main May 15, 2023
vroulet added a commit to vroulet/jaxopt that referenced this pull request Jun 13, 2023
commit 3614817
Author: Vincent Roulet <[email protected]>
Date:   Tue Jun 13 15:07:47 2023 -0700

    added approx wolfe condition to fix bug at high precision

commit d2211bc
Author: Vincent Roulet <[email protected]>
Date:   Mon Jun 12 13:43:05 2023 -0700

    fix attempt for failed test correctness lbfgsb

commit 024bca8
Merge: c21479e 257b673
Author: Vincent Roulet <[email protected]>
Date:   Mon Jun 12 09:06:42 2023 -0700

    Merge branch 'main' into zoom_linesearch_as_iterative_linesearch_solver

commit 257b673
Merge: d40b6d7 0cfa882
Author: JAXopt authors <[email protected]>
Date:   Mon Jun 12 08:20:07 2023 -0700

    Merge pull request google#440 from mblondel:lbfb_failure

    PiperOrigin-RevId: 539658221

commit 0cfa882
Author: Mathieu Blondel <[email protected]>
Date:   Mon Jun 12 16:28:14 2023 +0200

    Drop Python 3.7 support.

commit c21479e
Merge: 675ae9d d108ddf
Author: Vincent Roulet <[email protected]>
Date:   Fri Jun 9 15:33:21 2023 -0700

    merging with main branch

commit d108ddf
Author: Srinivas Vasudevan <[email protected]>
Date:   Tue Jun 6 13:49:34 2023 -0700

    Internal change

    PiperOrigin-RevId: 538281622

commit 414b5b9
Author: Guillaume Dalle <[email protected]>
Date:   Sat May 27 09:14:40 2023 +0200

    Fix typos

commit dcae685
Author: Mathieu Blondel <[email protected]>
Date:   Fri May 26 23:26:05 2023 +0200

    Release v0.7.

commit 83c1370
Author: Chansoo Lee <[email protected]>
Date:   Fri May 26 09:21:31 2023 -0700

    Internal change

    PiperOrigin-RevId: 535636577

commit af48e7b
Author: Zaccharie Ramzi <[email protected]>
Date:   Mon May 8 20:19:37 2023 +0200

    fixed imaml tutorial (speed and correctness): phase application, data generation, outer loss computation

commit 5392a81
Author: Fabian Pedregosa <[email protected]>
Date:   Wed Feb 15 15:33:31 2023 +0100

    Misc improvements in resnet_flax example.

    * Better data augmentation, leading to 88% accuracy (from 70%)
    * Plots showing the data augmentation in action.
    * Options use the same format as distributed training examples.
    * Changed solver from Adam to SGD for better accuracy.

commit a260cfb
Author: Emily Fertig <[email protected]>
Date:   Tue Apr 25 08:47:41 2023 -0700

    Internal change

    PiperOrigin-RevId: 526978774

commit e593c89
Author: Vincent Roulet <[email protected]>
Date:   Tue Apr 4 17:54:20 2023 -0700

    Fixed prox to handle pytrees

    Fixed prox_lasso and prox_elastic_net to handle pytrees as inputs and floats for hyperparameters

    Added tests

commit 675ae9d
Author: Vincent Roulet <[email protected]>
Date:   Fri Jun 9 13:27:24 2023 -0700

    integrated new zoom linesearch in all solvers, simplifying them

commit 8da43d3
Author: Vincent Roulet <[email protected]>
Date:   Wed Jun 7 14:20:41 2023 -0700

    minor edit

commit 416e687
Author: Vincent Roulet <[email protected]>
Date:   Wed Jun 7 14:10:16 2023 -0700

    fix copyright year

commit 8f2d2c2
Author: Vincent Roulet <[email protected]>
Date:   Wed Jun 7 11:29:03 2023 -0700

    fixed dtypes

commit 8ca6e67
Author: Vincent Roulet <[email protected]>
Date:   Tue Jun 6 22:19:02 2023 -0700

    minor edits

commit 4c847ea
Author: Vincent Roulet <[email protected]>
Date:   Tue Jun 6 21:41:53 2023 -0700

    convert zoom_linesearch into an IterativeLineSearchSolver

commit d40b6d7
Author: Srinivas Vasudevan <[email protected]>
Date:   Tue Jun 6 13:49:34 2023 -0700

    Internal change

    PiperOrigin-RevId: 538281622

commit e87b9b9
Merge: 58ce7cb 3ccb6b9
Author: JAXopt authors <[email protected]>
Date:   Sat May 27 08:28:11 2023 -0700

    Merge pull request google#435 from gdalle:patch-1

    PiperOrigin-RevId: 535860339

commit 3ccb6b9
Author: Guillaume Dalle <[email protected]>
Date:   Sat May 27 09:14:40 2023 +0200

    Fix typos

commit 58ce7cb
Merge: 541bbaa 7cf0567
Author: JAXopt authors <[email protected]>
Date:   Fri May 26 14:54:44 2023 -0700

    Merge pull request google#434 from mblondel:release_0.7

    PiperOrigin-RevId: 535720960

commit 7cf0567
Author: Mathieu Blondel <[email protected]>
Date:   Fri May 26 23:26:05 2023 +0200

    Release v0.7.

commit 541bbaa
Author: Chansoo Lee <[email protected]>
Date:   Fri May 26 09:21:31 2023 -0700

    Internal change

    PiperOrigin-RevId: 535636577

commit b934387
Merge: b3b6a0d f6c0ca0
Author: JAXopt authors <[email protected]>
Date:   Mon May 15 07:04:49 2023 -0700

    Merge pull request google#425 from zaccharieramzi:fix-maml-example

    PiperOrigin-RevId: 532098566

commit f6c0ca0
Author: Zaccharie Ramzi <[email protected]>
Date:   Mon May 8 20:19:37 2023 +0200

    fixed imaml tutorial (speed and correctness): phase application, data generation, outer loss computation

commit b3b6a0d
Merge: 4aa9bc9 fff693f
Author: JAXopt authors <[email protected]>
Date:   Thu Apr 27 07:24:22 2023 -0700

    Merge pull request google#401 from fabianp:resnet_flax

    PiperOrigin-RevId: 527570739

commit 4aa9bc9
Author: Emily Fertig <[email protected]>
Date:   Tue Apr 25 08:47:41 2023 -0700

    Internal change

    PiperOrigin-RevId: 526978774

commit fff693f
Author: Fabian Pedregosa <[email protected]>
Date:   Wed Feb 15 15:33:31 2023 +0100

    Misc improvements in resnet_flax example.

    * Better data augmentation, leading to 88% accuracy (from 70%)
    * Plots showing the data augmentation in action.
    * Options use the same format as distributed training examples.
    * Changed solver from Adam to SGD for better accuracy.

commit 4edd8ac
Merge: 674a992 7da12ec
Author: JAXopt authors <[email protected]>
Date:   Wed Apr 12 14:00:46 2023 -0700

    Merge pull request google#420 from vroulet:fix_prox_pytree

    PiperOrigin-RevId: 523798658

commit 7da12ec
Author: Vincent Roulet <[email protected]>
Date:   Tue Apr 4 17:54:20 2023 -0700

    Fixed prox to handle pytrees

    Fixed prox_lasso and prox_elastic_net to handle pytrees as inputs and floats for hyperparameters

    Added tests

commit 674a992
Merge: 1019f7b 18c4bd3
Author: JAXopt authors <[email protected]>
Date:   Wed Apr 5 01:39:26 2023 -0700

    Merge pull request google#418 from LawrenceMMStewart:main

    PiperOrigin-RevId: 521986072

commit 18c4bd3
Author: LawrenceMMStewart <[email protected]>
Date:   Fri Mar 31 12:12:57 2023 +0200

    added control variate to make_perturbed_argmax

commit 1019f7b
Merge: 36d7a0d 7f54e31
Author: JAXopt authors <[email protected]>
Date:   Thu Mar 23 18:16:45 2023 -0700

    Merge pull request google#382 from aymgal:pr-hess_inv

    PiperOrigin-RevId: 519014528

commit 36d7a0d
Author: Quentin Berthet <[email protected]>
Date:   Tue Mar 21 06:08:00 2023 -0700

    Internal change

    PiperOrigin-RevId: 518250976

commit 7f54e31
Merge: a4f3956 ea8e0f1
Author: Aymeric Galan <[email protected]>
Date:   Thu Mar 16 15:13:49 2023 +0100

    Merge remote-tracking branch 'upstream/main' into pr-hess_inv

commit ea8e0f1
Merge: cb6ed9a e196ece
Author: JAXopt authors <[email protected]>
Date:   Wed Mar 15 13:40:56 2023 -0700

    Merge pull request google#412 from froystig:jit-bisect-test

    PiperOrigin-RevId: 516917371

commit e196ece
Author: Roy Frostig <[email protected]>
Date:   Wed Mar 15 17:25:40 2023 +0000

    avoid closing over dynamic jax tracers in the bisection solver

    Internally in jaxopt, we (should) try to maintain that the parameters
    to a solver class are "static" from jax's point of view.

    One reason for this is that class attributes might be read by any of
    the class' methods, including `run`. Meanwhile a bound `run` method
    serves as the solver function, which is passed through jaxopt's core
    `custom_root` mechanism in order to set it up with an
    implicit-diff-based custom VJP. Currently, that `custom_root`
    mechanism assumes that the solver function it receives has, in its
    closure, no arrays that are involved in any of jax's differentiation
    or staging. Re-stated using jax-internal jargon: `custom_root` assumes
    that the solver function it receives does not have tracers in its
    closure. But: a bound Python method (e.g. `o.run`) carries its bound
    instance (e.g. `o`) in its closure.

    The code in `bisection_test.py` did not conform to this requirement
    that all class attributes are static (in the jax transformation
    sense). Specifically, it constructed a `Bisection` instance, within a
    jitted function, given parameters (`lower` and `upper`) that depend on
    inputs to the jitted function. This change fixes that by hoisting the
    construction of this `Bisection` out from the jitted function (and
    marking it a static argument).

    Doing this fixes a jax "tracer leak" error raised in the jaxopt CI
    recently. This was not an issue until jax released version 0.4.4, for
    the rather technical reason that jax changed its `jit` implementation
    such that it eagerly stages out its function argument. This in turn
    led jax to encounter "jit tracers" (corresponding to
    `Bisection.{lower,upper}`) within the closure of a solver function
    (`Bisection.run`) in the course of custom-differentiating the solver
    function.

commit cb6ed9a
Author: Emily Fertig <[email protected]>
Date:   Wed Mar 15 10:43:20 2023 -0700

    Internal change

    PiperOrigin-RevId: 516868349

commit a4f3956
Author: Aymeric Galan <[email protected]>
Date:   Thu Mar 9 16:32:48 2023 +0100

    Attempt to fix failing test on python 3.9 regarding 32 vs 64-bits numbers

commit abe44e4
Author: Aymeric Galan <[email protected]>
Date:   Thu Jan 19 13:05:54 2023 +0100

    Add inverse hessian approximation to the returned state

    Add custom pytree registration for LbfgsInvHessProduct result

    Fix issue with undefined class

    Add docstring

    Remove drepecated comments

    Fix scipy.optimize module not found

    LbfgsInvHessProductPyTree constructor now compliant with JAX

commit 040c8fc
Author: Yash Katariya <[email protected]>
Date:   Tue Feb 21 15:24:39 2023 -0800

    Internal change

    PiperOrigin-RevId: 511320677

commit a51d5ed
Merge: 52d56ab f65001b
Author: JAXopt authors <[email protected]>
Date:   Fri Feb 17 04:34:13 2023 -0800

    Merge pull request google#398 from mblondel:add_isotonic_module

    PiperOrigin-RevId: 510398823

commit 52d56ab
Merge: 0472831 6e6a0ab
Author: JAXopt authors <[email protected]>
Date:   Fri Feb 17 01:08:50 2023 -0800

    Merge pull request google#397 from mblondel:remove_matplotlib

    PiperOrigin-RevId: 510364159

commit f65001b
Author: Mathieu Blondel <[email protected]>
Date:   Thu Feb 16 19:49:06 2023 +0100

    Add isotonic module.

commit 6e6a0ab
Author: Mathieu Blondel <[email protected]>
Date:   Thu Feb 16 19:34:24 2023 +0100

    Update requirements.

commit 0472831
Author: Peter Hawkins <[email protected]>
Date:   Thu Feb 9 09:03:52 2023 -0800

    Internal change

    PiperOrigin-RevId: 508389531

commit e1d8355
Merge: 0c8b25b 730b5a6
Author: JAXopt authors <[email protected]>
Date:   Thu Feb 9 07:43:10 2023 -0800

    Merge pull request google#394 from mblondel:release_0.6

    PiperOrigin-RevId: 508371158

commit 730b5a6
Author: Mathieu Blondel <[email protected]>
Date:   Thu Feb 9 16:02:32 2023 +0100

    Release v0.6.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Tiny issues w.r.t data sampling in few-shot iMAML example
3 participants