Skip to content

Commit

Permalink
fixed imaml tutorial (speed and correctness): phase application, data…
Browse files Browse the repository at this point in the history
… generation, outer loss computation
  • Loading branch information
zaccharieramzi committed May 8, 2023
1 parent 1019f7b commit f6c0ca0
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 217 deletions.
304 changes: 166 additions & 138 deletions docs/notebooks/implicit_diff/maml.ipynb

Large diffs are not rendered by default.

193 changes: 114 additions & 79 deletions docs/notebooks/implicit_diff/maml.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.4
jupytext_version: 1.14.5
kernelspec:
display_name: Python 3
language: python
Expand Down Expand Up @@ -37,17 +37,18 @@ This notebook shows how to use Model Agnostic Meta-Learning (MAML) for few-shot

```{code-cell} ipython3
%%capture
%pip install jaxopt flax
%pip install jaxopt flax matplotlib tqdm
```

```{code-cell} ipython3
from functools import partial
from typing import Any, Sequence
# activate TPUs if available
try:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
except KeyError:
except (KeyError, RuntimeError):
print("TPU not found, continuing without it.")
from jax.config import config
Expand All @@ -56,15 +57,19 @@ config.update("jax_enable_x64", True)
import jax
from jax import numpy as jnp
from jax import random
from jax import vmap
from jax.tree_util import Partial, tree_map
from jaxopt import LBFGS
from jaxopt import LBFGS, GradientDescent
from jaxopt import linear_solve
from jaxopt import OptaxSolver
from jaxopt import tree_util
import optax
# we use flax to construct a small multi-layer perceptron
from flax import linen as nn
from tqdm.auto import tqdm
# for plotting
import matplotlib.pyplot as plt
Expand All @@ -79,7 +84,7 @@ MAX_X = 5
# amount of L2 regularization. Higher regularization values will promote
# task parameters that are closer to each other.
L2REG = 1e-2
L2REG = 2 # similar to that of the paper
# for bigger plots
plt.rcParams.update({"figure.figsize": (12, 6)})
Expand All @@ -101,14 +106,28 @@ def generate_task(key, n_samples_train=6, n_samples_test=6, min_phase=0.5, max_p
key, _ = random.split(key)
x_train = random.uniform(key, shape=(n_samples_train,)) * (MAX_X - MIN_X) + MIN_X
x_train = x_train.reshape((-1, 1)) # Reshape to feed into MLP later
y_train = jnp.sin(phase * x_train) * amplitude
y_train = jnp.sin(x_train - phase) * amplitude
key, _ = random.split(key)
x_test = random.uniform(key, shape=(n_samples_test,)) * (MAX_X - MIN_X) + MIN_X
x_test = x_test.reshape((-1, 1)) # Reshape to feed into MLP later
y_test = jnp.sin(phase * x_test) * amplitude
y_test = jnp.sin(x_test - phase) * amplitude
return (x_train, y_train), (x_test, y_test), phase, amplitude
# the above function generates a single task
# the next function should generate a metabatch of tasks in a vectorized fashion (that is without a for loop)
# the tasks should be batched in the first dimension
@partial(jax.jit, static_argnums=(1, 2, 3))
def generate_task_batch(key, meta_batch_size=25, n_samples_train=6, n_samples_test=6, min_phase=0.5, max_phase=jnp.pi, min_amplitude=0.1, max_amplitude=0.5):
"""Generate a batch of toy 1-D regression datasets."""
keys = random.split(key, meta_batch_size)
tasks = vmap(
generate_task,
in_axes=(0, None, None, None, None, None, None),
)(keys, n_samples_train, n_samples_test, min_phase, max_phase, min_amplitude, max_amplitude)
return tasks
```

