-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathtrain_online.py
107 lines (89 loc) · 3.53 KB
/
train_online.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#! /usr/bin/env python
import gym
import tqdm
import wandb
from absl import app, flags
from ml_collections import config_flags
import jaxrl2.extra_envs.dm_control_suite
from jaxrl2.agents import SACLearner
from jaxrl2.data import ReplayBuffer
from jaxrl2.evaluation import evaluate
from jaxrl2.wrappers import wrap_gym
FLAGS = flags.FLAGS
flags.DEFINE_string("env_name", "HalfCheetah-v2", "Environment name.")
flags.DEFINE_string("save_dir", "./tmp/", "Tensorboard logging dir.")
flags.DEFINE_integer("seed", 42, "Random seed.")
flags.DEFINE_integer("eval_episodes", 10, "Number of episodes used for evaluation.")
flags.DEFINE_integer("log_interval", 1000, "Logging interval.")
flags.DEFINE_integer("eval_interval", 5000, "Eval interval.")
flags.DEFINE_integer("batch_size", 256, "Mini batch size.")
flags.DEFINE_integer("max_steps", int(1e6), "Number of training steps.")
flags.DEFINE_integer(
"start_training", int(1e4), "Number of training steps to start training."
)
flags.DEFINE_boolean("tqdm", True, "Use tqdm progress bar.")
flags.DEFINE_boolean("wandb", True, "Log wandb.")
flags.DEFINE_boolean("save_video", False, "Save videos during evaluation.")
config_flags.DEFINE_config_file(
"config",
"configs/sac_default.py",
"File path to the training hyperparameter configuration.",
lock_config=False,
)
def main(_):
wandb.init(project="jaxrl2_online")
wandb.config.update(FLAGS)
env = gym.make(FLAGS.env_name)
env = wrap_gym(env, rescale_actions=True)
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=1)
env.seed(FLAGS.seed)
eval_env = gym.make(FLAGS.env_name)
eval_env = wrap_gym(eval_env, rescale_actions=True)
eval_env.seed(FLAGS.seed + 42)
kwargs = dict(FLAGS.config)
agent = SACLearner(FLAGS.seed, env.observation_space, env.action_space, **kwargs)
replay_buffer = ReplayBuffer(
env.observation_space, env.action_space, FLAGS.max_steps
)
replay_buffer.seed(FLAGS.seed)
observation, done = env.reset(), False
for i in tqdm.tqdm(
range(1, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm
):
if i < FLAGS.start_training:
action = env.action_space.sample()
else:
action = agent.sample_actions(observation)
next_observation, reward, done, info = env.step(action)
if not done or "TimeLimit.truncated" in info:
mask = 1.0
else:
mask = 0.0
replay_buffer.insert(
dict(
observations=observation,
actions=action,
rewards=reward,
masks=mask,
dones=done,
next_observations=next_observation,
)
)
observation = next_observation
if done:
observation, done = env.reset(), False
for k, v in info["episode"].items():
decode = {"r": "return", "l": "length", "t": "time"}
wandb.log({f"training/{decode[k]}": v}, step=i)
if i >= FLAGS.start_training:
batch = replay_buffer.sample(FLAGS.batch_size)
update_info = agent.update(batch)
if i % FLAGS.log_interval == 0:
for k, v in update_info.items():
wandb.log({f"training/{k}": v}, step=i)
if i % FLAGS.eval_interval == 0:
eval_info = evaluate(agent, eval_env, num_episodes=FLAGS.eval_episodes)
for k, v in eval_info.items():
wandb.log({f"evaluation/{k}": v}, step=i)
if __name__ == "__main__":
app.run(main)