diff --git a/.gitignore b/.gitignore index 4938ff7d2..e15bcc085 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ */*/mjkey.txt +**/.DS_STORE +**/*.pyc +**/*.swp diff --git a/README.md b/README.md index 91063273d..527cb79de 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,11 @@ Some implemented algorithms: - Temporal Difference Models (TDMs) - [example script](examples/tdm/cheetah.py) - [TDM paper](https://arxiv.org/abs/1802.09081) - - [Details on implementation](rlkit/torch/tdm/TDMs.md) + - [Documentation](docs/TDMs.md) + - Hindsight Experience Replay (HER) + - [example script](examples/her/her_td3_gym_fetch_reach.py) + - [HER paper](https://arxiv.org/abs/1707.01495) + - [Documentation](docs/HER.md) - Deep Deterministic Policy Gradient (DDPG) - [example script](examples/ddpg.py) - [DDPG paper](https://arxiv.org/pdf/1509.02971.pdf) @@ -84,7 +88,7 @@ Alternatively, if you don't want to clone all of `rllab`, a repository containin ```bash python viskit/viskit/frontend.py LOCAL_LOG_DIR// ``` -This `viskit` repo also has a few extra nice features, like plotting multiple Y-axis values at once, figure-splitting on multiple keys, and being able to filter hyperparametrs out. +This `viskit` repo also has a few extra nice features, like plotting multiple Y-axis values at once, figure-splitting on multiple keys, and being able to filter hyperparametrs out. ## Visualizing a TDM policy To visualize a TDM policy, run @@ -97,7 +101,7 @@ To visualize a TDM policy, run Recommended hyperparameters to tune: - `max_tau` - `reward_scale` - + ### SAC The SAC implementation provided here only uses Gaussian policy, rather than a Gaussian mixture model, as described in the original SAC paper. Recommended hyperparameters to tune: diff --git a/docs/HER.md b/docs/HER.md new file mode 100644 index 000000000..823ddbad6 --- /dev/null +++ b/docs/HER.md @@ -0,0 +1,93 @@ +# Hindsight Experience Replay +Some notes on the implementation of +[Hindsight Experience Replay](https://arxiv.org/abs/1707.01495). +## Expected Results +If you run the [Fetch example](examples/her/her_td3_gym_fetch_reach.py), then + you should get results like this: + ![Fetch HER results](docs/images/FetchReach-v1_HER-TD3.png) + +If you run the [Sawyer example](examples/her/her_td3_multiworld_sawyer_reach.py) +, then you should get results like this: + ![Sawyer HER results](docs/images/SawyerReachXYZEnv-v0_HER-TD3.png) + +Note that these examples use HER combined with TD3, and not DDPG. +TD3 is a new method that came out after the HER paper, and it seems to work +better than DDPG. + +## Goal-based environments and `ObsDictRelabelingBuffer` +Some algorithms, like HER, are for goal-conditioned environments, like +the [OpenAI Gym GoalEnv](https://blog.openai.com/ingredients-for-robotics-research/) +or the [multiworld MultitaskEnv](https://github.com/vitchyr/multiworld/) +environments. + +These environments are different from normal gym environments in that they +return dictionaries for observations, like so: the environments work like this: + +``` +env = CarEnv() +obs = env.reset() +next_obs, reward, done, info = env.step(action) +print(obs) + +# Output: +# { +# 'observation': ..., +# 'desired_goal': ..., +# 'achieved_goal': ..., +# } +``` +The `GoalEnv` environments also have a function with signature +``` +def compute_rewards (achieved_goal, desired_goal): + # achieved_goal and desired_goal are vectors +``` +while the `MultitaskEnv` has a signature like +``` +def compute_rewards (observation, action, next_observation): + # observation and next_observations are dictionaries +``` +To learn more about these environments, check out the URLs above. +This means that normal RL algorithms won't even "type check" with these +environments. + +`ObsDictRelabelingBuffer` perform hindsight experience replay with +either types of environments and works by saving specific values in the +observation dictionary. + +## Implementation Difference +This HER implemention is slightly different from the one presented in the paper. +Rather than relabeling goals when saving data to the replay buffer, the goals +are relabeled when sampling from the replay buffer. + + +In other words, HER in the paper does this: + + Data collection + 1. Sample $(s, a, r, s', g) ~ \text\{ENV}$. + 2. Save $(s, a, r, s', g)$ into replay buffer $\mathcal B$. + For i = 1, ..., K: + Sample $g_i$ using the future strategy. + Recompute rewards $r_i = f(s', g_i)$. + Save $(s, a, r_i, s', g_)$ into replay buffer $\mathcal B$. + Train time + 1. Sample $(s, a, r, s', g)$ from replay buffer + 2. Train Q function $(s, a, r, s', g)$ + +The implementation here does: + + Data collection + 1. Sample $(s, a, r, s', g) ~ \text\{ENV}$. + 2. Save $(s, a, r, s', g)$ into replay buffer $\mathcal B$. + Train time + 1. Sample $(s, a, r, s', g)$ from replay buffer + 2a. With probability 1/(K+1): + Train Q function $(s, a, r, s', g)$ + 2b. With probability 1 - 1/(K+1): + Sample $g'$ using the future strategy. + Recompute rewards $r' = f(s', g')$. + Train Q function on $(s, a, r', s', g')$ + +Both implementations effective do the same thing: with probability 1/(K+1), +you train the policy on the goal used during rollout. Otherwise, train the +policy on a resampled goal. + diff --git a/rlkit/torch/tdm/TDMs.md b/docs/TDMs.md similarity index 100% rename from rlkit/torch/tdm/TDMs.md rename to docs/TDMs.md diff --git a/docs/images/FetchReacher-v0_HER-TD3.png b/docs/images/FetchReacher-v0_HER-TD3.png new file mode 100644 index 000000000..ad1b8668c Binary files /dev/null and b/docs/images/FetchReacher-v0_HER-TD3.png differ diff --git a/docs/images/SawyerReachXYZEnv-v0_HER-TD3.png b/docs/images/SawyerReachXYZEnv-v0_HER-TD3.png new file mode 100644 index 000000000..a76bd8b6a Binary files /dev/null and b/docs/images/SawyerReachXYZEnv-v0_HER-TD3.png differ diff --git a/examples/her/her_td3_multiworld_sawyer_reach.py b/examples/her/her_td3_multiworld_sawyer_reach.py index 23ed27e21..7414f9673 100644 --- a/examples/her/her_td3_multiworld_sawyer_reach.py +++ b/examples/her/her_td3_multiworld_sawyer_reach.py @@ -22,7 +22,7 @@ def experiment(variant): - env = gym.make('SawyerReachXYEnv-v1') + env = gym.make('SawyerReachXYZEnv-v0') es = GaussianAndEpislonStrategy( action_space=env.action_space, max_sigma=.2,