diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 83c81f8d..246c13b3 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -21,7 +21,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v3 with: - python-version: 3.12 + python-version: 3.11 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/README.md b/README.md index 200979d4..88d7c127 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,11 @@ Additionally, we provide an [example project](https://github.com/cpnota/all-exam ## High-Quality Reference Implementations -The `autonomous-learning-library` separates reinforcement learning agents into two modules: `all.agents`, which provides flexible, high-level implementations of many common algorithms which can be adapted to new problems and environments, and `all.presets` which provides specific instansiations of these agents tuned for particular sets of environments, including Atari games, classic control tasks, and PyBullet robotics simulations. Some benchmark results showing results on-par with published results can be found below: +The `autonomous-learning-library` separates reinforcement learning agents into two modules: `all.agents`, which provides flexible, high-level implementations of many common algorithms which can be adapted to new problems and environments, and `all.presets` which provides specific instansiations of these agents tuned for particular sets of environments, including Atari games, classic control tasks, and MuJoCo/Pybullet robotics simulations. Some benchmark results showing results on-par with published results can be found below: -![atari40](benchmarks/atari40.png) -![pybullet](benchmarks/pybullet.png) +![atari40](benchmarks/atari_40m.png) +![atari40](benchmarks/mujoco_v4.png) +![pybullet](benchmarks/pybullet_v0.png) As of today, `all` contains implementations of the following deep RL algorithms: diff --git a/all/environments/pybullet.py b/all/environments/pybullet.py index 1d792ac9..db630dbb 100644 --- a/all/environments/pybullet.py +++ b/all/environments/pybullet.py @@ -5,8 +5,8 @@ class PybulletEnvironment(GymEnvironment): short_names = { "ant": "AntBulletEnv-v0", "cheetah": "HalfCheetahBulletEnv-v0", - "humanoid": "HumanoidBulletEnv-v0", "hopper": "HopperBulletEnv-v0", + "humanoid": "HumanoidBulletEnv-v0", "walker": "Walker2DBulletEnv-v0", } diff --git a/all/experiments/slurm.py b/all/experiments/slurm.py index 7e4e0903..21f029c2 100644 --- a/all/experiments/slurm.py +++ b/all/experiments/slurm.py @@ -89,10 +89,12 @@ def create_sbatch_script(self): "output": os.path.join(self.outdir, "all_%A_%a.out"), "error": os.path.join(self.outdir, "all_%A_%a.err"), "array": "0-" + str(num_experiments - 1), - "partition": "1080ti-short", + "partition": "gpu-long", "ntasks": 1, + "cpus-per-task": 4, "mem-per-cpu": 4000, - "gres": "gpu:1", + "gpus-per-node": 1, + "time": "7-0", } sbatch_args.update(self.sbatch_args) diff --git a/all/policies/soft_deterministic.py b/all/policies/soft_deterministic.py index 74656f08..9d6b3fb2 100644 --- a/all/policies/soft_deterministic.py +++ b/all/policies/soft_deterministic.py @@ -20,18 +20,32 @@ class SoftDeterministicPolicy(Approximation): kwargs (optional): Any other arguments accepted by all.approximation.Approximation """ - def __init__(self, model, optimizer=None, space=None, name="policy", **kwargs): - model = SoftDeterministicPolicyNetwork(model, space) + def __init__( + self, + model, + optimizer=None, + space=None, + name="policy", + log_std_min=-20, + log_std_max=4, + **kwargs + ): + model = SoftDeterministicPolicyNetwork( + model, space, log_std_min=log_std_min, log_std_max=log_std_max + ) self._inner_model = model super().__init__(model, optimizer, name=name, **kwargs) class SoftDeterministicPolicyNetwork(RLNetwork): - def __init__(self, model, space): + def __init__(self, model, space, log_std_min=-20, log_std_max=4, log_std_scale=0.5): super().__init__(model) self._action_dim = space.shape[0] self._tanh_scale = torch.tensor((space.high - space.low) / 2).to(self.device) self._tanh_mean = torch.tensor((space.high + space.low) / 2).to(self.device) + self._log_std_min = log_std_min + self._log_std_max = log_std_max + self._log_std_scale = log_std_scale def forward(self, state): outputs = super().forward(state) @@ -41,9 +55,10 @@ def forward(self, state): def _normal(self, outputs): means = outputs[..., 0 : self._action_dim] - logvars = outputs[..., self._action_dim :] - std = logvars.mul(0.5).exp_() - return torch.distributions.normal.Normal(means, std) + log_stds = outputs[..., self._action_dim :] * self._log_std_scale + clipped_log_stds = torch.clamp(log_stds, self._log_std_min, self._log_std_max) + stds = clipped_log_stds.exp_() + return torch.distributions.normal.Normal(means, stds) def _sample(self, normal): raw = normal.rsample() diff --git a/all/presets/continuous/ddpg.py b/all/presets/continuous/ddpg.py index 4762d252..60d0e7eb 100644 --- a/all/presets/continuous/ddpg.py +++ b/all/presets/continuous/ddpg.py @@ -16,8 +16,8 @@ # Common settings "discount_factor": 0.99, # Adam optimizer settings - "lr_q": 3e-4, - "lr_pi": 3e-4, + "lr_q": 1e-3, + "lr_pi": 1e-3, # Training settings "minibatch_size": 256, "update_frequency": 1, diff --git a/all/presets/continuous/sac.py b/all/presets/continuous/sac.py index 460ec2a8..a46d8fc6 100644 --- a/all/presets/continuous/sac.py +++ b/all/presets/continuous/sac.py @@ -17,7 +17,7 @@ "discount_factor": 0.99, # Adam optimizer settings "lr_q": 1e-3, - "lr_pi": 3e-4, + "lr_pi": 1e-3, # Training settings "minibatch_size": 256, "update_frequency": 1, @@ -26,7 +26,7 @@ "replay_start_size": 5000, "replay_buffer_size": 1e6, # Exploration settings - "temperature_initial": 0.1, + "temperature_initial": 1.0, "lr_temperature_scaling": 3e-5, "entropy_backups": True, "entropy_target_scaling": 1.0, diff --git a/benchmarks/atari40.png b/benchmarks/atari40.png deleted file mode 100644 index 4e2d8e45..00000000 Binary files a/benchmarks/atari40.png and /dev/null differ diff --git a/benchmarks/atari_40m.png b/benchmarks/atari_40m.png new file mode 100644 index 00000000..8940654b Binary files /dev/null and b/benchmarks/atari_40m.png differ diff --git a/benchmarks/atari40.py b/benchmarks/atari_40m.py similarity index 85% rename from benchmarks/atari40.py rename to benchmarks/atari_40m.py index 3dc88c7c..09812e7c 100644 --- a/benchmarks/atari40.py +++ b/benchmarks/atari_40m.py @@ -20,8 +20,8 @@ def main(): agents, envs, 10e6, - logdir="benchmarks/atari40", - sbatch_args={"partition": "gpu-long"}, + logdir="benchmarks/atari_40m", + sbatch_args={"partition": "gypsum-1080ti"}, ) diff --git a/benchmarks/mujoco_v4.png b/benchmarks/mujoco_v4.png new file mode 100644 index 00000000..f71acf4e Binary files /dev/null and b/benchmarks/mujoco_v4.png differ diff --git a/benchmarks/mujoco_v4.py b/benchmarks/mujoco_v4.py new file mode 100644 index 00000000..746183de --- /dev/null +++ b/benchmarks/mujoco_v4.py @@ -0,0 +1,34 @@ +from all.environments import MujocoEnvironment +from all.experiments import SlurmExperiment +from all.presets.continuous import ddpg, ppo, sac + + +def main(): + frames = int(5e6) + + agents = [ddpg, ppo, sac] + + envs = [ + MujocoEnvironment(env, device="cuda") + for env in [ + "Ant-v4", + "HalfCheetah-v4", + "Hopper-v4", + "Humanoid-v4", + "Walker2d-v4", + ] + ] + + SlurmExperiment( + agents, + envs, + frames, + logdir="benchmarks/mujoco_v4", + sbatch_args={ + "partition": "gpu-long", + }, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/pybullet.png b/benchmarks/pybullet.png deleted file mode 100644 index 1602eb76..00000000 Binary files a/benchmarks/pybullet.png and /dev/null differ diff --git a/benchmarks/pybullet_v0.png b/benchmarks/pybullet_v0.png new file mode 100644 index 00000000..27c65759 Binary files /dev/null and b/benchmarks/pybullet_v0.png differ diff --git a/benchmarks/pybullet.py b/benchmarks/pybullet_v0.py similarity index 52% rename from benchmarks/pybullet.py rename to benchmarks/pybullet_v0.py index dc045fd6..5e92b8f7 100644 --- a/benchmarks/pybullet.py +++ b/benchmarks/pybullet_v0.py @@ -4,21 +4,29 @@ def main(): - frames = int(1e7) + frames = int(5e6) agents = [ddpg, ppo, sac] envs = [ PybulletEnvironment(env, device="cuda") - for env in PybulletEnvironment.short_names + for env in [ + "AntBulletEnv-v0", + "HalfCheetahBulletEnv-v0", + "HopperBulletEnv-v0", + "HumanoidBulletEnv-v0", + "Walker2DBulletEnv-v0", + ] ] SlurmExperiment( agents, envs, frames, - logdir="benchmarks/pybullet", - sbatch_args={"partition": "gpu-long"}, + logdir="benchmarks/pybullet_v0", + sbatch_args={ + "partition": "gpu-long", + }, ) diff --git a/docs/source/guide/benchmark_performance.rst b/docs/source/guide/benchmark_performance.rst index 1a9348c6..831237cd 100644 --- a/docs/source/guide/benchmark_performance.rst +++ b/docs/source/guide/benchmark_performance.rst @@ -28,7 +28,7 @@ Additionally, we use the following agent "bodies": The results were as follows: -.. image:: ../../../benchmarks/atari40.png +.. image:: ../../../benchmarks/atari_40m.png For comparison, we look at the results published in the paper, `Rainbow: Combining Improvements in Deep Reinforcement Learning `_: @@ -40,23 +40,29 @@ Our ``dqn`` and ``ddqn`` in particular were better almost across the board. While there are some minor implementation differences (for example, we use ``Adam`` for most algorithms instead of ``RMSprop``), our agents achieved very similar behavior to the agents tested by DeepMind. +MuJoCo Benchmark +------------------ + +`MuJoCo https://mujoco.org`_ is "a free and open source physics engine that aims to facilitate research and development in robotics, biomechanics, graphics and animation, and other areas where fast and accurate simulation is needed." +The MuJoCo Gym environments are a common benchmark in RL research for evaluating agents with continuous action spaces. +We ran each continuous preset for 5 million timesteps (in this case, timesteps are equal to frames). +The learning rate was decayed over the course of training using cosine annealing. +The results were as follows: + +.. image:: ../../../benchmarks/mujoco_v4.png + +These results are similar to results found elsewhere, and in some cases better. +However, results can very based on hyperparameter tuning, implementation specifics, and the random seed. + PyBullet Benchmark ------------------ `PyBullet `_ provides a free alternative to the popular MuJoCo robotics environments. -While MuJoCo requires a license key and can be difficult for independent researchers to afford, PyBullet is free and open. -Additionally, the PyBullet environments are widely considered more challenging, making them a more discriminant test bed. -For these reasons, we chose to benchmark the ``all.presets.continuous`` presets using PyBullet. - -Similar to the Atari benchmark, we ran each agent for 10 million timesteps (in this case, timesteps are equal to frames). +We ran each agent for 5 million timesteps (in this case, timesteps are equal to frames). The learning rate was decayed over the course of training using cosine annealing. -To reduce the variance of the updates, we added an extra time feature to the state (t * 0.001, where t is the current timestep). The results were as follows: -.. image:: ../../../benchmarks/pybullet.png - -PPO was omitted from the plot for Humanoid because it achieved very large negative returns which interfered with the scale of the graph. -Note, however, that our implementation of soft actor-critic (SAC) is able to solve even this difficult environment. +.. image:: ../../../benchmarks/pybullet_v0.png Because most research papers still use MuJoCo, direct comparisons are difficult to come by. However, George Sung helpfully benchmarked TD3 and DDPG on several PyBullet environments [here](https://github.com/georgesung/TD3).