diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f2c74cc7e..2fa97bd18 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -6,7 +6,7 @@ jobs: linters: name: "Python ${{ matrix.python-version }} on ubuntu-latest" runs-on: ubuntu-latest - timeout-minutes: 5 + timeout-minutes: 10 strategy: matrix: diff --git a/examples/Quickstart.ipynb b/examples/Quickstart.ipynb index baf119cda..b3ceb373f 100644 --- a/examples/Quickstart.ipynb +++ b/examples/Quickstart.ipynb @@ -535,7 +535,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "id": "eWjNSGvZ7ALw" }, @@ -571,7 +571,7 @@ " )\n", "\n", " # Initialise observation with obs of all agents.\n", - " obs = env.observation_spec().generate_value()\n", + " obs = env.observation_spec.generate_value()\n", " init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)\n", "\n", " # Initialise actor params and optimiser state.\n", @@ -1111,7 +1111,8 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "mava", + "language": "python", "name": "python3" }, "language_info": { @@ -1124,12 +1125,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" - }, - "vscode": { - "interpreter": { - "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" - } + "version": "3.12.4" } }, "nbformat": 4, diff --git a/examples/advanced_usage/README.md b/examples/advanced_usage/README.md index e5da8511c..3a8b97398 100644 --- a/examples/advanced_usage/README.md +++ b/examples/advanced_usage/README.md @@ -12,7 +12,7 @@ dummy_flashbax_transition = { "observation": jnp.zeros( ( config.system.num_agents, - env.observation_spec().agents_view.shape[1], + env.observation_spec.agents_view.shape[1], ), dtype=jnp.float32, ), diff --git a/examples/advanced_usage/ff_ippo_store_experience.py b/examples/advanced_usage/ff_ippo_store_experience.py index 58d71c795..6f6417f11 100644 --- a/examples/advanced_usage/ff_ippo_store_experience.py +++ b/examples/advanced_usage/ff_ippo_store_experience.py @@ -360,7 +360,7 @@ def learner_setup( # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - action_head, _ = get_action_head(env.action_spec()) + action_head, _ = get_action_head(env.action_spec) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) @@ -377,7 +377,7 @@ def learner_setup( ) # Initialise observation with obs of all agents. - obs = env.observation_spec().generate_value() + obs = env.observation_spec.generate_value() init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs) # Initialise actor params and optimiser state. @@ -507,7 +507,7 @@ def run_experiment(_config: DictConfig) -> None: "observation": jnp.zeros( ( config.system.num_agents, - env.observation_spec().agents_view.shape[1], + env.observation_spec.agents_view.shape[1], ), dtype=jnp.float32, ), diff --git a/mava/configs/default/ff_ippo.yaml b/mava/configs/default/ff_ippo.yaml index 5b938fef5..5246a00fe 100644 --- a/mava/configs/default/ff_ippo.yaml +++ b/mava/configs/default/ff_ippo.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: ppo/ff_ippo - network: mlp # [mlp, cnn] - - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] + - env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/ff_mappo.yaml b/mava/configs/default/ff_mappo.yaml index 9f3953b00..e717ae6d1 100644 --- a/mava/configs/default/ff_mappo.yaml +++ b/mava/configs/default/ff_mappo.yaml @@ -2,8 +2,8 @@ defaults: - logger: logger - arch: anakin - system: ppo/ff_mappo - - network: mlp # [mlp, cnn] - - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] + - network: mlp # [mlp, cnn] + - env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/ff_sable.yaml b/mava/configs/default/ff_sable.yaml index 406c605ca..3be27767e 100644 --- a/mava/configs/default/ff_sable.yaml +++ b/mava/configs/default/ff_sable.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: sable/ff_sable - network: ff_retention - - env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mpe] + - env: rware # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/mat.yaml b/mava/configs/default/mat.yaml index 9e73740b2..35de665db 100644 --- a/mava/configs/default/mat.yaml +++ b/mava/configs/default/mat.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: mat/mat - network: transformer - - env: rware # [gigastep, lbf, mabrax, matrax, rware, smax, mpe] + - env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/rec_ippo.yaml b/mava/configs/default/rec_ippo.yaml index 91fbab8c9..5ece6c21c 100644 --- a/mava/configs/default/rec_ippo.yaml +++ b/mava/configs/default/rec_ippo.yaml @@ -2,8 +2,8 @@ defaults: - logger: logger - arch: anakin - system: ppo/rec_ippo - - network: rnn # [rnn, rcnn] - - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] + - network: rnn # [rnn, rcnn] + - env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/rec_iql.yaml b/mava/configs/default/rec_iql.yaml index 7d805699a..2b553ecc3 100644 --- a/mava/configs/default/rec_iql.yaml +++ b/mava/configs/default/rec_iql.yaml @@ -3,8 +3,8 @@ defaults: - logger: logger - arch: anakin - system: q_learning/rec_iql - - network: rnn # [rnn, rcnn] - - env: smax # [cleaner, connector, gigastep, lbf, matrax, rware, smax] + - network: rnn # [rnn, rcnn] + - env: smax # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax] hydra: searchpath: diff --git a/mava/configs/default/rec_mappo.yaml b/mava/configs/default/rec_mappo.yaml index d3b145f5a..a14eb6f55 100644 --- a/mava/configs/default/rec_mappo.yaml +++ b/mava/configs/default/rec_mappo.yaml @@ -2,8 +2,8 @@ defaults: - logger: logger - arch: anakin - system: ppo/rec_mappo - - network: rnn # [rnn, rcnn] - - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] + - network: rnn # [rnn, rcnn] + - env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/rec_sable.yaml b/mava/configs/default/rec_sable.yaml index 654441bd4..a49ee47fc 100644 --- a/mava/configs/default/rec_sable.yaml +++ b/mava/configs/default/rec_sable.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: sable/rec_sable - network: rec_retention - - env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mabrax, mpe] + - env: rware # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/env/connector.yaml b/mava/configs/env/connector.yaml index 1d2f8adde..5e77cac6a 100644 --- a/mava/configs/env/connector.yaml +++ b/mava/configs/env/connector.yaml @@ -4,7 +4,10 @@ defaults: - scenario: con-5x5x3a # [con-5x5x3a, con-7x7x5a, con-10x10x10a, con-15x15x23a] # Further environment config details in "con-10x10x5a" file. -env_name: MaConnector # Used for logging purposes. +env_name: Connector # Used for logging purposes. + +# Choose whether to aggregate individual rewards into a shared team reward or not. +aggregate_rewards: True # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. diff --git a/mava/configs/env/lbf.yaml b/mava/configs/env/lbf.yaml index 5952f1f70..dc130e312 100644 --- a/mava/configs/env/lbf.yaml +++ b/mava/configs/env/lbf.yaml @@ -6,8 +6,8 @@ defaults: env_name: LevelBasedForaging # Used for logging purposes. -# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True. -use_individual_rewards: False # If True, use the list of individual rewards. +# Choose whether to aggregate individual rewards into a shared team reward or not. +aggregate_rewards: True # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. diff --git a/mava/configs/env/scenario/con-10x10x10a.yaml b/mava/configs/env/scenario/con-10x10x10a.yaml index b050d60a0..9c379741d 100644 --- a/mava/configs/env/scenario/con-10x10x10a.yaml +++ b/mava/configs/env/scenario/con-10x10x10a.yaml @@ -1,5 +1,5 @@ # The config of the 10x10x10a scenario -name: MaConnector-v2 +name: Connector-v2 task_name: con-10x10x10a task_config: diff --git a/mava/configs/env/scenario/con-15x15x23a.yaml b/mava/configs/env/scenario/con-15x15x23a.yaml index 4dad2c5ca..707e3c3a7 100644 --- a/mava/configs/env/scenario/con-15x15x23a.yaml +++ b/mava/configs/env/scenario/con-15x15x23a.yaml @@ -1,5 +1,5 @@ # The config of the 15x15x23a scenario -name: MaConnector-v2 +name: Connector-v2 task_name: con-15x15x23a task_config: diff --git a/mava/configs/env/scenario/con-5x5x3a.yaml b/mava/configs/env/scenario/con-5x5x3a.yaml index a7f5daa76..4aa42c9a2 100644 --- a/mava/configs/env/scenario/con-5x5x3a.yaml +++ b/mava/configs/env/scenario/con-5x5x3a.yaml @@ -1,5 +1,5 @@ # The config of the 5x5x3a scenario -name: MaConnector-v2 +name: Connector-v2 task_name: con-5x5x3a task_config: diff --git a/mava/configs/env/scenario/con-7x7x5a.yaml b/mava/configs/env/scenario/con-7x7x5a.yaml index 40faa8839..980138756 100644 --- a/mava/configs/env/scenario/con-7x7x5a.yaml +++ b/mava/configs/env/scenario/con-7x7x5a.yaml @@ -1,5 +1,5 @@ # The config of the 7x7x5a scenario -name: MaConnector-v2 +name: Connector-v2 task_name: con-7x7x5a task_config: diff --git a/mava/configs/env/vector-connector.yaml b/mava/configs/env/vector-connector.yaml index 647ddd9b9..01f9312b5 100644 --- a/mava/configs/env/vector-connector.yaml +++ b/mava/configs/env/vector-connector.yaml @@ -4,7 +4,10 @@ defaults: - scenario: con-5x5x3a # [con-5x5x3a, con-7x7x5a, con-10x10x10a, con-15x15x23a] # Further environment config details in "con-10x10x5a" file. -env_name: VectorMaConnector # Used for logging purposes. +env_name: VectorConnector # Used for logging purposes. + +# Choose whether to aggregate individual rewards into a shared team reward or not. +aggregate_rewards: True # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. diff --git a/mava/systems/mat/anakin/mat.py b/mava/systems/mat/anakin/mat.py index 1db141d19..3d37e5205 100644 --- a/mava/systems/mat/anakin/mat.py +++ b/mava/systems/mat/anakin/mat.py @@ -319,11 +319,11 @@ def learner_setup( # PRNG keys. key, actor_net_key = keys - # Get mock inputs to initialise network. - init_x = env.observation_spec().generate_value() + # Initialise observation: Obs for all agents. + init_x = env.observation_spec.generate_value() init_x = tree.map(lambda x: x[None, ...], init_x) - _, action_space_type = get_action_head(env.action_spec()) + _, action_space_type = get_action_head(env.action_spec) if action_space_type == "discrete": init_action = jnp.zeros((1, config.system.num_agents), dtype=jnp.int32) diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index ed2943916..07ef5d39d 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -332,7 +332,7 @@ def learner_setup( # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - action_head, _ = get_action_head(env.action_spec()) + action_head, _ = get_action_head(env.action_spec) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) @@ -351,8 +351,8 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Get mock inputs to initialise network. - obs = env.observation_spec().generate_value() + # Initialise observation with obs of all agents. + obs = env.observation_spec.generate_value() init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs) # Initialise actor params and optimiser state. diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index 5e5a0006f..f511f9de7 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -334,7 +334,7 @@ def learner_setup( # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - action_head, _ = get_action_head(env.action_spec()) + action_head, _ = get_action_head(env.action_spec) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) @@ -353,8 +353,8 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Get mock inputs to initialise network. - obs = env.observation_spec().generate_value() + # Initialise observation with obs of all agents. + obs = env.observation_spec.generate_value() init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs) # Initialise actor params and optimiser state. diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index 279238382..90310fa9e 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -428,7 +428,7 @@ def learner_setup( # Define network and optimisers. actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) - action_head, _ = get_action_head(env.action_spec()) + action_head, _ = get_action_head(env.action_spec) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) @@ -457,8 +457,8 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Get mock inputs to initialise network. - init_obs = env.observation_spec().generate_value() + # Initialise observation with obs of all agents. + init_obs = env.observation_spec.generate_value() init_obs = tree.map( lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0), init_obs, diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index 96d7d74ac..e7466142a 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -430,7 +430,7 @@ def learner_setup( # Define network and optimiser. actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) - action_head, _ = get_action_head(env.action_spec()) + action_head, _ = get_action_head(env.action_spec) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) @@ -460,8 +460,8 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Get mock inputs to initialise network. - init_obs = env.observation_spec().generate_value() + # Initialise observation with obs of all agents. + init_obs = env.observation_spec.generate_value() init_obs = tree.map( lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0), init_obs, diff --git a/mava/systems/q_learning/anakin/rec_iql.py b/mava/systems/q_learning/anakin/rec_iql.py index 10ec5c3ab..c8a7b4cf4 100644 --- a/mava/systems/q_learning/anakin/rec_iql.py +++ b/mava/systems/q_learning/anakin/rec_iql.py @@ -92,7 +92,7 @@ def replicate(x: Any) -> Any: # N: Agent # Make dummy inputs to init recurrent Q network -> need shape (T, B, N, ...) - init_obs = env.observation_spec().generate_value() # (N, ...) + init_obs = env.observation_spec.generate_value() # (N, ...) # (B, T, N, ...) init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs) init_term_or_trunc = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1) @@ -130,7 +130,7 @@ def replicate(x: Any) -> Any: init_hidden_state = replicate(init_hidden_state) # Create dummy transition - init_acts = env.action_spec().generate_value() # (N,) + init_acts = env.action_spec.generate_value() # (N,) init_transition = Transition( obs=init_obs, # (N, ...) action=init_acts, diff --git a/mava/systems/q_learning/anakin/rec_qmix.py b/mava/systems/q_learning/anakin/rec_qmix.py index c2bce9796..af7bfc62b 100644 --- a/mava/systems/q_learning/anakin/rec_qmix.py +++ b/mava/systems/q_learning/anakin/rec_qmix.py @@ -94,7 +94,7 @@ def replicate(x: Any) -> Any: # N: Agent # Make dummy inputs to init recurrent Q network -> need shape (T, B, N, ...) - init_obs = env.observation_spec().generate_value() # (N, ...) + init_obs = env.observation_spec.generate_value() # (N, ...) # (B, T, N, ...) init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs) init_term_or_trunc = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1) @@ -126,7 +126,7 @@ def replicate(x: Any) -> Any: dtype=float, ) global_env_state_shape = ( - env.observation_spec().generate_value().global_state[0, :].shape + env.observation_spec.generate_value().global_state[0, :].shape ) # NOTE: Env wrapper currently duplicates env state for each agent dummy_global_env_state = jnp.zeros( ( @@ -159,7 +159,7 @@ def replicate(x: Any) -> Any: opt_state = replicate(opt_state) init_hidden_state = replicate(init_hidden_state) - init_acts = env.action_spec().generate_value() + init_acts = env.action_spec.generate_value() # NOTE: term_or_trunc refers to the the joint done, ie. when all agents are done or when the # episode horizon has been reached. We use this exclusively in QMIX. diff --git a/mava/systems/sable/anakin/ff_sable.py b/mava/systems/sable/anakin/ff_sable.py index 33547523c..043d670c0 100644 --- a/mava/systems/sable/anakin/ff_sable.py +++ b/mava/systems/sable/anakin/ff_sable.py @@ -328,17 +328,13 @@ def learner_setup( # Get available TPU cores. n_devices = len(jax.devices()) - # Get number of agents. - config.system.num_agents = env.num_agents - # PRNG keys. key, net_key = keys # Get number of agents and actions. action_dim = env.action_dim - n_agents = env.action_spec().shape[0] + n_agents = env.num_agents config.system.num_agents = n_agents - config.system.num_actions = action_dim # Setting the chunksize - many agent problems require chunking agents # Create a dummy decay factor for FF Sable @@ -353,7 +349,7 @@ def learner_setup( # Set positional encoding to False, since ff-sable does not use temporal dependencies. config.network.memory_config.timestep_positional_encoding = False - _, action_space_type = get_action_head(env.action_spec()) + _, action_space_type = get_action_head(env.action_spec) # Define network. sable_network = SableNetwork( @@ -373,7 +369,7 @@ def learner_setup( ) # Get mock inputs to initialise network. - init_obs = env.observation_spec().generate_value() + init_obs = env.observation_spec.generate_value() init_obs = tree.map(lambda x: x[jnp.newaxis, ...], init_obs) # Add batch dim init_hs = get_init_hidden_state(config.network.net_config, config.arch.num_envs) init_hs = tree.map(lambda x: x[0, jnp.newaxis], init_hs) diff --git a/mava/systems/sable/anakin/rec_sable.py b/mava/systems/sable/anakin/rec_sable.py index b1eda03f0..08b609c4a 100644 --- a/mava/systems/sable/anakin/rec_sable.py +++ b/mava/systems/sable/anakin/rec_sable.py @@ -363,9 +363,8 @@ def learner_setup( # Get number of agents and actions. action_dim = env.action_dim - n_agents = env.action_spec().shape[0] + n_agents = env.num_agents config.system.num_agents = n_agents - config.system.num_actions = action_dim # Setting the chunksize - smaller chunks save memory at the cost of speed if config.network.memory_config.timestep_chunk_size: @@ -375,7 +374,7 @@ def learner_setup( else: config.network.memory_config.chunk_size = config.system.rollout_length * n_agents - _, action_space_type = get_action_head(env.action_spec()) + _, action_space_type = get_action_head(env.action_spec) # Define network. sable_network = SableNetwork( @@ -395,7 +394,7 @@ def learner_setup( ) # Get mock inputs to initialise network. - init_obs = env.observation_spec().generate_value() + init_obs = env.observation_spec.generate_value() init_obs = tree.map(lambda x: x[jnp.newaxis, ...], init_obs) # Add batch dim init_hs = get_init_hidden_state(config.network.net_config, config.arch.num_envs) init_hs = tree.map(lambda x: x[0, jnp.newaxis], init_hs) diff --git a/mava/systems/sac/anakin/ff_hasac.py b/mava/systems/sac/anakin/ff_hasac.py index fe78dcf15..ce6255703 100644 --- a/mava/systems/sac/anakin/ff_hasac.py +++ b/mava/systems/sac/anakin/ff_hasac.py @@ -144,16 +144,16 @@ def replicate(x: Any) -> Any: key, actor_key, q1_key, q2_key, q1_target_key, q2_target_key = jax.random.split(key, 6) actor_keys = jax.random.split(actor_key, n_agents) - acts = env.action_spec().generate_value() # all agents actions + acts = env.action_spec.generate_value() # all agents actions act_single = acts[0] # single agents action concat_acts = jnp.concatenate([act_single for _ in range(n_agents)], axis=0) concat_acts_batched = concat_acts[jnp.newaxis, ...] # batch + concat of all agents actions - obs = env.observation_spec().generate_value() + obs = env.observation_spec.generate_value() obs_single_batched = tree.map(lambda x: x[0][jnp.newaxis, ...], obs) # Making actor network actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso) - action_head, _ = get_action_head(env.action_spec()) + action_head, _ = get_action_head(env.action_spec) actor_action_head = hydra.utils.instantiate( action_head, action_dim=env.action_dim, independent_std=False ) @@ -285,7 +285,7 @@ def make_update_fns( actor_net, q_net = networks actor_opt, q_opt, alpha_opt = optims - full_action_shape = (cfg.arch.num_envs, *env.action_spec().shape) + full_action_shape = (cfg.arch.num_envs, *env.action_spec.shape) # losses: def q_loss_fn( diff --git a/mava/systems/sac/anakin/ff_isac.py b/mava/systems/sac/anakin/ff_isac.py index 222654527..25cba2e8c 100644 --- a/mava/systems/sac/anakin/ff_isac.py +++ b/mava/systems/sac/anakin/ff_isac.py @@ -104,14 +104,14 @@ def replicate(x: Any) -> Any: key, actor_key, q1_key, q2_key, q1_target_key, q2_target_key = jax.random.split(key, 6) - acts = env.action_spec().generate_value() # all agents actions + acts = env.action_spec.generate_value() # all agents actions act_single_batched = acts[0][jnp.newaxis, ...] # batch single agent action - obs = env.observation_spec().generate_value() + obs = env.observation_spec.generate_value() obs_single_batched = tree.map(lambda x: x[0][jnp.newaxis, ...], obs) # Making actor network actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso) - action_head, _ = get_action_head(env.action_spec()) + action_head, _ = get_action_head(env.action_spec) actor_action_head = hydra.utils.instantiate( action_head, action_dim=env.action_dim, independent_std=False ) @@ -242,7 +242,7 @@ def make_update_fns( actor_net, q_net = networks actor_opt, q_opt, alpha_opt = optims - full_action_shape = (cfg.arch.num_envs, *env.action_spec().shape) + full_action_shape = (cfg.arch.num_envs, *env.action_spec.shape) # losses: def q_loss_fn( diff --git a/mava/systems/sac/anakin/ff_masac.py b/mava/systems/sac/anakin/ff_masac.py index 4e96cdbf3..15519b87e 100644 --- a/mava/systems/sac/anakin/ff_masac.py +++ b/mava/systems/sac/anakin/ff_masac.py @@ -105,16 +105,16 @@ def replicate(x: Any) -> Any: key, actor_key, q1_key, q2_key, q1_target_key, q2_target_key = jax.random.split(key, 6) - acts = env.action_spec().generate_value() # all agents actions + acts = env.action_spec.generate_value() # all agents actions act_single = acts[0] # single agents action joint_acts = jnp.concatenate([act_single for _ in range(n_agents)], axis=0) joint_acts_batched = joint_acts[jnp.newaxis, ...] # joint actions with a batch dim - obs = env.observation_spec().generate_value() + obs = env.observation_spec.generate_value() obs_single_batched = tree.map(lambda x: x[0][jnp.newaxis, ...], obs) # Making actor network actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso) - action_head, _ = get_action_head(env.action_spec()) + action_head, _ = get_action_head(env.action_spec) actor_action_head = hydra.utils.instantiate( action_head, action_dim=env.action_dim, independent_std=False ) @@ -245,7 +245,7 @@ def make_update_fns( actor_net, q_net = networks actor_opt, q_opt, alpha_opt = optims - full_action_shape = (cfg.arch.num_envs, *env.action_spec().shape) + full_action_shape = (cfg.arch.num_envs, *env.action_spec.shape) # losses: def q_loss_fn( diff --git a/mava/types.py b/mava/types.py index d60175f50..26e54e948 100644 --- a/mava/types.py +++ b/mava/types.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Any, Callable, Dict, Generic, Optional, Protocol, Tuple, TypeVar, Union import chex @@ -67,6 +68,7 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: """ ... + @cached_property def observation_spec(self) -> specs.Spec: """Returns the observation spec. @@ -75,6 +77,7 @@ def observation_spec(self) -> specs.Spec: """ ... + @cached_property def action_spec(self) -> specs.Spec: """Returns the action spec. @@ -83,6 +86,7 @@ def action_spec(self) -> specs.Spec: """ ... + @cached_property def reward_spec(self) -> specs.Array: """Describes the reward returned by the environment. By default, this is assumed to be a single float. @@ -92,6 +96,7 @@ def reward_spec(self) -> specs.Array: """ ... + @cached_property def discount_spec(self) -> specs.BoundedArray: """Describes the discount returned by the environment. By default, this is assumed to be a single float between 0 and 1. diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 0a56367c8..26c2fb3bc 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -64,8 +64,8 @@ _jumanji_registry = { "RobotWarehouse": {"generator": RwareRandomGenerator, "wrapper": RwareWrapper}, "LevelBasedForaging": {"generator": LbfRandomGenerator, "wrapper": LbfWrapper}, - "MaConnector": {"generator": ConnectorRandomGenerator, "wrapper": ConnectorWrapper}, - "VectorMaConnector": { + "Connector": {"generator": ConnectorRandomGenerator, "wrapper": ConnectorWrapper}, + "VectorConnector": { "generator": ConnectorRandomGenerator, "wrapper": VectorConnectorWrapper, }, diff --git a/mava/wrappers/gigastep.py b/mava/wrappers/gigastep.py index f395e0536..fc8125cb7 100644 --- a/mava/wrappers/gigastep.py +++ b/mava/wrappers/gigastep.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import TYPE_CHECKING, Dict, Tuple, Union import jax @@ -56,6 +57,10 @@ def __init__( has_global_state (bool): Whether the environment has a global state. Defaults to False. """ + self.has_global_state = has_global_state + self.time_limit = env.max_episode_length + self.num_agents = env.n_agents_team1 + self.action_dim = env.n_actions super().__init__(env) assert ( env.discrete_actions @@ -65,10 +70,6 @@ def __init__( ), "Only Vector observations are currently supported for Gigastep environments" self._env: GigastepEnv - self.time_limit = self._env.max_episode_length - self.num_agents = self._env.n_agents_team1 - self.action_dim = self._env.n_actions - self.has_global_state = has_global_state def reset(self, key: PRNGKey) -> Tuple[GigastepState, TimeStep]: """Reset the Gigastep environment. @@ -184,6 +185,7 @@ def get_global_state(self, obs: Array) -> Array: global_obs = jnp.concatenate(obs, axis=0) return jnp.tile(global_obs, (self.num_agents, 1)) + @cached_property def observation_spec(self) -> specs.Spec: agents_view = specs.BoundedArray( (self.num_agents, *self._env.observation_space.shape), @@ -223,12 +225,15 @@ def observation_spec(self) -> specs.Spec: step_count=step_count, ) + @cached_property def action_spec(self) -> specs.Spec: return specs.MultiDiscreteArray(num_values=jnp.full(self.num_agents, self.action_dim)) + @cached_property def reward_spec(self) -> specs.Array: return specs.Array(shape=(self.num_agents,), dtype=float, name="reward") + @cached_property def discount_spec(self) -> specs.BoundedArray: return specs.BoundedArray( shape=(self.num_agents,), dtype=float, minimum=0.0, maximum=1.0, name="discount" diff --git a/mava/wrappers/jaxmarl.py b/mava/wrappers/jaxmarl.py index 23322fc0c..ebb6d6b7b 100644 --- a/mava/wrappers/jaxmarl.py +++ b/mava/wrappers/jaxmarl.py @@ -189,11 +189,11 @@ def __init__( # Making sure the child envs set this correctly. assert time_limit > 0, f"Time limit must be greater than 0, got {time_limit}" + self.has_global_state = has_global_state + self.time_limit = time_limit super().__init__(env) self._env: MultiAgentEnv self.agents = self._env.agents - self.has_global_state = has_global_state - self.time_limit = time_limit self.num_agents = self._env.num_agents # Calling these on init to cache the values in a non-jitted context. @@ -251,6 +251,7 @@ def _create_observation( return Observation(**obs_data) + @cached_property def observation_spec(self) -> specs.Spec: agents_view = jaxmarl_space_to_jumanji_spec(merge_space(self._env.observation_spaces)) @@ -285,12 +286,15 @@ def observation_spec(self) -> specs.Spec: step_count=step_count, ) + @cached_property def action_spec(self) -> specs.Spec: return jaxmarl_space_to_jumanji_spec(merge_space(self._env.action_spaces)) + @cached_property def reward_spec(self) -> specs.Array: return specs.Array(shape=(self.num_agents,), dtype=float, name="reward") + @cached_property def discount_spec(self) -> specs.BoundedArray: return specs.BoundedArray( shape=(self.num_agents,), diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 5716d5557..443e37d38 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -24,8 +24,9 @@ from jumanji.env import Environment from jumanji.environments.routing.cleaner import Cleaner from jumanji.environments.routing.cleaner.constants import DIRTY, WALL -from jumanji.environments.routing.connector import MaConnector +from jumanji.environments.routing.connector import Connector from jumanji.environments.routing.connector.constants import ( + AGENT_INITIAL_VALUE, EMPTY, PATH, POSITION, @@ -39,12 +40,18 @@ from mava.types import Observation, ObservationGlobalState, State +def aggregate_rewards(reward: chex.Array, num_agents: int) -> chex.Array: + """Aggregate individual rewards across agents.""" + team_reward = jnp.sum(reward) + return jnp.repeat(team_reward, num_agents) + + class JumanjiMarlWrapper(Wrapper, ABC): def __init__(self, env: Environment, add_global_state: bool): + self.add_global_state = add_global_state super().__init__(env) self.num_agents = self._env.num_agents self.time_limit = self._env.time_limit - self.add_global_state = add_global_state @abstractmethod def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: @@ -91,6 +98,7 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: return state, timestep + @cached_property def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalState]]: """Specification of the observation of the environment.""" step_count = specs.BoundedArray( @@ -101,7 +109,7 @@ def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalSta "step_count", ) - obs_spec = self._env.observation_spec() + obs_spec = self._env.observation_spec obs_data = { "agents_view": obs_spec.agents_view, "action_mask": obs_spec.action_mask, @@ -123,7 +131,7 @@ def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalSta @cached_property def action_dim(self) -> chex.Array: """Get the actions dim for each agent.""" - return int(self._env.action_spec().num_values[0]) + return int(self._env.action_spec.num_values[0]) class RwareWrapper(JumanjiMarlWrapper): @@ -144,11 +152,12 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: discount = jnp.repeat(timestep.discount, self.num_agents) return timestep.replace(observation=observation, reward=reward, discount=discount) + @cached_property def observation_spec( self, ) -> specs.Spec[Union[Observation, ObservationGlobalState]]: # need to cast the agents view and global state to floats as we do in modify timestep - inner_spec = super().observation_spec() + inner_spec = super().observation_spec spec = inner_spec.replace(agents_view=inner_spec.agents_view.replace(dtype=float)) if self.add_global_state: spec = spec.replace(global_state=inner_spec.global_state.replace(dtype=float)) @@ -171,21 +180,11 @@ def __init__( self, env: LevelBasedForaging, add_global_state: bool = False, - use_individual_rewards: bool = False, + aggregate_rewards: bool = True, ): super().__init__(env, add_global_state) self._env: LevelBasedForaging - self._use_individual_rewards = use_individual_rewards - - def aggregate_rewards( - self, timestep: TimeStep, observation: Observation - ) -> TimeStep[Observation]: - """Aggregate individual rewards across agents.""" - team_reward = jnp.sum(timestep.reward) - - # Repeat the aggregated reward for each agent. - reward = jnp.repeat(team_reward, self.num_agents) - return timestep.replace(observation=observation, reward=reward) + self._aggregate_rewards = aggregate_rewards def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: """Modify the timestep for Level-Based Foraging environment and update @@ -197,18 +196,19 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: action_mask=timestep.observation.action_mask, step_count=jnp.repeat(timestep.observation.step_count, self.num_agents), ) - if self._use_individual_rewards: - # The environment returns a list of individual rewards and these are used as is. - return timestep.replace(observation=modified_observation) + # Whether or not aggregate the list of individual rewards. + reward = timestep.reward + if self._aggregate_rewards: + reward = aggregate_rewards(reward, self.num_agents) - # Aggregate the list of individual rewards and use a single team_reward. - return self.aggregate_rewards(timestep, modified_observation) + return timestep.replace(observation=modified_observation, reward=reward) + @cached_property def observation_spec( self, ) -> specs.Spec[Union[Observation, ObservationGlobalState]]: # need to cast the agents view and global state to floats as we do in modify timestep - inner_spec = super().observation_spec() + inner_spec = super().observation_spec spec = inner_spec.replace(agents_view=inner_spec.agents_view.replace(dtype=float)) if self.add_global_state: spec = spec.replace(global_state=inner_spec.global_state.replace(dtype=float)) @@ -216,21 +216,49 @@ def observation_spec( return spec +def switch_perspective(grid: chex.Array, agent_id: int, num_agents: int) -> chex.Array: + """ + Encodes the observation with respect to the current agent defined by `agent_id`. + Each agent sees its observations as values `1, 2, 3`. Observations of other agents + are shifted cyclically based on their relative position. The mapping is designed + such that the ordering of observations remains consistent. + For example,in a 3-agent game, if we wanted to switch to agent 1's perspective, then: + agent 1s values will change from 4,5,6 -> 1,2,3 + agent 2s values will change from 7,8,9 -> 4,5,6 + agent 0s values will change from 1,2,3 -> 7,8,9 + Agent 0 will be passed observations where it is represented by the values 1,2,3. Agent 1 + will be passed observations where it is represented by the values 1,2,3. However in the + state agent 0 will always be 1,2,3 and agent 1 will always be 4,5,6.""" + new_grid = grid - AGENT_INITIAL_VALUE # Center agent values around 0 + new_grid -= 3 * agent_id # Move the obs + new_grid %= 3 * num_agents # Keep obs in bounds + new_grid += AGENT_INITIAL_VALUE # 'Un-center' agent obs around 0 + # Take agent values from rotated grid and empty values from old grid + return jnp.where((grid >= AGENT_INITIAL_VALUE), new_grid, grid) + + class ConnectorWrapper(JumanjiMarlWrapper): """Multi-agent wrapper for the MA Connector environment. Do not use the AgentID wrapper with this env, it has implicit agent IDs. """ - def __init__(self, env: MaConnector, add_global_state: bool = False): + def __init__( + self, env: Connector, add_global_state: bool = False, aggregate_rewards: bool = True + ): super().__init__(env, add_global_state) - self._env: MaConnector + self._env: Connector + self._aggregate_rewards = aggregate_rewards + self.agent_ids = jnp.arange(self.num_agents) def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: """Modify the timestep for the Connector environment.""" # TARGET = 3 = The number of different types of items on the grid. def create_agents_view(grid: chex.Array) -> chex.Array: + grid = jax.vmap(switch_perspective, in_axes=(None, 0, None))( + grid, self.agent_ids, self.num_agents + ) # Mark position and target of each agent with that agent's normalized index. positions = ( jnp.where(grid % TARGET == POSITION, jnp.ceil(grid / TARGET), 0) / self.num_agents @@ -247,16 +275,6 @@ def create_agents_view(grid: chex.Array) -> chex.Array: ) return agents_view - def aggregate_rewards( - timestep: TimeStep, - ) -> TimeStep[Observation]: - """Aggregate individual rewards and discounts across agents.""" - team_reward = jnp.sum(timestep.reward) - reward = jnp.repeat(team_reward, self.num_agents) - return timestep.replace(reward=reward) - - timestep = aggregate_rewards(timestep) - obs_data = { "agents_view": create_agents_view(timestep.observation.grid), "action_mask": timestep.observation.action_mask, @@ -266,7 +284,11 @@ def aggregate_rewards( # The episode is won if all agents have connected. extras = timestep.extras | {"won_episode": timestep.extras["ratio_connections"] == 1.0} - return timestep.replace(observation=Observation(**obs_data), extras=extras) + # Whether or not aggregate the list of individual rewards. + reward = timestep.reward + if self._aggregate_rewards: + reward = aggregate_rewards(reward, self.num_agents) + return timestep.replace(observation=Observation(**obs_data), reward=reward, extras=extras) def get_global_state(self, obs: Observation) -> chex.Array: """Constructs the global state from the global information @@ -274,6 +296,7 @@ def get_global_state(self, obs: Observation) -> chex.Array: """ return jnp.tile(obs.agents_view[..., :3][0], (obs.agents_view.shape[0], 1, 1, 1)) + @cached_property def observation_spec( self, ) -> specs.Spec[Union[Observation, ObservationGlobalState]]: @@ -294,10 +317,9 @@ def observation_spec( ) obs_data = { "agents_view": agents_view, - "action_mask": self._env.observation_spec().action_mask, + "action_mask": self._env.observation_spec.action_mask, "step_count": step_count, } - if self.add_global_state: global_state = specs.BoundedArray( shape=(self._env.num_agents, self._env.grid_size, self._env.grid_size, 3), @@ -335,23 +357,30 @@ def _get_location(grid: chex.Array) -> chex.Array: class VectorConnectorWrapper(JumanjiMarlWrapper): - """Multi-agent wrapper for the MaConnector environment. + """Multi-agent wrapper for the Connector environment. This wrapper transforms the grid-based observation to a vector of features. This env should have the AgentID wrapper applied to it since there is not longer a channel that can encode AgentID information. """ - def __init__(self, env: MaConnector, add_global_state: bool = False): - super().__init__(env, add_global_state) - self._env: MaConnector + def __init__( + self, env: Connector, add_global_state: bool = False, aggregate_rewards: bool = True + ): self.fov = 2 + super().__init__(env, add_global_state) + self._env: Connector + self._aggregate_rewards = aggregate_rewards + self.agent_ids = jnp.arange(self.num_agents) def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: """Modify the timestep for the Connector environment.""" # TARGET = 3 = The number of different types of items on the grid. def create_agents_view(grid: chex.Array) -> chex.Array: + grid = jax.vmap(switch_perspective, in_axes=(None, 0, None))( + grid, self.agent_ids, self.num_agents + ) positions = jnp.where(grid % TARGET == POSITION, True, False) targets = jnp.where((grid % TARGET == 0) & (grid != EMPTY), True, False) paths = jnp.where(grid % TARGET == PATH, True, False) @@ -407,8 +436,12 @@ def _create_one_agent_view(i: int) -> chex.Array: # The episode is won if all agents have connected. extras = timestep.extras | {"won_episode": timestep.extras["ratio_connections"] == 1.0} - return timestep.replace(observation=Observation(**obs_data), extras=extras) + reward = timestep.reward + if self._aggregate_rewards: + reward = aggregate_rewards(reward, self.num_agents) + return timestep.replace(observation=Observation(**obs_data), reward=reward, extras=extras) + @cached_property def observation_spec( self, ) -> specs.Spec[Union[Observation, ObservationGlobalState]]: @@ -420,7 +453,6 @@ def observation_spec( jnp.repeat(self.time_limit, self.num_agents), "step_count", ) - # 2 sets of tiles in fov (blockers and targets) + xy position of agent and target tiles_in_fov = (self.fov * 2 + 1) ** 2 single_agent_obs = 4 + tiles_in_fov * 2 @@ -431,13 +463,11 @@ def observation_spec( minimum=-1.0, maximum=1.0, ) - obs_data = { "agents_view": agents_view, - "action_mask": self._env.observation_spec().action_mask, + "action_mask": self._env.observation_spec.action_mask, "step_count": step_count, } - if self.add_global_state: global_state = specs.BoundedArray( shape=(self.num_agents, self.num_agents * single_agent_obs), @@ -521,6 +551,7 @@ def get_global_state(self, obs: Observation) -> chex.Array: """ return obs.agents_view[..., :3] # (A, R, C, 3) + @cached_property def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalState]]: """Specification of the observation of the environment.""" step_count = specs.BoundedArray( @@ -539,7 +570,7 @@ def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalSta ) obs_data = { "agents_view": agents_view, - "action_mask": self._env.observation_spec().action_mask, + "action_mask": self._env.observation_spec.action_mask, "step_count": step_count, } if self.add_global_state: diff --git a/mava/wrappers/matrax.py b/mava/wrappers/matrax.py index aec95ed16..cb75d2163 100644 --- a/mava/wrappers/matrax.py +++ b/mava/wrappers/matrax.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Tuple, Union import chex @@ -29,6 +30,7 @@ class MatraxWrapper(Wrapper): """Multi-agent wrapper for the Matrax environment.""" def __init__(self, env: Environment, add_global_state: bool): + self.add_global_state = add_global_state super().__init__(env) self._env: MatrixGame @@ -36,7 +38,6 @@ def __init__(self, env: Environment, add_global_state: bool): self.action_dim = self._env.num_actions self.time_limit = self._env.time_limit self.action_mask = jnp.ones((self.num_agents, self.num_actions), dtype=bool) - self.add_global_state = add_global_state def modify_timestep( self, timestep: TimeStep @@ -65,6 +66,7 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: state, timestep = self._env.step(state, action) return state, self.modify_timestep(timestep) + @cached_property def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalState]]: """Specification of the observation of the environment.""" step_count = specs.BoundedArray( @@ -79,7 +81,7 @@ def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalSta bool, "action_mask", ) - obs_spec = self._env.observation_spec() + obs_spec = self._env.observation_spec obs_data = { "agents_view": obs_spec.agent_obs, "action_mask": action_mask, diff --git a/mava/wrappers/observation.py b/mava/wrappers/observation.py index 7e094a6a8..11ed774d4 100644 --- a/mava/wrappers/observation.py +++ b/mava/wrappers/observation.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Tuple, Union import chex @@ -70,11 +71,12 @@ def step( return state, timestep + @cached_property def observation_spec( self, ) -> Union[specs.Spec[Observation], specs.Spec[ObservationGlobalState]]: """Specification of the observation of the selected environment.""" - obs_spec = self._env.observation_spec() + obs_spec = self._env.observation_spec num_obs_features = obs_spec.agents_view.shape[-1] + self._env.num_agents dtype = obs_spec.agents_view.dtype agents_view = specs.Array((self._env.num_agents, num_obs_features), dtype, "agents_view") diff --git a/requirements/requirements.txt b/requirements/requirements.txt index e004a3c23..ce6dad462 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,9 +10,9 @@ id-marl-eval @ git+https://github.com/instadeepai/marl-eval jax==0.4.30 jaxlib==0.4.30 jaxmarl @ git+https://github.com/RuanJohn/JaxMARL@unpin-jax # This only unpins the version of Jax. -jumanji @ git+https://github.com/sash-a/jumanji@old_jumanji # Includes a few extra MARL envs +jumanji>= 1.1.0 lbforaging -matrax @ git+https://github.com/instadeepai/matrax@4c5d8aa97214848ea659274f16c48918c13e845b +matrax>= 0.0.5 mujoco==3.1.3 mujoco-mjx==3.1.3 neptune