Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: unified gae #1129

Merged
merged 20 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 16 additions & 36 deletions mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -78,7 +79,7 @@ def _env_step(
learner_state: LearnerState, _: Any
) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]:
"""Step the environment."""
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state

# Select action
key, policy_key = jax.random.split(key)
Expand All @@ -93,9 +94,9 @@ def _env_step(

done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1)
transition = PPOTransition(
done, action, value, timestep.reward, log_prob, last_timestep.observation
last_done, action, value, timestep.reward, log_prob, last_timestep.observation
)
learner_state = LearnerState(params, opt_states, key, env_state, timestep)
learner_state = LearnerState(params, opt_states, key, env_state, timestep, done)
return learner_state, (transition, timestep.extras["episode_metrics"])

# Step environment for rollout length
Expand All @@ -104,37 +105,12 @@ def _env_step(
)

# Calculate advantage
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state
last_val = critic_apply_fn(params.critic_params, last_timestep.observation)

def _calculate_gae(
traj_batch: PPOTransition, last_val: chex.Array
) -> Tuple[chex.Array, chex.Array]:
"""Calculate the GAE."""

def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple:
"""Calculate the GAE for a single transition."""
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae
return (gae, value), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down Expand Up @@ -283,7 +259,7 @@ def _critic_loss_fn(
)

params, opt_states, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep)
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep, last_done)
return learner_state, (episode_metrics, loss_info)

def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]:
Expand Down Expand Up @@ -400,9 +376,13 @@ def learner_setup(
params = restored_params

# Define params to be replicated across devices and batches.
dones = jnp.zeros(
(config.arch.num_envs, config.system.num_agents),
dtype=bool,
)
key, step_keys = jax.random.split(key)
opt_states = OptStates(actor_opt_state, critic_opt_state)
replicate_learner = (params, opt_states, step_keys)
replicate_learner = (params, opt_states, step_keys, dones)

# Duplicate learner for update_batch_size.
broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape))
Expand All @@ -412,8 +392,8 @@ def learner_setup(
replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices())

# Initialise learner state.
params, opt_states, step_keys = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps)
params, opt_states, step_keys, dones = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones)

return learn, actor_network, init_learner_state

Expand Down
52 changes: 16 additions & 36 deletions mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -77,7 +78,7 @@ def _env_step(
learner_state: LearnerState, _: Any
) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]:
"""Step the environment."""
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state

# Select action
key, policy_key = jax.random.split(key)
Expand All @@ -92,9 +93,9 @@ def _env_step(
done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1)

transition = PPOTransition(
done, action, value, timestep.reward, log_prob, last_timestep.observation
last_done, action, value, timestep.reward, log_prob, last_timestep.observation
)
learner_state = LearnerState(params, opt_states, key, env_state, timestep)
learner_state = LearnerState(params, opt_states, key, env_state, timestep, done)
return learner_state, (transition, timestep.extras["episode_metrics"])

# Step environment for rollout length
Expand All @@ -103,37 +104,12 @@ def _env_step(
)

# Calculate advantage
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state
last_val = critic_apply_fn(params.critic_params, last_timestep.observation)

def _calculate_gae(
traj_batch: PPOTransition, last_val: chex.Array
) -> Tuple[chex.Array, chex.Array]:
"""Calculate the GAE."""

def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple:
"""Calculate the GAE for a single transition."""
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae
return (gae, value), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down Expand Up @@ -285,7 +261,7 @@ def _critic_loss_fn(
)

params, opt_states, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep)
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep, last_done)
return learner_state, (episode_metrics, loss_info)

def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]:
Expand Down Expand Up @@ -402,9 +378,13 @@ def learner_setup(
params = restored_params

# Define params to be replicated across devices and batches.
dones = jnp.zeros(
(config.arch.num_envs, config.system.num_agents),
dtype=bool,
)
key, step_keys = jax.random.split(key)
opt_states = OptStates(actor_opt_state, critic_opt_state)
replicate_learner = (params, opt_states, step_keys)
replicate_learner = (params, opt_states, step_keys, dones)

# Duplicate learner for update_batch_size.
broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape))
Expand All @@ -414,8 +394,8 @@ def learner_setup(
replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices())

# Initialise learner state.
params, opt_states, step_keys = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps)
params, opt_states, step_keys, dones = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones)

return learn, actor_network, init_learner_state

