From 45df968439c401b88a2c119d05c9958085606f75 Mon Sep 17 00:00:00 2001 From: SimonDuToit Date: Tue, 5 Nov 2024 16:53:57 +0200 Subject: [PATCH 01/11] unified gae --- mava/systems/ppo/anakin/ff_ippo.py | 46 +++++++-------------------- mava/systems/ppo/anakin/ff_mappo.py | 47 +++++++--------------------- mava/systems/ppo/anakin/rec_ippo.py | 25 ++------------- mava/systems/ppo/anakin/rec_mappo.py | 25 ++------------- mava/utils/gae.py | 29 +++++++++++++++++ 5 files changed, 57 insertions(+), 115 deletions(-) create mode 100644 mava/utils/gae.py diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index 9329b1357..3b15e1a1d 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -36,6 +36,7 @@ from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.gae import calculate_gae from mava.utils.jax_utils import ( merge_leading_dims, unreplicate_batch_dim, @@ -81,7 +82,7 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: """Step the environment.""" - params, opt_states, key, env_state, last_timestep = learner_state + params, opt_states, key, env_state, last_timestep, last_done = learner_state # SELECT ACTION key, policy_key = jax.random.split(key) @@ -102,7 +103,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra info = timestep.extras["episode_metrics"] transition = PPOTransition( - done, + last_done, action, value, timestep.reward, @@ -110,7 +111,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra last_timestep.observation, info, ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = LearnerState(params, opt_states, key, env_state, timestep, done) return learner_state, transition # STEP ENVIRONMENT FOR ROLLOUT LENGTH @@ -119,37 +120,10 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra ) # CALCULATE ADVANTAGE - params, opt_states, key, env_state, last_timestep = learner_state + params, opt_states, key, env_state, last_timestep, last_done = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) - def _calculate_gae( - traj_batch: PPOTransition, last_val: chex.Array - ) -> Tuple[chex.Array, chex.Array]: - """Calculate the GAE.""" - - def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple: - """Calculate the GAE for a single transition.""" - gae, next_value = gae_and_next_value - done, value, reward = ( - transition.done, - transition.value, - transition.reward, - ) - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae - return (gae, value), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value - - advantages, targets = _calculate_gae(traj_batch, last_val) + advantages, targets = calculate_gae(traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" @@ -312,7 +286,7 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + learner_state = LearnerState(params, opt_states, key, env_state, last_timestep, last_done) metric = traj_batch.info return learner_state, (metric, loss_info) @@ -430,6 +404,10 @@ def learner_setup( params = restored_params # Define params to be replicated across devices and batches. + dones = jnp.zeros( + (config.arch.num_envs, config.system.num_agents), + dtype=bool, + ) key, step_keys = jax.random.split(key) opt_states = OptStates(actor_opt_state, critic_opt_state) replicate_learner = (params, opt_states, step_keys) @@ -443,7 +421,7 @@ def learner_setup( # Initialise learner state. params, opt_states, step_keys = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones) return learn, actor_network, init_learner_state diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index 1b884c070..e9992b1f5 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -35,6 +35,7 @@ from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.gae import calculate_gae from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head @@ -76,7 +77,7 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: """Step the environment.""" - params, opt_states, key, env_state, last_timestep = learner_state + params, opt_states, key, env_state, last_timestep, last_done = learner_state # SELECT ACTION key, policy_key = jax.random.split(key) @@ -96,9 +97,9 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra info = timestep.extras["episode_metrics"] transition = PPOTransition( - done, action, value, timestep.reward, log_prob, last_timestep.observation, info + last_done, action, value, timestep.reward, log_prob, last_timestep.observation, info ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = LearnerState(params, opt_states, key, env_state, timestep, done) return learner_state, transition # STEP ENVIRONMENT FOR ROLLOUT LENGTH @@ -107,37 +108,10 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra ) # CALCULATE ADVANTAGE - params, opt_states, key, env_state, last_timestep = learner_state + params, opt_states, key, env_state, last_timestep, last_done = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) - def _calculate_gae( - traj_batch: PPOTransition, last_val: chex.Array - ) -> Tuple[chex.Array, chex.Array]: - """Calculate the GAE.""" - - def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple: - """Calculate the GAE for a single transition.""" - gae, next_value = gae_and_next_value - done, value, reward = ( - transition.done, - transition.value, - transition.reward, - ) - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae - return (gae, value), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value - - advantages, targets = _calculate_gae(traj_batch, last_val) + advantages, targets = calculate_gae(traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" @@ -296,7 +270,7 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + learner_state = LearnerState(params, opt_states, key, env_state, last_timestep, last_done) metric = traj_batch.info return learner_state, (metric, loss_info) @@ -414,7 +388,10 @@ def learner_setup( params = restored_params # Define params to be replicated across devices and batches. - key, step_keys = jax.random.split(key) + dones = jnp.zeros( + (config.arch.num_envs, config.system.num_agents), + dtype=bool, + ) opt_states = OptStates(actor_opt_state, critic_opt_state) replicate_learner = (params, opt_states, step_keys) @@ -427,7 +404,7 @@ def learner_setup( # Initialise learner state. params, opt_states, step_keys = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones) return learn, actor_network, init_learner_state diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index 3f6048d1a..738c76269 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -49,6 +49,7 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.gae import calculate_gae from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head @@ -179,29 +180,7 @@ def _env_step( # Squeeze out the batch dimension and mask out the value of terminal states. last_val = last_val.squeeze(0) - def _calculate_gae( - traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array - ) -> Tuple[chex.Array, chex.Array]: - def _get_advantages( - carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition - ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: - gae, next_value, next_done = carry - done, value, reward = transition.done, transition.value, transition.reward - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - next_done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae - return (gae, value, done), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val, last_done), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value - - advantages, targets = _calculate_gae(traj_batch, last_val, last_done) + advantages, targets = calculate_gae(traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index fb5c348bc..5483a927d 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -49,6 +49,7 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.gae import calculate_gae from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head @@ -175,29 +176,7 @@ def _env_step( # Squeeze out the batch dimension and mask out the value of terminal states. last_val = last_val.squeeze(0) - def _calculate_gae( - traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array - ) -> Tuple[chex.Array, chex.Array]: - def _get_advantages( - carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition - ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: - gae, next_value, next_done = carry - done, value, reward = transition.done, transition.value, transition.reward - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - next_done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae - return (gae, value, done), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val, last_done), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value - - advantages, targets = _calculate_gae(traj_batch, last_val, last_done) + advantages, targets = calculate_gae(traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" diff --git a/mava/utils/gae.py b/mava/utils/gae.py new file mode 100644 index 000000000..53bb25731 --- /dev/null +++ b/mava/utils/gae.py @@ -0,0 +1,29 @@ +from typing import Tuple, Union +import chex +import jax +import jax.numpy as jnp + +from mava.systems.ppo.types import PPOTransition, RNNPPOTransition + + +def calculate_gae( + traj_batch: Union[PPOTransition, RNNPPOTransition], last_val: chex.Array, last_done: chex.Array, gamma: float, gae_lambda: float +) -> Tuple[chex.Array, chex.Array]: + def _get_advantages( + carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition + ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: + gae, next_value, next_done = carry + done, value, reward = transition.done, transition.value, transition.reward + gamma = gamma + delta = reward + gamma * next_value * (1 - next_done) - value + gae = delta + gamma * gae_lambda * (1 - next_done) * gae + return (gae, value, done), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val, last_done), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value \ No newline at end of file From 1cf85343e1157b345a20e87e64feaf4709608a55 Mon Sep 17 00:00:00 2001 From: SimonDuToit Date: Tue, 5 Nov 2024 17:29:34 +0200 Subject: [PATCH 02/11] fixes --- mava/systems/ppo/anakin/ff_ippo.py | 4 ++-- mava/systems/ppo/anakin/ff_mappo.py | 5 +++-- mava/systems/ppo/types.py | 1 + mava/utils/gae.py | 1 - 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index 3b15e1a1d..5b8623ba9 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -410,7 +410,7 @@ def learner_setup( ) key, step_keys = jax.random.split(key) opt_states = OptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states, step_keys) + replicate_learner = (params, opt_states, step_keys, dones) # Duplicate learner for update_batch_size. broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape)) @@ -420,7 +420,7 @@ def learner_setup( replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) # Initialise learner state. - params, opt_states, step_keys = replicate_learner + params, opt_states, step_keys, dones = replicate_learner init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones) return learn, actor_network, init_learner_state diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index e9992b1f5..b23be41d8 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -392,8 +392,9 @@ def learner_setup( (config.arch.num_envs, config.system.num_agents), dtype=bool, ) + key, step_keys = jax.random.split(key) opt_states = OptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states, step_keys) + replicate_learner = (params, opt_states, step_keys, dones) # Duplicate learner for update_batch_size. broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape)) @@ -403,7 +404,7 @@ def learner_setup( replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) # Initialise learner state. - params, opt_states, step_keys = replicate_learner + params, opt_states, step_keys, dones = replicate_learner init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones) return learn, actor_network, init_learner_state diff --git a/mava/systems/ppo/types.py b/mava/systems/ppo/types.py index f129b89d3..e36b9c871 100644 --- a/mava/systems/ppo/types.py +++ b/mava/systems/ppo/types.py @@ -52,6 +52,7 @@ class LearnerState(NamedTuple): key: chex.PRNGKey env_state: State timestep: TimeStep + dones: Done class RNNLearnerState(NamedTuple): diff --git a/mava/utils/gae.py b/mava/utils/gae.py index 53bb25731..7282f62b8 100644 --- a/mava/utils/gae.py +++ b/mava/utils/gae.py @@ -14,7 +14,6 @@ def _get_advantages( ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: gae, next_value, next_done = carry done, value, reward = transition.done, transition.value, transition.reward - gamma = gamma delta = reward + gamma * next_value * (1 - next_done) - value gae = delta + gamma * gae_lambda * (1 - next_done) * gae return (gae, value, done), gae From 18241c091eeb6ef4d1ba678f66cd0bc87ad0e1c7 Mon Sep 17 00:00:00 2001 From: SimonDuToit Date: Thu, 7 Nov 2024 17:17:15 +0200 Subject: [PATCH 03/11] pre-commit --- mava/systems/ppo/anakin/ff_ippo.py | 4 +++- mava/systems/ppo/anakin/ff_mappo.py | 4 +++- mava/systems/ppo/anakin/rec_ippo.py | 4 +++- mava/systems/ppo/anakin/rec_mappo.py | 4 +++- mava/utils/gae.py | 23 +++++++++++++++++++++-- 5 files changed, 33 insertions(+), 6 deletions(-) diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index 5b8623ba9..98d4effb8 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -123,7 +123,9 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra params, opt_states, key, env_state, last_timestep, last_done = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) - advantages, targets = calculate_gae(traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda) + advantages, targets = calculate_gae( + traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda + ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index b23be41d8..420978a04 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -111,7 +111,9 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra params, opt_states, key, env_state, last_timestep, last_done = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) - advantages, targets = calculate_gae(traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda) + advantages, targets = calculate_gae( + traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda + ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index 738c76269..157e24bf4 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -180,7 +180,9 @@ def _env_step( # Squeeze out the batch dimension and mask out the value of terminal states. last_val = last_val.squeeze(0) - advantages, targets = calculate_gae(traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda) + advantages, targets = calculate_gae( + traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda + ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index 5483a927d..4bd7f594a 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -176,7 +176,9 @@ def _env_step( # Squeeze out the batch dimension and mask out the value of terminal states. last_val = last_val.squeeze(0) - advantages, targets = calculate_gae(traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda) + advantages, targets = calculate_gae( + traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda + ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" diff --git a/mava/utils/gae.py b/mava/utils/gae.py index 7282f62b8..b2f02b8fa 100644 --- a/mava/utils/gae.py +++ b/mava/utils/gae.py @@ -1,4 +1,19 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Tuple, Union + import chex import jax import jax.numpy as jnp @@ -7,7 +22,11 @@ def calculate_gae( - traj_batch: Union[PPOTransition, RNNPPOTransition], last_val: chex.Array, last_done: chex.Array, gamma: float, gae_lambda: float + traj_batch: Union[PPOTransition, RNNPPOTransition], + last_val: chex.Array, + last_done: chex.Array, + gamma: float, + gae_lambda: float, ) -> Tuple[chex.Array, chex.Array]: def _get_advantages( carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition @@ -25,4 +44,4 @@ def _get_advantages( reverse=True, unroll=16, ) - return advantages, advantages + traj_batch.value \ No newline at end of file + return advantages, advantages + traj_batch.value From ccfd8edfdee6c259c81dc30bd4ff021a653976f5 Mon Sep 17 00:00:00 2001 From: SimonDuToit Date: Thu, 7 Nov 2024 18:12:05 +0200 Subject: [PATCH 04/11] requested changes --- mava/systems/ppo/anakin/ff_ippo.py | 2 +- mava/systems/ppo/anakin/ff_mappo.py | 2 +- mava/systems/ppo/anakin/rec_ippo.py | 2 +- mava/systems/ppo/anakin/rec_mappo.py | 2 +- mava/utils/{gae.py => multistep.py} | 23 ++++++++++++++++++++++- 5 files changed, 26 insertions(+), 5 deletions(-) rename mava/utils/{gae.py => multistep.py} (64%) diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index 98d4effb8..e65326210 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -36,13 +36,13 @@ from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer -from mava.utils.gae import calculate_gae from mava.utils.jax_utils import ( merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.multistep import calculate_gae from mava.utils.network_utils import get_action_head from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index 420978a04..f53cce299 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -35,9 +35,9 @@ from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer -from mava.utils.gae import calculate_gae from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.multistep import calculate_gae from mava.utils.network_utils import get_action_head from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index 157e24bf4..e5682f9e3 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -49,9 +49,9 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer -from mava.utils.gae import calculate_gae from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.multistep import calculate_gae from mava.utils.network_utils import get_action_head from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index 4bd7f594a..1a204498e 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -49,9 +49,9 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer -from mava.utils.gae import calculate_gae from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.multistep import calculate_gae from mava.utils.network_utils import get_action_head from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate diff --git a/mava/utils/gae.py b/mava/utils/multistep.py similarity index 64% rename from mava/utils/gae.py rename to mava/utils/multistep.py index b2f02b8fa..090aba01b 100644 --- a/mava/utils/gae.py +++ b/mava/utils/multistep.py @@ -27,12 +27,33 @@ def calculate_gae( last_done: chex.Array, gamma: float, gae_lambda: float, + unroll: int = 16 ) -> Tuple[chex.Array, chex.Array]: + """Computes truncated generalized advantage estimates. + + The advantages are computed in a backwards fashion according to the equation: + Âₜ = δₜ + (γλ) * δₜ₊₁ + ... + ... + (γλ)ᵏ⁻ᵗ⁺¹ * δₖ₋₁ + where δₜ = rₜ₊₁ + γₜ₊₁ * v(sₜ₊₁) - v(sₜ). + See Proximal Policy Optimization Algorithms, Schulman et al.: + https://arxiv.org/abs/1707.06347 + + Args: + traj_batch (B, T, A, ...): a batch of trajectories. + last_val (B, A): value of the final timestep. + last_done (B, A): whether the last timestep was a terminated or truncated. + gamma (float): discount factor. + gae_lambda (float): GAE mixing parameter. + unroll (int): how much XLA should unroll the scan used to calculate GAE. + + Returns Tuple[(B, T, A), (B, T, A)]: advantages and target values. + """ + def _get_advantages( carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: gae, next_value, next_done = carry done, value, reward = transition.done, transition.value, transition.reward + delta = reward + gamma * next_value * (1 - next_done) - value gae = delta + gamma * gae_lambda * (1 - next_done) * gae return (gae, value, done), gae @@ -42,6 +63,6 @@ def _get_advantages( (jnp.zeros_like(last_val), last_val, last_done), traj_batch, reverse=True, - unroll=16, + unroll=unroll, ) return advantages, advantages + traj_batch.value From f175753e8871786292ac74825d01d955f672ad4e Mon Sep 17 00:00:00 2001 From: Simon Du Toit <90381208+SimonDuToit@users.noreply.github.com> Date: Fri, 8 Nov 2024 11:37:33 +0200 Subject: [PATCH 05/11] Replace A with N (number of agents) Co-authored-by: Sasha Abramowitz --- mava/utils/multistep.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mava/utils/multistep.py b/mava/utils/multistep.py index 090aba01b..8124b664d 100644 --- a/mava/utils/multistep.py +++ b/mava/utils/multistep.py @@ -38,9 +38,9 @@ def calculate_gae( https://arxiv.org/abs/1707.06347 Args: - traj_batch (B, T, A, ...): a batch of trajectories. - last_val (B, A): value of the final timestep. - last_done (B, A): whether the last timestep was a terminated or truncated. + traj_batch (B, T, N, ...): a batch of trajectories. + last_val (B, N): value of the final timestep. + last_done (B, N): whether the last timestep was a terminated or truncated. gamma (float): discount factor. gae_lambda (float): GAE mixing parameter. unroll (int): how much XLA should unroll the scan used to calculate GAE. From f6a513e7497fb07cca6b4bc20d610ec9c1e97956 Mon Sep 17 00:00:00 2001 From: Simon Du Toit <90381208+SimonDuToit@users.noreply.github.com> Date: Fri, 8 Nov 2024 11:37:53 +0200 Subject: [PATCH 06/11] same Co-authored-by: Sasha Abramowitz --- mava/utils/multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/utils/multistep.py b/mava/utils/multistep.py index 8124b664d..a122b9130 100644 --- a/mava/utils/multistep.py +++ b/mava/utils/multistep.py @@ -45,7 +45,7 @@ def calculate_gae( gae_lambda (float): GAE mixing parameter. unroll (int): how much XLA should unroll the scan used to calculate GAE. - Returns Tuple[(B, T, A), (B, T, A)]: advantages and target values. + Returns Tuple[(B, T, N), (B, T, N)]: advantages and target values. """ def _get_advantages( From ff96cdbebf12a45303485f2191f110377ceebf33 Mon Sep 17 00:00:00 2001 From: SimonDuToit Date: Mon, 25 Nov 2024 11:34:04 +0200 Subject: [PATCH 07/11] cleaning --- mava/systems/ppo/sebulba/ff_ippo.py | 28 +++++++++++----------- mava/systems/ppo/types.py | 10 ++++++++ mava/utils/multistep.py | 10 ++++---- mava/utils/training.py | 36 ++++++++++++++++++++++++++++- 4 files changed, 64 insertions(+), 20 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 6f34c0b1a..55d0df7e2 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -41,7 +41,7 @@ from mava.evaluator import make_ff_eval_act_fn from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.systems.ppo.types import OptStates, Params, PPOTransition, SebulbaLearnerState from mava.types import ( ActorApply, CriticApply, @@ -162,7 +162,7 @@ def get_learner_step_fn( apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, -) -> SebulbaLearnerFn[LearnerState, PPOTransition]: +) -> SebulbaLearnerFn[SebulbaLearnerState, PPOTransition]: """Get the learner function.""" num_envs = config.arch.num_envs @@ -173,16 +173,16 @@ def get_learner_step_fn( actor_update_fn, critic_update_fn = update_fns def _update_step( - learner_state: LearnerState, + learner_state: SebulbaLearnerState, traj_batch: PPOTransition, - ) -> Tuple[LearnerState, Tuple]: + ) -> Tuple[SebulbaLearnerState, Tuple]: """A single update of the network. This function calculates advantages and targets based on the trajectories from the actor and updates the actor and critic networks based on the losses. Args: - learner_state (LearnerState): contains all the items needed for learning. + learner_state (SebulbaLearnerState): contains all the items needed for learning. traj_batch (PPOTransition): the batch of data to learn with. """ @@ -358,13 +358,13 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, None, learner_state.timestep) + learner_state = SebulbaLearnerState(params, opt_states, key, None, learner_state.timestep) metric = traj_batch.info return learner_state, (metric, loss_info) def learner_fn( - learner_state: LearnerState, traj_batch: PPOTransition - ) -> ExperimentOutput[LearnerState]: + learner_state: SebulbaLearnerState, traj_batch: PPOTransition + ) -> ExperimentOutput[SebulbaLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -394,8 +394,8 @@ def learner_fn( def learner_thread( - learn_fn: SebulbaLearnerFn[LearnerState, PPOTransition], - learner_state: LearnerState, + learn_fn: SebulbaLearnerFn[SebulbaLearnerState, PPOTransition], + learner_state: SebulbaLearnerState, config: DictConfig, eval_queue: Queue, pipeline: Pipeline, @@ -442,9 +442,9 @@ def learner_thread( def learner_setup( key: chex.PRNGKey, config: DictConfig, learner_devices: List ) -> Tuple[ - SebulbaLearnerFn[LearnerState, PPOTransition], + SebulbaLearnerFn[SebulbaLearnerState, PPOTransition], Tuple[ActorApply, CriticApply], - LearnerState, + SebulbaLearnerState, Sharding, ]: """Initialise learner_fn, network and learner state.""" @@ -508,7 +508,7 @@ def learner_setup( update_fns = (actor_optim.update, critic_optim.update) # defines how the learner state is sharded: params, opt and key = sharded, timestep = sharded - learn_state_spec = LearnerState(model_spec, model_spec, data_spec, None, data_spec) + learn_state_spec = SebulbaLearnerState(model_spec, model_spec, data_spec, None, data_spec) learn = get_learner_step_fn(apply_fns, update_fns, config) learn = jax.jit( shard_map( @@ -541,7 +541,7 @@ def learner_setup( ) # Initialise learner state. - init_learner_state = LearnerState(params, opt_states, step_keys, None, None) # type: ignore + init_learner_state = SebulbaLearnerState(params, opt_states, step_keys, None, None) # type: ignore env.close() return learn, apply_fns, init_learner_state, learner_sharding # type: ignore diff --git a/mava/systems/ppo/types.py b/mava/systems/ppo/types.py index 3d8a53d7a..5bf1dd769 100644 --- a/mava/systems/ppo/types.py +++ b/mava/systems/ppo/types.py @@ -55,6 +55,16 @@ class LearnerState(NamedTuple): dones: Done +class SebulbaLearnerState(NamedTuple): + """State of the learner.""" + + params: Params + opt_states: OptStates + key: chex.PRNGKey + env_state: State + timestep: TimeStep + + class RNNLearnerState(NamedTuple): """State of the `Learner` for recurrent architectures.""" diff --git a/mava/utils/multistep.py b/mava/utils/multistep.py index a122b9130..955f0ccd0 100644 --- a/mava/utils/multistep.py +++ b/mava/utils/multistep.py @@ -27,7 +27,7 @@ def calculate_gae( last_done: chex.Array, gamma: float, gae_lambda: float, - unroll: int = 16 + unroll: int = 16, ) -> Tuple[chex.Array, chex.Array]: """Computes truncated generalized advantage estimates. @@ -36,7 +36,7 @@ def calculate_gae( where δₜ = rₜ₊₁ + γₜ₊₁ * v(sₜ₊₁) - v(sₜ). See Proximal Policy Optimization Algorithms, Schulman et al.: https://arxiv.org/abs/1707.06347 - + Args: traj_batch (B, T, N, ...): a batch of trajectories. last_val (B, N): value of the final timestep. @@ -44,16 +44,16 @@ def calculate_gae( gamma (float): discount factor. gae_lambda (float): GAE mixing parameter. unroll (int): how much XLA should unroll the scan used to calculate GAE. - + Returns Tuple[(B, T, N), (B, T, N)]: advantages and target values. """ - + def _get_advantages( carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: gae, next_value, next_done = carry done, value, reward = transition.done, transition.value, transition.reward - + delta = reward + gamma * next_value * (1 - next_done) - value gae = delta + gamma * gae_lambda * (1 - next_done) * gae return (gae, value, done), gae diff --git a/mava/utils/training.py b/mava/utils/training.py index 77aa98fc6..8173a4f55 100644 --- a/mava/utils/training.py +++ b/mava/utils/training.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union +from typing import Callable, Tuple, Union +import chex +import jax +import jax.numpy as jnp from omegaconf import DictConfig +from mava.systems.ppo.types import PPOTransition, RNNPPOTransition + def make_learning_rate_schedule(init_lr: float, config: DictConfig) -> Callable: """Makes a very simple linear learning rate scheduler. @@ -62,3 +67,32 @@ def make_learning_rate(init_lr: float, config: DictConfig) -> Union[float, Calla return make_learning_rate_schedule(init_lr, config) else: return init_lr + + +def _calculate_gae( + traj_batch: PPOTransition, + last_val: chex.Array, + last_done: chex.Array, + recurrent: bool, + config: DictConfig, +) -> Tuple[chex.Array, chex.Array]: + def _get_advantages( + carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition + ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: + gae, next_value, next_done = carry + done, value, reward = transition.done, transition.value, transition.reward + gamma = config.system.gamma + if not recurrent: + next_done = done + delta = reward + gamma * next_value * (1 - next_done) - value + gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae + return (gae, value, done), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val, last_done), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value From 3778a3ad8fbf8d115ffdcaf2bd73ea4d5814df10 Mon Sep 17 00:00:00 2001 From: SimonDuToit Date: Mon, 25 Nov 2024 11:36:14 +0200 Subject: [PATCH 08/11] undo mistake --- mava/utils/training.py | 38 ++------------------------------------ 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/mava/utils/training.py b/mava/utils/training.py index 8173a4f55..6d3ddcc45 100644 --- a/mava/utils/training.py +++ b/mava/utils/training.py @@ -12,15 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Tuple, Union +from typing import Callable, Union -import chex -import jax -import jax.numpy as jnp from omegaconf import DictConfig -from mava.systems.ppo.types import PPOTransition, RNNPPOTransition - def make_learning_rate_schedule(init_lr: float, config: DictConfig) -> Callable: """Makes a very simple linear learning rate scheduler. @@ -66,33 +61,4 @@ def make_learning_rate(init_lr: float, config: DictConfig) -> Union[float, Calla if config.system.decay_learning_rates: return make_learning_rate_schedule(init_lr, config) else: - return init_lr - - -def _calculate_gae( - traj_batch: PPOTransition, - last_val: chex.Array, - last_done: chex.Array, - recurrent: bool, - config: DictConfig, -) -> Tuple[chex.Array, chex.Array]: - def _get_advantages( - carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition - ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: - gae, next_value, next_done = carry - done, value, reward = transition.done, transition.value, transition.reward - gamma = config.system.gamma - if not recurrent: - next_done = done - delta = reward + gamma * next_value * (1 - next_done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae - return (gae, value, done), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val, last_done), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value + return init_lr \ No newline at end of file From 348b7a67cb9954bb8aca662fbe0f9a26bff476ba Mon Sep 17 00:00:00 2001 From: SimonDuToit Date: Mon, 25 Nov 2024 11:37:54 +0200 Subject: [PATCH 09/11] more cleaning --- mava/utils/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/utils/training.py b/mava/utils/training.py index 6d3ddcc45..77aa98fc6 100644 --- a/mava/utils/training.py +++ b/mava/utils/training.py @@ -61,4 +61,4 @@ def make_learning_rate(init_lr: float, config: DictConfig) -> Union[float, Calla if config.system.decay_learning_rates: return make_learning_rate_schedule(init_lr, config) else: - return init_lr \ No newline at end of file + return init_lr From 425d7a6b36b9ff500babd535176b9b19f9c06b07 Mon Sep 17 00:00:00 2001 From: SimonDuToit Date: Tue, 3 Dec 2024 14:41:29 +0200 Subject: [PATCH 10/11] merge conflicts --- mava/systems/ppo/anakin/ff_ippo.py | 2 +- mava/systems/ppo/anakin/ff_mappo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index 3e0dd33dd..ad53cee52 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -96,7 +96,7 @@ def _env_step( transition = PPOTransition( last_done, action, value, timestep.reward, log_prob, last_timestep.observation ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = LearnerState(params, opt_states, key, env_state, timestep, done) return learner_state, (transition, timestep.extras["episode_metrics"]) # Step environment for rollout length diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index 173a7e698..e0e8ca09e 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -96,7 +96,7 @@ def _env_step( last_done, action, value, timestep.reward, log_prob, last_timestep.observation ) learner_state = LearnerState(params, opt_states, key, env_state, timestep, done) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) # Step environment for rollout length learner_state, (traj_batch, episode_metrics) = jax.lax.scan( From 59b815f9bc8131ad12c4c352f8340b854c18810c Mon Sep 17 00:00:00 2001 From: SimonDuToit Date: Tue, 3 Dec 2024 14:52:42 +0200 Subject: [PATCH 11/11] uncapitalize advantage comments --- mava/systems/ppo/anakin/ff_ippo.py | 2 +- mava/systems/ppo/anakin/ff_mappo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index ad53cee52..c98d679c0 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -104,7 +104,7 @@ def _env_step( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_states, key, env_state, last_timestep, last_done = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index e0e8ca09e..750e49ea3 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -103,7 +103,7 @@ def _env_step( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_states, key, env_state, last_timestep, last_done = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation)