-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmaml.py
245 lines (221 loc) · 10.9 KB
/
maml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import json
import os
import numpy as np
import torch
from torch import nn
class MAML():
def __init__(self,
model,
num_inner_steps,
inner_lr,
outer_lr,
first_order):
"""Initializes the MAML.
Args:
model (Model): Meta-learner.
num_inner_steps (int): Number of inner-loop optimization steps
during adaptation.
inner_lr (float): Learning rate for inner-loop optimization.
outer_lr (float): Learning rate for outer-loop optimization.
first_order (bool): Whether to use first-order approximation
when computing gradient for meta-update.
"""
self.model = model
self.model_params = list(model.parameters())
self.num_inner_steps = num_inner_steps
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.outer_optimizer = torch.optim.Adam(self.model_params,
self.outer_lr)
self.loss_function = nn.MSELoss()
self.first_order = first_order
def inner_loop(self, x_support, y_support):
"""Computes the adapted model parameters on a given support set.
Args:
x_support (Tensor): Support set inputs.
y_support (Tensor): Support set outputs.
Returns:
tuple[Tensor]: Adapted model parameters. Length `num_params`.
list[float]: Support set loss throughout inner loop.
Length `num_inner_steps` + 1.
float: Frobenius norm of initial adaptation gradient.
"""
support_losses = []
# Clone outer parameters for adaptation
adapt_params = [w.clone() for w in self.model_params]
# Perform `num_inner_steps` iterations of SGD
for inner_step in range(self.num_inner_steps+1):
# Compute support loss and adaptation gradient
pred_y_support = self.model.forward(x_support, adapt_params)
support_loss = self.loss_function(pred_y_support, y_support)
support_losses.append(support_loss.item())
adapt_grad = torch.autograd.grad(support_loss, adapt_params,
create_graph=not self.first_order)
if inner_step == 0:
init_grad = torch.concat([g.flatten() for g in adapt_grad])
init_grad_norm = torch.linalg.norm(init_grad).item()
if inner_step == self.num_inner_steps:
break
# Update adaptation parameters
adapt_params = [w - self.inner_lr * g
for w, g in zip(adapt_params, adapt_grad)]
return adapt_params, support_losses, init_grad_norm
def inner_loop_plot(self, x_support, y_support, x, inner_steps_plot):
"""Adapts the model on a given support set and periodically
generates predictions for `x`.
Args:
x_support (Tensor): Support set inputs.
y_support (Tensor): Support set outputs.
x (Tensor): Prediction inputs.
inner_steps_plot (list[int]): Inner step(s) at which
predictions for `x` are generated by the adapted model.
For example, if `inner_steps_plot` = [1, 2, 10], then
10 inner-optimization steps are performed, and
predictions are generated after 0, 1, 2, and 10 steps.
Returns:
dict[int, ndarray]: Key is inner optimization step.
Value is prediction output.
"""
pred_y = {}
# Clone outer parameters for adaptation
adapt_params = [w.clone() for w in self.model_params]
# Perform `max(inner_steps_plot)` iterations of SGD
for inner_step in range(max(inner_steps_plot)+1):
# Compute support loss and adaptation gradient
pred_y_support = self.model.forward(x_support, adapt_params)
support_loss = self.loss_function(pred_y_support, y_support)
if inner_step == 0 or inner_step in inner_steps_plot:
pred_y[inner_step] = (
self.model.forward(x, adapt_params)
).detach().cpu().numpy()
adapt_grad = torch.autograd.grad(support_loss, adapt_params,
create_graph=not self.first_order)
# Update adaptation parameters
adapt_params = [w - self.inner_lr * g
for w, g in zip(adapt_params, adapt_grad)]
return pred_y
def outer_step(self, task_batch):
"""Computes the MAML loss (i.e., mean query loss) on a batch of tasks.
Args:
task_batch (list): Batch of tasks. Length `batch_size`.
Returns:
Tensor: Query loss averaged over batch. Scalar.
ndarray: Support set loss throughout inner loop, averaged
over batch. Shape (`num_inner_steps` + 1,).
float: Frobenius norm of initial adaptation gradient,
averaged over batch.
"""
query_loss_batch = []
support_losses_batch = []
init_grad_norm_batch = []
for task in task_batch:
x_support, y_support, x_query, y_query = task
# Obtain adapted parameters from inner loop
adapt_params, support_losses, init_grad_norm = \
self.inner_loop(x_support, y_support)
support_losses_batch.append(support_losses)
init_grad_norm_batch.append(init_grad_norm)
# Compute query loss using adapted parameters
pred_y_query = self.model.forward(x_query, adapt_params)
query_loss_batch.append(self.loss_function(pred_y_query, y_query))
mean_query_loss = torch.stack(query_loss_batch).mean()
mean_support_losses = np.mean(np.array(support_losses_batch), axis=0)
mean_init_grad_norm = np.mean(np.array(init_grad_norm_batch))
return mean_query_loss, mean_support_losses, mean_init_grad_norm
def train(self, data_train, log_interval, filename="./logs/train.json"):
"""Trains the MAML.
Args:
data_train (list[list[list[Tensor]]]): Training dataset.
First dimension has length `num_train_steps//batch_size`.
Second dimension has length `batch_size`.
Innermost dimension has length 4 and consists of
support set inputs, support set outputs, query set inputs,
and support set outputs, respectively.
log_interval (int): Frequency of metric logging.
filename (str): File to which metrics are logged.
Returns:
list[float]: Query loss (averaged over batch) throughout
training loop. Length `num_train_steps`.
ndarray: Frobenius norm of initial adaptation gradient
(averaged over batch) throughout training loop.
Shape (`num_train_steps`,).
"""
log = {}
query_losses = []
init_grad_norms = []
for outer_step, task_batch in enumerate(data_train):
mean_query_loss, mean_support_losses, mean_init_grad_norm = \
self.outer_step(task_batch)
self.outer_optimizer.zero_grad()
mean_query_loss.backward()
self.outer_optimizer.step()
# Log metrics
if outer_step % log_interval == 0:
print(f"Iteration {outer_step}: ")
print("MAML loss (query set loss, batch average): ",
f"{mean_query_loss.item():.3f}")
print("Pre-adaptation support set loss (batch average): ",
f"{mean_support_losses[0]:.3f}")
print("Post-adaptation support set loss (batch average): ",
f"{mean_support_losses[-1]:.3f}")
print("Norm of initial adaptation gradient (batch average): ",
f"{mean_init_grad_norm:.3f}")
print("-"*50)
log[outer_step] = {
"Query set loss": mean_query_loss.item(),
"Pre-adaptation support set loss": mean_support_losses[0],
"Post-adaptation support set loss": mean_support_losses[-1],
"Norm of initial adaptation gradient": mean_init_grad_norm,
}
query_losses.append(mean_query_loss.item())
init_grad_norms.append(mean_init_grad_norm)
# Save metrics
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "w") as f:
json.dump(log, f)
return query_losses, init_grad_norms
def test(self, data_test, num_inner_steps=None):
"""Evaluates the MAML on test tasks.
Args:
data_test (list[list[list[Tensor]]]): Test dataset.
See description for `data_train` in train().
num_inner_steps (int): Number of inner-loop optimization steps
during adaptation. If not provided, then training-time
value is used.
Returns:
float: Query loss averaged over test tasks. Scalar.
float: Standard deviation of query loss. Scalar.
ndarray: Support set loss throughout inner loop,
averaged over test tasks. Size (`num_inner_steps + 1`,).
ndarray: Standard deviation of support set loss throughout
inner loop, averaged over test tasks.
Size (`num_inner_steps` + 1, 1).
"""
query_losses = []
support_losses = []
num_test_tasks = len(data_test) * len(data_test[0])
# Set `num_inner_steps` according to argument value
if num_inner_steps is not None:
prev_num_inner_steps = self.num_inner_steps
self.num_inner_steps = num_inner_steps
# Evaluate test tasks
for task_batch in data_test:
q_loss, s_losses, _ = self.outer_step(task_batch)
query_losses.append(q_loss.item())
support_losses.append(s_losses)
# Reset original value of `num_inner_steps`
if num_inner_steps is not None:
self.num_inner_steps = prev_num_inner_steps
# Compute statistics
mean_query_loss = np.mean(query_losses)
CI_95_query_loss = 1.96 * np.std(query_losses) / np.sqrt(num_test_tasks)
np_support_losses = np.array(support_losses)
mean_support_losses = np.mean(np_support_losses, axis=0)
CI_95_support_losses = (1.96 * np.std(np_support_losses, axis=0)
/ np.sqrt(num_test_tasks))
print(f"Evaluation statistics on {num_test_tasks} test tasks: ")
print("MAML loss:")
print(f"Mean: {mean_query_loss:.3f}")
print(f"95% confidence interval: {CI_95_query_loss:.3f}")
return (mean_query_loss, CI_95_query_loss,
mean_support_losses, CI_95_support_losses)