From 61c0a22ff7749b74f785cad297e3c0386cd2ff10 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 14 Dec 2024 16:28:24 -0800 Subject: [PATCH 1/7] parallel collection --- benchmarl/conf/experiment/base_experiment.yaml | 7 +++++-- benchmarl/environments/common.py | 2 +- benchmarl/environments/magent/common.py | 9 +++++---- benchmarl/environments/meltingpot/common.py | 6 ++++-- benchmarl/environments/pettingzoo/common.py | 7 ++++--- benchmarl/environments/smacv2/common.py | 5 +++-- benchmarl/environments/vmas/common.py | 5 +++-- benchmarl/experiment/experiment.py | 9 +++++++-- test/conftest.py | 1 + test/test_pettingzoo.py | 3 +++ 10 files changed, 36 insertions(+), 18 deletions(-) diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index 05d244e6..58d30bf9 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -15,6 +15,9 @@ share_policy_params: True prefer_continuous_actions: True # If False collection is done using a collector (under no grad). If True, collection is done with gradients. collect_with_grad: False +# In case of non-vectorized environments, weather to run collection of multiple processes +# If this is used, there will be n_envs_per_worker processes, collecting frames_per_batch/n_envs_per_worker each +parallel_collection: False # Discount factor gamma: 0.9 @@ -51,7 +54,7 @@ max_n_frames: 3_000_000 on_policy_collected_frames_per_batch: 6000 # Number of environments used for collection # If the environment is vectorized, this will be the number of batched environments. -# Otherwise batching will be simulated and each env will be run sequentially. +# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection. on_policy_n_envs_per_worker: 10 # This is the number of times collected_frames_per_batch will be split into minibatches and trained on_policy_n_minibatch_iters: 45 @@ -63,7 +66,7 @@ on_policy_minibatch_size: 400 off_policy_collected_frames_per_batch: 6000 # Number of environments used for collection # If the environment is vectorized, this will be the number of batched environments. -# Otherwise batching will be simulated and each env will be run sequentially. +# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection. off_policy_n_envs_per_worker: 10 # This is the number of times off_policy_train_batch_size will be sampled from the buffer and trained over. off_policy_n_optimizer_steps: 1000 diff --git a/benchmarl/environments/common.py b/benchmarl/environments/common.py index 63ac150e..cccccdfb 100644 --- a/benchmarl/environments/common.py +++ b/benchmarl/environments/common.py @@ -34,7 +34,7 @@ def _type_check_task_config( else: if warn_on_missing_dataclass: warnings.warn( - "TaskConfig python dataclass not foud, task is being loaded without type checks" + "TaskConfig python dataclass not found, task is being loaded without type checks" ) return config diff --git a/benchmarl/environments/magent/common.py b/benchmarl/environments/magent/common.py index b8964ddd..08cae772 100644 --- a/benchmarl/environments/magent/common.py +++ b/benchmarl/environments/magent/common.py @@ -3,7 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # - +import copy from typing import Callable, Dict, List, Optional from torchrl.data import Composite @@ -31,9 +31,10 @@ def get_env_fun( seed: Optional[int], device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: + config = copy.deepcopy(self.config) return lambda: PettingZooWrapper( - env=self.__get_env(), + env=self.__get_env(config), return_state=True, seed=seed, done_on_any=False, @@ -41,7 +42,7 @@ def get_env_fun( device=device, ) - def __get_env(self) -> EnvBase: + def __get_env(self, config) -> EnvBase: try: from magent2.environments import ( adversarial_pursuit_v4, @@ -66,7 +67,7 @@ def __get_env(self) -> EnvBase: } if self.name not in envs: raise Exception(f"{self.name} is not an environment of MAgent2") - return envs[self.name].parallel_env(**self.config, render_mode="rgb_array") + return envs[self.name].parallel_env(**config, render_mode="rgb_array") def supports_continuous_actions(self) -> bool: return False diff --git a/benchmarl/environments/meltingpot/common.py b/benchmarl/environments/meltingpot/common.py index 848ceaa3..57423b13 100644 --- a/benchmarl/environments/meltingpot/common.py +++ b/benchmarl/environments/meltingpot/common.py @@ -3,7 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # - +import copy from typing import Callable, Dict, List, Optional import torch @@ -84,11 +84,13 @@ def get_env_fun( ) -> Callable[[], EnvBase]: from torchrl.envs.libs.meltingpot import MeltingpotEnv + config = copy.deepcopy(self.config) + return lambda: MeltingpotEnv( substrate=self.name.lower(), categorical_actions=True, device=device, - **self.config, + **config, ) def supports_continuous_actions(self) -> bool: diff --git a/benchmarl/environments/pettingzoo/common.py b/benchmarl/environments/pettingzoo/common.py index 1bb3cd15..f6078c99 100644 --- a/benchmarl/environments/pettingzoo/common.py +++ b/benchmarl/environments/pettingzoo/common.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. # +import copy from typing import Callable, Dict, List, Optional from torchrl.data import Composite @@ -35,9 +36,9 @@ def get_env_fun( seed: Optional[int], device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: + config = copy.deepcopy(self.config) if self.supports_continuous_actions() and self.supports_discrete_actions(): - self.config.update({"continuous_actions": continuous_actions}) - + config.update({"continuous_actions": continuous_actions}) return lambda: PettingZooEnv( categorical_actions=True, device=device, @@ -45,7 +46,7 @@ def get_env_fun( parallel=True, return_state=self.has_state(), render_mode="rgb_array", - **self.config + **config ) def supports_continuous_actions(self) -> bool: diff --git a/benchmarl/environments/smacv2/common.py b/benchmarl/environments/smacv2/common.py index dc87f6b7..08b972a1 100644 --- a/benchmarl/environments/smacv2/common.py +++ b/benchmarl/environments/smacv2/common.py @@ -3,7 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # - +import copy from typing import Callable, Dict, List, Optional import torch @@ -42,8 +42,9 @@ def get_env_fun( seed: Optional[int], device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: + config = copy.deepcopy(self.config) return lambda: SMACv2Env( - categorical_actions=True, seed=seed, device=device, **self.config + categorical_actions=True, seed=seed, device=device, **config ) def supports_continuous_actions(self) -> bool: diff --git a/benchmarl/environments/vmas/common.py b/benchmarl/environments/vmas/common.py index 9c648045..dd77a249 100644 --- a/benchmarl/environments/vmas/common.py +++ b/benchmarl/environments/vmas/common.py @@ -3,7 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # - +import copy from typing import Callable, Dict, List, Optional from torchrl.data import Composite @@ -52,6 +52,7 @@ def get_env_fun( seed: Optional[int], device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: + config = copy.deepcopy(self.config) return lambda: VmasEnv( scenario=self.name.lower(), num_envs=num_envs, @@ -60,7 +61,7 @@ def get_env_fun( device=device, categorical_actions=True, clamp_actions=True, - **self.config, + **config, ) def supports_continuous_actions(self) -> bool: diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index bd72a717..777c18e9 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -19,8 +19,9 @@ import torch from tensordict import TensorDictBase from tensordict.nn import TensorDictSequential + from torchrl.collectors import SyncDataCollector -from torchrl.envs import SerialEnv, TransformedEnv +from torchrl.envs import ParallelEnv, SerialEnv, TransformedEnv from torchrl.envs.transforms import Compose from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.record.loggers import generate_exp_name @@ -58,6 +59,7 @@ class ExperimentConfig: share_policy_params: bool = MISSING prefer_continuous_actions: bool = MISSING collect_with_grad: bool = MISSING + parallel_collection: bool = MISSING gamma: float = MISSING lr: float = MISSING @@ -435,8 +437,11 @@ def _setup_task(self): transforms_training = Compose(*transforms_training) if test_env.batch_size == (): + env_class = ( + SerialEnv if not self.config.parallel_collection else ParallelEnv + ) self.env_func = lambda: TransformedEnv( - SerialEnv(self.config.n_envs_per_worker(self.on_policy), env_func), + env_class(self.config.n_envs_per_worker(self.on_policy), env_func), transforms_training.clone(), ) else: diff --git a/test/conftest.py b/test/conftest.py index 3f53416e..5ce3bde8 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -29,6 +29,7 @@ def experiment_config(tmp_path) -> ExperimentConfig: experiment_config.on_policy_n_envs_per_worker = ( experiment_config.off_policy_n_envs_per_worker ) = 2 + experiment_config.parallel_collection = False experiment_config.off_policy_n_optimizer_steps = 2 experiment_config.off_policy_train_batch_size = 3 experiment_config.off_policy_memory_size = 200 diff --git a/test/test_pettingzoo.py b/test/test_pettingzoo.py index a726c017..8f9169a4 100644 --- a/test/test_pettingzoo.py +++ b/test/test_pettingzoo.py @@ -68,13 +68,16 @@ def test_all_algos( @pytest.mark.parametrize("algo_config", [IppoConfig, MasacConfig]) @pytest.mark.parametrize("task", list(PettingZooTask)) + @pytest.mark.parametrize("parallel_collection", [True, False]) def test_all_tasks( self, algo_config: AlgorithmConfig, task: Task, + parallel_collection, experiment_config, mlp_sequence_config, ): + experiment_config.parallel_collection = parallel_collection task = task.get_from_yaml() experiment = Experiment( algorithm_config=algo_config.get_from_yaml(), From 267440027b8493085e34f0d05a7e99c5eb7fe08a Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 14 Dec 2024 16:30:42 -0800 Subject: [PATCH 2/7] parallel collection --- benchmarl/conf/experiment/base_experiment.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index 58d30bf9..014aae5f 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -16,7 +16,7 @@ prefer_continuous_actions: True # If False collection is done using a collector (under no grad). If True, collection is done with gradients. collect_with_grad: False # In case of non-vectorized environments, weather to run collection of multiple processes -# If this is used, there will be n_envs_per_worker processes, collecting frames_per_batch/n_envs_per_worker each +# If this is used, there will be n_envs_per_worker processes, collecting frames_per_batch/n_envs_per_worker frames each parallel_collection: False # Discount factor From bafbe86bbd9ad2ed143cdff33137c6c8a9033cf9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 14 Dec 2024 18:40:24 -0800 Subject: [PATCH 3/7] parallel collection --- test/test_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_task.py b/test/test_task.py index fc1660f6..e6c65c0b 100644 --- a/test/test_task.py +++ b/test/test_task.py @@ -34,7 +34,7 @@ def test_loading_tasks(task_name): task_name_hydra = cfg.hydra.runtime.choices.task assert task_name_hydra == task_name - warn_message = "TaskConfig python dataclass not foud, task is being loaded without type checks" + warn_message = "TaskConfig python dataclass not found, task is being loaded without type checks" with ( pytest.warns(match=warn_message) From a9db46b65b953828e4be1b4bae90ed4233895731 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 16 Dec 2024 20:28:24 -0800 Subject: [PATCH 4/7] fixes --- benchmarl/algorithms/common.py | 42 +----------- benchmarl/experiment/experiment.py | 100 +++++++++++++++++++++++------ test/test_pettingzoo.py | 19 ++++-- 3 files changed, 97 insertions(+), 64 deletions(-) diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index 87e92bbd..96fba1ab 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -13,21 +13,13 @@ from tensordict.nn import TensorDictModule, TensorDictSequential from torchrl.data import ( Categorical, - Composite, LazyTensorStorage, OneHot, ReplayBuffer, TensorDictReplayBuffer, ) from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement -from torchrl.envs import ( - Compose, - EnvBase, - InitTracker, - TensorDictPrimer, - Transform, - TransformedEnv, -) +from torchrl.envs import Compose, EnvBase, Transform from torchrl.objectives import LossModule from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater @@ -251,38 +243,6 @@ def process_env_fun( Returns: a function that takes no args and creates an enviornment """ - if self.has_rnn: - - def model_fun(): - env = env_fun() - - spec_actor = self.model_config.get_model_state_spec() - spec_actor = Composite( - { - group: Composite( - spec_actor.expand(len(agents), *spec_actor.shape), - shape=(len(agents),), - ) - for group, agents in self.group_map.items() - } - ) - - env = TransformedEnv( - env, - Compose( - *( - [InitTracker(init_key="is_init")] - + ( - [TensorDictPrimer(spec_actor, reset_key="_reset")] - if len(spec_actor.keys(True, True)) > 0 - else [] - ) - ) - ), - ) - return env - - return model_fun return env_fun diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 777c18e9..212e3c99 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -14,14 +14,23 @@ from collections import deque, OrderedDict from dataclasses import dataclass, MISSING from pathlib import Path -from typing import Any, Dict, List, Optional + +from typing import Any, Callable, Dict, List, Optional import torch from tensordict import TensorDictBase from tensordict.nn import TensorDictSequential - from torchrl.collectors import SyncDataCollector -from torchrl.envs import ParallelEnv, SerialEnv, TransformedEnv + +from torchrl.data import Composite +from torchrl.envs import ( + EnvBase, + InitTracker, + ParallelEnv, + SerialEnv, + TensorDictPrimer, + TransformedEnv, +) from torchrl.envs.transforms import Compose from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.record.loggers import generate_exp_name @@ -361,6 +370,7 @@ def _setup(self): self._setup_task() self._setup_algorithm() self._setup_collector() + self._setup_buffers() self._setup_name() self._setup_logger() self._on_setup() @@ -436,7 +446,21 @@ def _setup_task(self): transforms_env = Compose(*transforms_env) transforms_training = Compose(*transforms_training) + self.observation_spec = self.task.observation_spec(test_env) + self.info_spec = self.task.info_spec(test_env) + self.state_spec = self.task.state_spec(test_env) + self.action_mask_spec = self.task.action_mask_spec(test_env) + self.action_spec = self.task.action_spec(test_env) + self.group_map = self.task.group_map(test_env) + self.train_group_map = copy.deepcopy(self.group_map) + self.max_steps = self.task.max_steps(test_env) + + if self.model_config.is_rnn: + test_env = self._add_rnn_transforms(lambda: test_env)() + env_func = self._add_rnn_transforms(env_func) + if test_env.batch_size == (): + # If the environment is not vectorized, we simulate vectorization using parallel or serial environments env_class = ( SerialEnv if not self.config.parallel_collection else ParallelEnv ) @@ -453,28 +477,12 @@ def _setup_task(self): self.config.sampling_device ) - self.observation_spec = self.task.observation_spec(self.test_env) - self.info_spec = self.task.info_spec(self.test_env) - self.state_spec = self.task.state_spec(self.test_env) - self.action_mask_spec = self.task.action_mask_spec(self.test_env) - self.action_spec = self.task.action_spec(self.test_env) - self.group_map = self.task.group_map(self.test_env) - self.train_group_map = copy.deepcopy(self.group_map) - self.max_steps = self.task.max_steps(self.test_env) - def _setup_algorithm(self): self.algorithm = self.algorithm_config.get_algorithm(experiment=self) self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)() self.env_func = self.algorithm.process_env_fun(self.env_func) - self.replay_buffers = { - group: self.algorithm.get_replay_buffer( - group=group, - transforms=self.task.get_replay_buffer_transforms(self.test_env, group), - ) - for group in self.group_map.keys() - } self.losses = { group: self.algorithm.get_loss_and_updater(group)[0] for group in self.group_map.keys() @@ -523,6 +531,15 @@ def _setup_collector(self): ) self.rollout_env = self.env_func().to(self.config.sampling_device) + def _setup_buffers(self): + self.replay_buffers = { + group: self.algorithm.get_replay_buffer( + group=group, + transforms=self.task.get_replay_buffer_transforms(self.test_env, group), + ) + for group in self.group_map.keys() + } + def _setup_name(self): self.algorithm_name = self.algorithm_config.associated_class().__name__.lower() self.model_name = self.model_config.associated_class().__name__.lower() @@ -929,3 +946,48 @@ def _load_experiment(self) -> Experiment: ) self.load_state_dict(loaded_dict) return self + + def _add_rnn_transforms( + self, + env_fun: Callable[[], EnvBase], + ) -> Callable[[], EnvBase]: + """ + This function adds RNN specific transforms to the environment + + Args: + env_fun (callable): a function that takes no args and creates an environment + + Returns: a function that takes no args and creates an environment + + """ + + def model_fun(): + env = env_fun() + group_map = self.task.group_map(env) + spec_actor = self.model_config.get_model_state_spec() + spec_actor = Composite( + { + group: Composite( + spec_actor.expand(len(agents), *spec_actor.shape), + shape=(len(agents),), + ) + for group, agents in group_map.items() + } + ) + + out_env = TransformedEnv( + env, + Compose( + *( + [InitTracker(init_key="is_init")] + + ( + [TensorDictPrimer(spec_actor, reset_key="_reset")] + if len(spec_actor.keys(True, True)) > 0 + else [] + ) + ) + ), + ) + return out_env + + return model_fun diff --git a/test/test_pettingzoo.py b/test/test_pettingzoo.py index 8f9169a4..5ce637c5 100644 --- a/test/test_pettingzoo.py +++ b/test/test_pettingzoo.py @@ -6,10 +6,12 @@ import pytest + from benchmarl.algorithms import ( algorithm_config_registry, IddpgConfig, IppoConfig, + IqlConfig, IsacConfig, MaddpgConfig, MappoConfig, @@ -68,16 +70,13 @@ def test_all_algos( @pytest.mark.parametrize("algo_config", [IppoConfig, MasacConfig]) @pytest.mark.parametrize("task", list(PettingZooTask)) - @pytest.mark.parametrize("parallel_collection", [True, False]) def test_all_tasks( self, algo_config: AlgorithmConfig, task: Task, - parallel_collection, experiment_config, mlp_sequence_config, ): - experiment_config.parallel_collection = parallel_collection task = task.get_from_yaml() experiment = Experiment( algorithm_config=algo_config.get_from_yaml(), @@ -112,16 +111,19 @@ def test_gnn( "algo_config", [IddpgConfig, MappoConfig, QmixConfig, MasacConfig] ) @pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG]) + @pytest.mark.parametrize("parallel_collection", [True, False]) def test_gru( self, algo_config: AlgorithmConfig, task: Task, + parallel_collection: bool, experiment_config, gru_mlp_sequence_config, ): algo_config = algo_config.get_from_yaml() if algo_config.has_critic(): algo_config.share_param_critic = False + experiment_config.parallel_collection = parallel_collection experiment_config.share_policy_params = False task = task.get_from_yaml() experiment = Experiment( @@ -160,17 +162,26 @@ def test_lstm( ) experiment.run() - @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) + @pytest.mark.parametrize("algo_config", [MappoConfig, IsacConfig, IqlConfig]) @pytest.mark.parametrize("prefer_continuous", [True, False]) @pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG]) + @pytest.mark.parametrize("parallel_collection", [True, False]) def test_reloading_trainer( self, algo_config: AlgorithmConfig, task: Task, + parallel_collection, experiment_config, mlp_sequence_config, prefer_continuous, ): + # To not run the same test twice + if (prefer_continuous and not algo_config.supports_continuous_actions()) or ( + not prefer_continuous and not algo_config.supports_discrete_actions() + ): + pytest.skip() + + experiment_config.parallel_collection = parallel_collection experiment_config.prefer_continuous_actions = prefer_continuous algo_config = algo_config.get_from_yaml() From d2b242fef73c656c4b0d1647c45dfb354a9cea0c Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 17 Dec 2024 14:27:29 +0000 Subject: [PATCH 5/7] fixes --- benchmarl/experiment/experiment.py | 31 ++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 212e3c99..223fc9a4 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -442,24 +442,30 @@ def _setup_task(self): transforms_training = transforms_env + [ self.task.get_reward_sum_transform(test_env) ] - transforms_env = Compose(*transforms_env) transforms_training = Compose(*transforms_training) - self.observation_spec = self.task.observation_spec(test_env) - self.info_spec = self.task.info_spec(test_env) - self.state_spec = self.task.state_spec(test_env) - self.action_mask_spec = self.task.action_mask_spec(test_env) - self.action_spec = self.task.action_spec(test_env) - self.group_map = self.task.group_map(test_env) + # Initialize test env + self.test_env = TransformedEnv(test_env, transforms_env.clone()).to( + self.config.sampling_device + ) + + self.observation_spec = self.task.observation_spec(self.test_env) + self.info_spec = self.task.info_spec(self.test_env) + self.state_spec = self.task.state_spec(self.test_env) + self.action_mask_spec = self.task.action_mask_spec(self.test_env) + self.action_spec = self.task.action_spec(self.test_env) + self.group_map = self.task.group_map(self.test_env) self.train_group_map = copy.deepcopy(self.group_map) - self.max_steps = self.task.max_steps(test_env) + self.max_steps = self.task.max_steps(self.test_env) + # Add rnn transforms here so they do not show in the benchmarl specs if self.model_config.is_rnn: - test_env = self._add_rnn_transforms(lambda: test_env)() + self.test_env = self._add_rnn_transforms(lambda: self.test_env)() env_func = self._add_rnn_transforms(env_func) - if test_env.batch_size == (): + # Initialize train env + if self.test_env.batch_size == (): # If the environment is not vectorized, we simulate vectorization using parallel or serial environments env_class = ( SerialEnv if not self.config.parallel_collection else ParallelEnv @@ -469,14 +475,11 @@ def _setup_task(self): transforms_training.clone(), ) else: + # Otherwise it is already vectorized self.env_func = lambda: TransformedEnv( env_func(), transforms_training.clone() ) - self.test_env = TransformedEnv(test_env, transforms_env.clone()).to( - self.config.sampling_device - ) - def _setup_algorithm(self): self.algorithm = self.algorithm_config.get_algorithm(experiment=self) From a6952e0c8b4dbb03319860911fa615ec67b0c6e6 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 17 Dec 2024 21:49:22 +0000 Subject: [PATCH 6/7] fixes --- benchmarl/experiment/experiment.py | 65 ++++-------------------------- benchmarl/utils.py | 56 ++++++++++++++++++++++++- test/test_meltingpot.py | 28 +++++++++++++ 3 files changed, 90 insertions(+), 59 deletions(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 223fc9a4..5f3fb3c7 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -15,22 +15,14 @@ from dataclasses import dataclass, MISSING from pathlib import Path -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from tensordict import TensorDictBase from tensordict.nn import TensorDictSequential from torchrl.collectors import SyncDataCollector -from torchrl.data import Composite -from torchrl.envs import ( - EnvBase, - InitTracker, - ParallelEnv, - SerialEnv, - TensorDictPrimer, - TransformedEnv, -) +from torchrl.envs import ParallelEnv, SerialEnv, TransformedEnv from torchrl.envs.transforms import Compose from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.record.loggers import generate_exp_name @@ -44,7 +36,7 @@ from benchmarl.experiment.logger import Logger from benchmarl.models import GnnConfig, SequenceModelConfig from benchmarl.models.common import ModelConfig -from benchmarl.utils import _read_yaml_config, seed_everything +from benchmarl.utils import _add_rnn_transforms, _read_yaml_config, seed_everything _has_hydra = importlib.util.find_spec("hydra") is not None if _has_hydra: @@ -461,8 +453,10 @@ def _setup_task(self): # Add rnn transforms here so they do not show in the benchmarl specs if self.model_config.is_rnn: - self.test_env = self._add_rnn_transforms(lambda: self.test_env)() - env_func = self._add_rnn_transforms(env_func) + self.test_env = _add_rnn_transforms( + lambda: self.test_env, self.group_map, self.model_config + )() + env_func = _add_rnn_transforms(env_func, self.group_map, self.model_config) # Initialize train env if self.test_env.batch_size == (): @@ -949,48 +943,3 @@ def _load_experiment(self) -> Experiment: ) self.load_state_dict(loaded_dict) return self - - def _add_rnn_transforms( - self, - env_fun: Callable[[], EnvBase], - ) -> Callable[[], EnvBase]: - """ - This function adds RNN specific transforms to the environment - - Args: - env_fun (callable): a function that takes no args and creates an environment - - Returns: a function that takes no args and creates an environment - - """ - - def model_fun(): - env = env_fun() - group_map = self.task.group_map(env) - spec_actor = self.model_config.get_model_state_spec() - spec_actor = Composite( - { - group: Composite( - spec_actor.expand(len(agents), *spec_actor.shape), - shape=(len(agents),), - ) - for group, agents in group_map.items() - } - ) - - out_env = TransformedEnv( - env, - Compose( - *( - [InitTracker(init_key="is_init")] - + ( - [TensorDictPrimer(spec_actor, reset_key="_reset")] - if len(spec_actor.keys(True, True)) > 0 - else [] - ) - ) - ), - ) - return out_env - - return model_fun diff --git a/benchmarl/utils.py b/benchmarl/utils.py index d2d63ae6..efe36e27 100644 --- a/benchmarl/utils.py +++ b/benchmarl/utils.py @@ -6,10 +6,16 @@ import importlib import random -from typing import Any, Dict, Union +import typing +from typing import Any, Callable, Dict, List, Union import torch import yaml +from torchrl.data import Composite +from torchrl.envs import Compose, EnvBase, InitTracker, TensorDictPrimer, TransformedEnv + +if typing.TYPE_CHECKING: + from benchmarl.models import ModelConfig _has_numpy = importlib.util.find_spec("numpy") is not None @@ -53,3 +59,51 @@ def seed_everything(seed: int): import numpy numpy.random.seed(seed) + + +def _add_rnn_transforms( + env_fun: Callable[[], EnvBase], + group_map: Dict[str, List[str]], + model_config: "ModelConfig", +) -> Callable[[], EnvBase]: + """ + This function adds RNN specific transforms to the environment + + Args: + env_fun (callable): a function that takes no args and creates an environment + group_map (Dict[str,List[str]]): the group_map of the agents + model_config (ModelConfig): the model configuration + + Returns: a function that takes no args and creates an environment + + """ + + def model_fun(): + env = env_fun() + spec_actor = model_config.get_model_state_spec() + spec_actor = Composite( + { + group: Composite( + spec_actor.expand(len(agents), *spec_actor.shape), + shape=(len(agents),), + ) + for group, agents in group_map.items() + } + ) + + out_env = TransformedEnv( + env, + Compose( + *( + [InitTracker(init_key="is_init")] + + ( + [TensorDictPrimer(spec_actor, reset_key="_reset")] + if len(spec_actor.keys(True, True)) > 0 + else [] + ) + ) + ), + ) + return out_env + + return model_fun diff --git a/test/test_meltingpot.py b/test/test_meltingpot.py index 2e7b20f9..07ecb2d4 100644 --- a/test/test_meltingpot.py +++ b/test/test_meltingpot.py @@ -10,6 +10,7 @@ from benchmarl.algorithms import ( algorithm_config_registry, IppoConfig, + MappoConfig, MasacConfig, QmixConfig, ) @@ -78,6 +79,33 @@ def test_all_tasks( ) experiment.run() + @pytest.mark.parametrize("algo_config", [MappoConfig]) + @pytest.mark.parametrize("task", [MeltingPotTask.COINS]) + @pytest.mark.parametrize("parallel_collection", [True, False]) + def test_lstm( + self, + algo_config: AlgorithmConfig, + task: Task, + parallel_collection: bool, + experiment_config, + cnn_lstm_sequence_config, + ): + algo_config = algo_config.get_from_yaml() + if algo_config.has_critic(): + algo_config.share_param_critic = False + experiment_config.parallel_collection = parallel_collection + experiment_config.share_policy_params = False + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config, + model_config=cnn_lstm_sequence_config, + critic_model_config=cnn_lstm_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) @pytest.mark.parametrize("task", [MeltingPotTask.COMMONS_HARVEST__OPEN]) def test_reloading_trainer( From a1c10ffff0703ac9bae159fab2c3f200b76e2abf Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 17 Dec 2024 21:53:50 +0000 Subject: [PATCH 7/7] revert buffer update --- benchmarl/experiment/experiment.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 5f3fb3c7..2ea0f92c 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -362,7 +362,6 @@ def _setup(self): self._setup_task() self._setup_algorithm() self._setup_collector() - self._setup_buffers() self._setup_name() self._setup_logger() self._on_setup() @@ -480,6 +479,13 @@ def _setup_algorithm(self): self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)() self.env_func = self.algorithm.process_env_fun(self.env_func) + self.replay_buffers = { + group: self.algorithm.get_replay_buffer( + group=group, + transforms=self.task.get_replay_buffer_transforms(self.test_env, group), + ) + for group in self.group_map.keys() + } self.losses = { group: self.algorithm.get_loss_and_updater(group)[0] for group in self.group_map.keys() @@ -528,15 +534,6 @@ def _setup_collector(self): ) self.rollout_env = self.env_func().to(self.config.sampling_device) - def _setup_buffers(self): - self.replay_buffers = { - group: self.algorithm.get_replay_buffer( - group=group, - transforms=self.task.get_replay_buffer_transforms(self.test_env, group), - ) - for group in self.group_map.keys() - } - def _setup_name(self): self.algorithm_name = self.algorithm_config.associated_class().__name__.lower() self.model_name = self.model_config.associated_class().__name__.lower()