-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_d4rl_iql.py
100 lines (77 loc) · 3.49 KB
/
run_d4rl_iql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
from absl import app, flags
from functools import partial
import numpy as np
import tqdm
import iql as learner
import d4rl_utils
from jaxrl_m.wandb import setup_wandb, default_wandb_config, get_flag_dict
import wandb
from jaxrl_m.evaluation import supply_rng, evaluate
from ml_collections import config_flags
import pickle
from flax.training import checkpoints
FLAGS = flags.FLAGS
flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')
flags.DEFINE_string('save_dir', None, 'Logging dir (if not None, save params).')
flags.DEFINE_integer('seed', np.random.choice(1000000), 'Random seed.')
flags.DEFINE_integer('eval_episodes', 10,
'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 10000, 'Eval interval.')
flags.DEFINE_integer('save_interval', 25000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
wandb_config = default_wandb_config()
wandb_config.update({
'project': 'd4rl_test',
'group': 'iql_test',
'name': 'iql_{env_name}',
})
config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
config_flags.DEFINE_config_dict('config', learner.get_default_config(), lock_config=False)
def get_normalization(dataset):
returns = []
ret = 0
for r, term in zip(dataset['rewards'], dataset['dones_float']):
ret += r
if term:
returns.append(ret)
ret = 0
return (max(returns) - min(returns)) / 1000
def main(_):
# Create wandb logger
setup_wandb(FLAGS.config.to_dict(), **FLAGS.wandb)
if FLAGS.save_dir is not None:
FLAGS.save_dir = os.path.join(FLAGS.save_dir, wandb.run.project, wandb.config.exp_prefix, wandb.config.experiment_id)
os.makedirs(FLAGS.save_dir, exist_ok=True)
print(f'Saving config to {FLAGS.save_dir}/config.pkl')
with open(os.path.join(FLAGS.save_dir, 'config.pkl'), 'wb') as f:
pickle.dump(get_flag_dict(), f)
env = d4rl_utils.make_env(FLAGS.env_name)
dataset = d4rl_utils.get_dataset(env)
normalizing_factor = get_normalization(dataset)
dataset = dataset.copy({'rewards': dataset['rewards'] / normalizing_factor})
example_batch = dataset.sample(1)
agent = learner.create_learner(FLAGS.seed,
example_batch['observations'],
example_batch['actions'],
max_steps=FLAGS.max_steps,
**FLAGS.config)
for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
smoothing=0.1,
dynamic_ncols=True):
batch = dataset.sample(FLAGS.batch_size)
agent, update_info = agent.update(batch)
if i % FLAGS.log_interval == 0:
train_metrics = {f'training/{k}': v for k, v in update_info.items()}
wandb.log(train_metrics, step=i)
if i % FLAGS.eval_interval == 0:
policy_fn = partial(supply_rng(agent.sample_actions), temperature=0.0)
eval_info = evaluate(policy_fn, env, num_episodes=FLAGS.eval_episodes)
eval_metrics = {f'evaluation/{k}': v for k, v in eval_info.items()}
wandb.log(eval_metrics, step=i)
if i % FLAGS.save_interval == 0 and FLAGS.save_dir is not None:
checkpoints.save_checkpoint(FLAGS.save_dir, agent, i)
if __name__ == '__main__':
app.run(main)