-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathEMRLD_HC.py
executable file
·372 lines (337 loc) · 14.8 KB
/
EMRLD_HC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
#!/usr/bin/env python3
"""
EMRLD
"""
import random
from copy import deepcopy
import pickle
import cherry as ch
import gym
import numpy as np
import torch
from cherry.algorithms import a2c, trpo
from cherry.models.robotics import LinearValue
from torch import autograd
from torch.distributions.kl import kl_divergence
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from tqdm import tqdm
import matplotlib.pyplot as plt
import learn2learn as l2l
from policies import DiagNormalPolicy
from torch.utils.tensorboard import SummaryWriter
import datetime
import time
import argparse
from envs.halfcheetah_forward_backward import HalfCheetahForwardBackwardEnv
def compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states):
# Update baseline
returns = ch.td.discount(gamma, rewards, dones)
baseline.fit(states, returns)
values = baseline(states)
next_values = baseline(next_states)
bootstraps = values * (1.0 - dones) + next_values * dones
next_value = torch.zeros(1, device=values.device)
return ch.pg.generalized_advantage(tau=tau,
gamma=gamma,
rewards=rewards,
dones=dones,
values=bootstraps,
next_value=next_value)
def maml_a2c_loss(train_episodes, learner, baseline, gamma, tau, task_data, w_a2c, w_bc):
# Update policy and baseline
demo_states = torch.tensor(task_data['state']).float()
demo_actions = torch.tensor(task_data['action']).float()
demo_adv = torch.ones((demo_states.shape[0], 1))
demo_log_probs = learner.log_prob(demo_states, demo_actions)
if train_episodes is None:
bc_loss = a2c.policy_loss(demo_log_probs, demo_adv)
return bc_loss
else:
states = train_episodes.state()
actions = train_episodes.action()
rewards = train_episodes.reward()
dones = train_episodes.done()
next_states = train_episodes.next_state()
advantages = compute_advantages(baseline, tau, gamma, rewards,
dones, states, next_states)
advantages = ch.normalize(advantages).detach()
log_probs = learner.log_prob(states, actions)
bc_loss_2 = a2c.policy_loss(demo_log_probs, w_bc * demo_adv)
a2c_loss = a2c.policy_loss(log_probs, w_a2c * advantages)
return bc_loss_2 + a2c_loss
def fast_adapt_a2c(clone, train_episodes, adapt_lr,adapt_a2c_lr, baseline, gamma, tau, task_data, w_a2c, w_bc,
first_order=False):
if ((train_episodes is not None) and (adapt_a2c_lr > 0)):
adapt_lr = adapt_a2c_lr
second_order = not first_order
loss = maml_a2c_loss(train_episodes, clone, baseline, gamma, tau, task_data, w_a2c, w_bc)
gradients = autograd.grad(loss,
clone.parameters(),
retain_graph=second_order,
create_graph=second_order)
return l2l.algorithms.maml.maml_update(clone, adapt_lr, gradients)
def meta_surrogate_loss(iteration_replays, iteration_policies, policy, baseline, tau, gamma, adapt_lr, adapt_a2c_lr,traj_data,
iteration, w_a2c, w_bc):
mean_loss = 0.0
mean_kl = 0.0
task_id = 0
for task_replays, old_policy in tqdm(zip(iteration_replays, iteration_policies),
total=len(iteration_replays),
desc='Surrogate Loss',
leave=False):
train_replays = task_replays[:-1]
valid_episodes = task_replays[-1]
new_policy = l2l.clone_module(policy)
task_data = traj_data[task_id]
task_id += 1
# Fast Adapt
for train_episodes in train_replays:
new_policy = fast_adapt_a2c(new_policy, train_episodes, adapt_lr,adapt_a2c_lr,
baseline, gamma, tau, task_data, w_a2c, w_bc, first_order=False)
# Useful values
states = valid_episodes.state()
actions = valid_episodes.action()
next_states = valid_episodes.next_state()
rewards = valid_episodes.reward()
dones = valid_episodes.done()
# Compute KL
old_densities = old_policy.density(states)
new_densities = new_policy.density(states)
kl = kl_divergence(new_densities, old_densities).mean()
mean_kl += kl
# Compute Surrogate Loss
advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states)
advantages = ch.normalize(advantages).detach()
old_log_probs = old_densities.log_prob(actions).mean(dim=1, keepdim=True).detach()
new_log_probs = new_densities.log_prob(actions).mean(dim=1, keepdim=True)
mean_loss += trpo.policy_loss(new_log_probs, old_log_probs, advantages)
mean_kl /= len(iteration_replays)
mean_loss /= len(iteration_replays)
return mean_loss, mean_kl
def main(
env_name='HalfCheetahForwardBackward-v1',
adapt_lr=0.1,
meta_lr=1.0,
adapt_steps = 2,
num_iterations=500,
meta_bsz=24,
adapt_bsz=20,
tau=1.00,
gamma=0.95,
seed=42,
num_workers=10,
cuda=0,
gpu_index=0,
is_sparse=False,
w_a2c=1.0,
w_bc=1.0,
load=False,
policy_path='',
baseline_path='',
adapt_a2c_lr=-1,
traj_data=None,
sparse_val=0
):
cuda = bool(cuda)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
device_name = 'cpu'
log_path = 'Results/EMRLD/{}/meta_rl_{}'.format(env_name,datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
writer = SummaryWriter(log_path)
writer.add_text('env_name', str(env_name))
writer.add_text('adapt_lr', str(adapt_lr))
writer.add_text('adapt_a2c_lr', str(adapt_a2c_lr))
writer.add_text('meta_lr', str(meta_lr))
writer.add_text('num_iterations', str(num_iterations))
writer.add_text('adapt_bsz', str(adapt_bsz))
writer.add_text('adapt_steps', str(adapt_steps))
writer.add_text('num_workers', str(num_workers))
writer.add_text('seed', str(seed))
writer.add_text('w_a2c', str(w_a2c))
writer.add_text('w_bc', str(w_bc))
writer.add_text('sparse_val', str(sparse_val))
writer.add_text('Traj',str(args.num_traj))
if cuda:
torch.cuda.manual_seed(seed)
device_name = 'cuda'
device = torch.device(device_name)
def make_env():
env = env = HalfCheetahForwardBackwardEnv(sparse_val = sparse_val)
env = ch.envs.ActionSpaceScaler(env)
return env
env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)])
env.seed(seed)
env.set_task(env.sample_tasks(1)[0])
env = ch.envs.Torch(env)
policy = DiagNormalPolicy(env.state_size, env.action_size, device=device)
if cuda:
policy = policy.to(device)
baseline = LinearValue(env.state_size, env.action_size)
meta_bsz = len(traj_data)
writer.add_text('meta_bsz', str(meta_bsz))
print('Meta Batch Size: ',meta_bsz)
# for logging
net_intermediate_reward = []
net_adaptation_reward = []
adapt_time = []
meta_update_time = []
if load:
policy.load_state_dict(torch.load(policy_path))
baseline.load_state_dict(torch.load(baseline_path))
for iteration in range(num_iterations):
t0 = time.time()
iteration_reward = 0.0
bc_reward = 0.0
iteration_replays = []
iteration_policies = []
test_episodes = []
task_config_list = []
for task_config in tqdm(traj_data, leave=False, desc='Data'): # Samples a new config #For
# each iteration sample a new set of tasks
task_data = traj_data[task_config]
clone = deepcopy(policy)
goal_task = {'direction': task_data['direction']}
env.set_task(goal_task)
env.reset()
task = ch.envs.Runner(env)
task_replay = []
for step in range(adapt_steps):
### BC-A2C Adapt Step ###
train_episodes = task.run(clone, episodes=adapt_bsz)
if step == 0:
bc_reward += train_episodes.reward().sum().item() / adapt_bsz
if cuda:
train_episodes = train_episodes.to(device, non_blocking=True)
task_replay.append(train_episodes)
clone = fast_adapt_a2c(clone, train_episodes, adapt_lr,adapt_a2c_lr,
baseline, gamma, tau, task_data, w_a2c, w_bc,
first_order=True)
valid_episodes = task.run(clone, episodes=adapt_bsz)
task_replay.append(valid_episodes)
iteration_reward += valid_episodes.reward().sum().item() / adapt_bsz
iteration_replays.append(task_replay)
iteration_policies.append(clone)
# Print statistics
t1 = time.time()
print('\nIteration', iteration)
adaptation_reward = iteration_reward / meta_bsz
iteration_bc_reward = bc_reward / meta_bsz
print('BC_reward', iteration_bc_reward)
print('adaptation_reward', adaptation_reward)
writer.add_scalar('avg adaptation reward', adaptation_reward, iteration + 1)
writer.add_scalar('avg BC_reward', iteration_bc_reward, iteration + 1)
adapt_time.append(t1 - t0)
net_intermediate_reward.append(iteration_bc_reward)
net_adaptation_reward.append(adaptation_reward)
# TRPO meta-optimization
backtrack_factor = 0.5
ls_max_steps = 15
max_kl = 0.01
if cuda:
policy = policy.to(device, non_blocking=True)
baseline = baseline.to(device, non_blocking=True)
iteration_replays = [[r.to(device, non_blocking=True) for r in task_replays] for task_replays in
iteration_replays]
# Compute CG step direction
old_loss, old_kl = meta_surrogate_loss(iteration_replays, iteration_policies, policy, baseline, tau, gamma,
adapt_lr,adapt_a2c_lr, traj_data, iteration, w_a2c, w_bc)
grad = autograd.grad(old_loss,
policy.parameters(),
retain_graph=True)
grad = parameters_to_vector([g.detach() for g in grad])
Fvp = trpo.hessian_vector_product(old_kl, policy.parameters())
step = trpo.conjugate_gradient(Fvp, grad)
shs = 0.5 * torch.dot(step, Fvp(step))
lagrange_multiplier = torch.sqrt(shs / max_kl)
step = step / lagrange_multiplier
step_ = [torch.zeros_like(p.data) for p in policy.parameters()]
vector_to_parameters(step, step_)
step = step_
del old_kl, Fvp, grad
old_loss.detach_()
# Line-search
for ls_step in range(ls_max_steps):
stepsize = backtrack_factor ** ls_step * meta_lr
clone = deepcopy(policy)
for p, u in zip(clone.parameters(), step):
p.data.add_(-stepsize, u.data)
new_loss, kl = meta_surrogate_loss(iteration_replays, iteration_policies, clone, baseline, tau, gamma,
adapt_lr,adapt_a2c_lr, traj_data, iteration, w_a2c, w_bc)
if new_loss < old_loss and kl < max_kl:
for p, u in zip(policy.parameters(), step):
p.data.add_(-stepsize, u.data)
break
t2 = time.time()
meta_update_time.append(t2 - t1)
print('Time per iteration: ',t2 - t0)
env.close()
return 1
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='EMRLD')
parser.add_argument('--adapt-lr', type=float, default=0.01, metavar='G',
help='adapt-lr')
parser.add_argument('--adapt-a2c-lr', type=float, default=-1, metavar='G',
help='adapt-a2c-lr')
parser.add_argument('--w-a2c', type=float, default=0.2, metavar='G',
help='w-rl')
parser.add_argument('--w-bc', type=float, default=1.0, metavar='G',
help='w-bc')
parser.add_argument('--workers', type=int, default=15, metavar='N',
help='Number of workers')
parser.add_argument('--adapt-bsz', type=int, default=20, metavar='N',
help='adapt_bsz')
parser.add_argument('--meta-bsz', type=int, default=10, metavar='N',
help='meta_bsz')
parser.add_argument('--adapt-steps', type=int, default=1, metavar='N',
help='Adapt Steps')
parser.add_argument('--seed', type=int, default=42, metavar='N',
help='Seed')
parser.add_argument('--max-iter', type=int, default=700, metavar='N',
help='max-iter')
parser.add_argument('--load', action='store_true', default=False,
help='Load policy')
parser.add_argument('--sparse-val', type=float, default=2., metavar='G',
help='Sparse Val')
parser.add_argument('--num-traj', type=int, default=1, metavar='N',
help='Number of Trajectories per task')
parser.add_argument('--exp-num', type=int, default=1, metavar='N',
help='1: Good 2:Bad')
args = parser.parse_args()
args.policy_path = ''
args.baseline_path = ''
if args.exp_num == 1:
fwd_data_path = 'Traj_Data/Good_Fwd_HC.p'
bwd_data_path = 'Traj_Data/Good_Bwd_HC.p'
elif args.exp_num == 2:
fwd_data_path = 'Traj_Data/Bad_Fwd_HC.p'
bwd_data_path = 'Traj_Data/Bad_Bwd_HC.p'
args.w_a2c = 1.0
else:
print("Running using non default values")
fwd_data = pickle.load(open(fwd_data_path, 'rb'))
bwd_data = pickle.load(open(bwd_data_path, 'rb'))
fwd_data['state'] = fwd_data['state'][:args.num_traj*100]
fwd_data['action'] = fwd_data['action'][:args.num_traj*100]
bwd_data['state'] = bwd_data['state'][:args.num_traj*100]
bwd_data['action'] = bwd_data['action'][:args.num_traj*100]
traj_data = {}
for i in range(args.meta_bsz):
if (i%2 == 0):
traj_data[i] = fwd_data
else:
traj_data[i] = bwd_data
a = main(w_a2c=args.w_a2c,
w_bc=args.w_bc,
num_workers=args.workers,
load=args.load,
policy_path=args.policy_path,
baseline_path=args.baseline_path,
adapt_steps=args.adapt_steps,
adapt_lr=args.adapt_lr,
adapt_a2c_lr=args.adapt_a2c_lr,
seed=args.seed,
adapt_bsz=args.adapt_bsz,
traj_data=traj_data,
num_iterations=args.max_iter,
sparse_val = args.sparse_val)