```{code-cell} ipython3
Expand All @@ -117,17 +136,15 @@ fig = plt.figure(figsize=(12, 6))
colors = cm.Set2(jnp.linspace(0, 1, n_tasks))
data_tasks = []
data_train, data_test, phase, amplitude = generate_task_batch(random.PRNGKey(0), meta_batch_size=n_tasks)
for task in range(n_tasks):
key, subkey = random.split(key)
data_train, data_test, phase, amplitude = generate_task(key)
# save the samples for later
data_tasks.append((data_train, data_test))
phase_ = phase[task]
amplitude_ = amplitude[task]
# generate the ground truth regression curve for plotting
xs = jnp.linspace(MIN_X, MAX_X, 100)
ys = jnp.sin(phase * xs) * amplitude
ys = jnp.sin(xs-phase_) * amplitude_
plt.plot(xs, ys, linewidth=4, label=f'ground truth for task {task+1}', color=colors[task])
plt.xlim((MIN_X, MAX_X))
Expand All @@ -149,7 +166,8 @@ We call each one of the curves above a "task". For each task, we have access to

```{code-cell} ipython3
fig = plt.figure(figsize=(12, 6))
for task, ((x_train, y_train), (x_test, y_test)), in enumerate(data_tasks):
for task in range(n_tasks):
((x_train, y_train), (x_test, y_test)) = (data_train[0][task], data_train[1][task]), (data_test[0][task], data_test[1][task])
plt.scatter(x_train, y_train, marker='o', s=50, label=f"Training samples for task {task+1}", color=colors[task])
plt.scatter(x_test, y_test, marker='^', s=80, label=f"Test samples for task {task+1}", color=colors[task])
plt.xlabel('x')
Expand Down Expand Up @@ -185,119 +203,136 @@ class SimpleMLP(nn.Module):
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
x = nn.Dense(feat, name=f'layers_{i}', param_dtype=self.dtype)(x)
x = nn.Dense(
feat,
name=f'layers_{i}',
param_dtype=self.dtype,
)(x)
if i != len(self.features) - 1:
x = nn.swish(x)
x = nn.relu(x)
return x
```

```{code-cell} ipython3
key, subkey = random.split(random.PRNGKey(0), 2)
dummy_input = random.uniform(key, (1,), dtype=jnp.float64)
model = SimpleMLP(features=[20, 20, 20, 1], dtype=jnp.float64)
model = SimpleMLP(features=[40, 40, 1], dtype=jnp.float64)
```

The regressor is a neural network model with 2 hidden layers of size 40 with ReLU nonlinearities.

```{code-cell} ipython3
def inner_loss(x, outer_parameters, data, regularization=L2REG):
# x are the task adapted parameters phi prime in the original paper
# outer_parameters are the meta parameters, theta bold in the original paper
samples, targets = data
prediction = model.apply(x, samples)
mse = jnp.mean((prediction - targets)**2)
mse = jnp.mean((prediction - targets)**2) # this is L(phi_prime, D^{tr}_i)
x_m_outer_parameters = tree_util.tree_add_scalar_mul(x, -1, outer_parameters)
reg = (regularization / 2) * tree_util.tree_l2_norm(x_m_outer_parameters, squared=True)
# this \lambda/2 ||phi_prime - theta_bold||^2
return mse + reg
def outer_loss(outer_params, inner_parameters, data):
inner_solver = GradientDescent(
inner_loss,
stepsize=-1, # using line search
maxiter=16,
tol=1e-12,
maxls=15,
acceleration=False,
implicit_diff=True,
implicit_diff_solve=Partial(
linear_solve.solve_cg,
maxiter=5,
tol=1e-7,
),
)
def outer_loss(meta_params, data_train, data_test):
# inner parameters is passed
# iterate on the first K-1 tasks
loss = 0.
task_params = []
for (data_train, _), in_params in zip(data, inner_parameters):
lbfgs = LBFGS(inner_loss, maxiter=2000, tol=1e-12)
in_params_sol, _ = lbfgs.run(in_params, outer_params, data_train)
prediction = model.apply(in_params_sol, x_test)
loss += jnp.mean((prediction - y_test)**2)
task_params.append(in_params_sol)
return loss, task_params
in_params_sol, _ = vmap(inner_solver.run, (None, None, 0))(
jax.lax.stop_gradient(meta_params),
meta_params,
data_train,
) # Alg^*(\theta_bold, D^{tr}_i)
samples_test, targets_test = data_test
prediction = vmap(model.apply)(in_params_sol, samples_test)
loss = jnp.mean((prediction - targets_test)**2) # L(\phi, D^{te}_i)
return loss, in_params_sol
```

```{code-cell} ipython3
key, subkey = random.split(random.PRNGKey(0), 2)
# initialize inner and outer params
inner_params = []
for _ in data_tasks:
key, subkey = random.split(key)
inner_params.append(model.init(key, dummy_input))
meta_params = model.init(key, dummy_input)
key, subkey = random.split(key)
outer_params = model.init(key, dummy_input)
gradient_subopt = []
solver = OptaxSolver(opt=optax.adam(1e-3), fun=outer_loss, maxiter=100, has_aux=True)
state = solver.init_state(outer_params, inner_params, data_tasks[:-1])
outer_losses = []
solver = OptaxSolver(
opt=optax.adam(1e-3),
fun=outer_loss,
maxiter=1000,
has_aux=True,
tol=1e-6,
)
data_train, data_test, phase, amplitude = generate_task_batch(
key,
meta_batch_size=2,
n_samples_train=10,
n_samples_test=10,
)
state = solver.init_state(meta_params, data_train, data_test)
jitted_update = jax.jit(solver.update)
for it in range(solver.maxiter):
outer_params, state = jitted_update(outer_params, state, state.aux, data_tasks[:-1])
gradient_subopt.append(solver.l2_optimality_error(outer_params, state.aux, data_tasks[:-1]))
pbar = tqdm(range(solver.maxiter))
for it in pbar:
key, subkey = random.split(key)
data_train, data_test, phase, amplitude = generate_task_batch(
key,
meta_batch_size=25,
n_samples_train=10,
n_samples_test=10,
)
meta_params, state = jitted_update(meta_params, state, data_train, data_test)
outer_losses.append(state.value)
pbar.set_description(f"Outer loss {state.value:.3f}")
```

```{code-cell} ipython3
xx = jnp.linspace(MIN_X, MAX_X, 200)
plt.title(f'Training data and predictive model')
for task, ((x_train, y_train), test) in enumerate(data_tasks[:-1]):
plt.scatter(x_train, y_train, marker='o', s=50, label=f"Training samples for task {task+1}", color=colors[task])
prediction = jax.lax.stop_gradient(model.apply(state.aux[task], xx.reshape((-1, 1))))
plt.plot(xx, prediction.ravel(), color=colors[task], lw=3, label=f"Prediction of model trained on task {task+1}")
for task in range(n_tasks):
((x_train, y_train), (x_test, y_test)) = (data_train[0][task], data_train[1][task]), (data_test[0][task], data_test[1][task])
in_params_sol, _ = inner_solver.run(
jax.lax.stop_gradient(meta_params),
meta_params,
(x_train, y_train),
)
plt.scatter(x_train, y_train, marker='o', s=50, label=f"Training samples for task {task+1}", color=colors[task])
prediction = jax.lax.stop_gradient(model.apply(in_params_sol, xx.reshape((-1, 1))))
plt.plot(xx, prediction.ravel(), color=colors[task], lw=3, label=f"Prediction of model trained on task {task+1}")
phase_, amplitude_ = phase[task], amplitude[task]
plt.plot(xx, amplitude_ * jnp.sin(xx - phase_), color="black", lw=2, alpha=0.7, ls='--', label=f"True function for task {task+1}")
plt.legend(loc='upper center', fontsize=14, bbox_to_anchor=(0.5, -0.1), frameon=False, ncol=3)
plt.xlabel('x')
plt.ylabel('y')
plt.show()
plt.title('Gradient suboptimality')
plt.plot(gradient_subopt, lw=3)
plt.title('Meta-learning curve')
plt.plot(outer_losses, lw=3)
plt.yscale('log')
plt.ylabel("gradient norm of outer objective")
plt.ylabel("outer loss")
plt.xlabel("iterations")
plt.grid()
plt.show()
```

```{code-cell} ipython3
# now we consider the last task, which we haven't used for training
train, test = data_tasks[-1]
params, _ = LBFGS(
fun=inner_loss, tol=1e-12, maxiter=2000).run(
model.init(key, dummy_input), outer_params, train)
params_without_regularization, _ = LBFGS(
fun=inner_loss, tol=1e-12, maxiter=2000).run(
model.init(key, dummy_input), outer_params, train, regularization=0)
```

```{code-cell} ipython3
xx = jnp.linspace(MIN_X, MAX_X, 200)
prediction = model.apply(params, xx.reshape((-1, 1)))
prediction_without_regularization = model.apply(params_without_regularization, xx.reshape((-1, 1)))
plt.title(" Fit with only 3 samples")
plt.plot(xx, prediction.ravel(), lw=3, label="with meta-learning")
plt.plot(xx, prediction_without_regularization.ravel(), lw=3, label="without meta-learning")
plt.scatter(train[0], train[1], marker='^', s=200, label=f"Training samples", color=colors[-1])
plt.plot(xs, ys, lw=3, label="ground truth")
plt.ylim((-1, 1))
plt.xlim(MIN_X, MAX_X)
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), frameon=False, ncol=2)
plt.grid()
plt.show()
```

```{code-cell} ipython3
```
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ myst-nb
tensorflow-datasets
dm-haiku
flax
jupytext

0 comments on commit f6c0ca0

Please sign in to comment.