Expand Down
27 changes: 4 additions & 23 deletions mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -158,29 +159,9 @@ def _env_step(
# Squeeze out the batch dimension and mask out the value of terminal states.
last_val = last_val.squeeze(0)

def _calculate_gae(
traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array
) -> Tuple[chex.Array, chex.Array]:
def _get_advantages(
carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition
) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]:
gae, next_value, next_done = carry
done, value, reward = transition.done, transition.value, transition.reward
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - next_done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae
return (gae, value, done), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val, last_done),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down
27 changes: 4 additions & 23 deletions mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -160,29 +161,9 @@ def _env_step(
# Squeeze out the batch dimension and mask out the value of terminal states.
last_val = last_val.squeeze(0)

def _calculate_gae(
traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array
) -> Tuple[chex.Array, chex.Array]:
def _get_advantages(
carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition
) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]:
gae, next_value, next_done = carry
done, value, reward = transition.done, transition.value, transition.reward
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - next_done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae
return (gae, value, done), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val, last_done),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down
28 changes: 14 additions & 14 deletions mava/systems/ppo/sebulba/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from mava.evaluator import make_ff_eval_act_fn
from mava.networks import FeedForwardActor as Actor
from mava.networks import FeedForwardValueNet as Critic
from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition
from mava.systems.ppo.types import OptStates, Params, PPOTransition, SebulbaLearnerState
from mava.types import (
ActorApply,
CriticApply,
Expand Down Expand Up @@ -163,7 +163,7 @@ def get_learner_step_fn(
apply_fns: Tuple[ActorApply, CriticApply],
update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn],
config: DictConfig,
) -> SebulbaLearnerFn[LearnerState, PPOTransition]:
) -> SebulbaLearnerFn[SebulbaLearnerState, PPOTransition]:
"""Get the learner function."""

num_envs = config.arch.num_envs
Expand All @@ -174,16 +174,16 @@ def get_learner_step_fn(
actor_update_fn, critic_update_fn = update_fns

def _update_step(
learner_state: LearnerState,
learner_state: SebulbaLearnerState,
traj_batch: PPOTransition,
) -> Tuple[LearnerState, Metrics]:
) -> Tuple[SebulbaLearnerState, Metrics]:
"""A single update of the network.

This function calculates advantages and targets based on the trajectories
from the actor and updates the actor and critic networks based on the losses.

Args:
learner_state (LearnerState): contains all the items needed for learning.
learner_state (SebulbaLearnerState): contains all the items needed for learning.
traj_batch (PPOTransition): the batch of data to learn with.
"""

Expand Down Expand Up @@ -359,12 +359,12 @@ def _critic_loss_fn(
)

params, opt_states, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_states, key, None, learner_state.timestep)
learner_state = SebulbaLearnerState(params, opt_states, key, None, learner_state.timestep)
return learner_state, loss_info

def learner_fn(
learner_state: LearnerState, traj_batch: PPOTransition
) -> Tuple[LearnerState, Metrics]:
learner_state: SebulbaLearnerState, traj_batch: PPOTransition
) -> Tuple[SebulbaLearnerState, Metrics]:
"""Learner function.

This function represents the learner, it updates the network parameters
Expand All @@ -390,8 +390,8 @@ def learner_fn(


def learner_thread(
learn_fn: SebulbaLearnerFn[LearnerState, PPOTransition],
learner_state: LearnerState,
learn_fn: SebulbaLearnerFn[SebulbaLearnerState, PPOTransition],
learner_state: SebulbaLearnerState,
config: DictConfig,
eval_queue: Queue,
pipeline: Pipeline,
Expand Down Expand Up @@ -438,9 +438,9 @@ def learner_thread(
def learner_setup(
key: chex.PRNGKey, config: DictConfig, learner_devices: List
) -> Tuple[
SebulbaLearnerFn[LearnerState, PPOTransition],
SebulbaLearnerFn[SebulbaLearnerState, PPOTransition],
Tuple[ActorApply, CriticApply],
LearnerState,
SebulbaLearnerState,
Sharding,
]:
"""Initialise learner_fn, network and learner state."""
Expand Down Expand Up @@ -504,7 +504,7 @@ def learner_setup(
update_fns = (actor_optim.update, critic_optim.update)

# defines how the learner state is sharded: params, opt and key = sharded, timestep = sharded
learn_state_spec = LearnerState(model_spec, model_spec, data_spec, None, data_spec)
learn_state_spec = SebulbaLearnerState(model_spec, model_spec, data_spec, None, data_spec)
learn = get_learner_step_fn(apply_fns, update_fns, config)
learn = jax.jit(
shard_map(
Expand Down Expand Up @@ -537,7 +537,7 @@ def learner_setup(
)

# Initialise learner state.
init_learner_state = LearnerState(params, opt_states, step_keys, None, None) # type: ignore
init_learner_state = SebulbaLearnerState(params, opt_states, step_keys, None, None) # type: ignore
env.close()

return learn, apply_fns, init_learner_state, learner_sharding # type: ignore
Expand Down
Loading
Loading