diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py index 892acb78..9c9a4ffc 100644 --- a/benchmarl/models/common.py +++ b/benchmarl/models/common.py @@ -5,7 +5,7 @@ # import pathlib - +import warnings from abc import ABC, abstractmethod from dataclasses import asdict, dataclass from typing import Any, Callable, Dict, List, Optional, Sequence @@ -157,6 +157,29 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # _check_spec(tensordict, self.output_spec) return tensordict + def share_params_with(self, other_model): + """Share paramters with another identical model model. + + This function modifies in-place the parameters of ``other_model`` to reference the parameters of ``self`` + + Args: + other_model (Model): the model that will share the parameters of ``self``. + + """ + if ( + self.share_params != other_model.share_params + or self.centralised != other_model.centralised + or self.input_has_agent_dim != other_model.input_has_agent_dim + or self.input_spec != other_model.input_spec + or self.output_spec != other_model.output_spec + ): + raise warnings.warn( + "Sharing parameters with models that are not identical. " + "This might result in unintended behavior or error." + ) + for param, other_param in zip(self.parameters(), other_model.parameters()): + other_param.data[:] = param.data + ############################### # Abstract methods to implement ############################### diff --git a/test/test_models.py b/test/test_models.py index c7671f56..191edfb0 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -19,6 +19,76 @@ from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec +def _get_input_and_output_specs( + centralised, + input_has_agent_dim, + model_name, + share_params, + n_agents, + in_features=2, + out_features=4, + x=12, + y=12, +): + + if model_name == "cnn": + multi_agent_input_shape = (n_agents, x, y, in_features) + single_agent_input_shape = (x, y, in_features) + else: + multi_agent_input_shape = (n_agents, in_features) + single_agent_input_shape = in_features + + other_multi_agent_input_shape = (n_agents, in_features) + other_single_agent_input_shape = in_features + + if input_has_agent_dim: + input_spec = CompositeSpec( + { + "agents": CompositeSpec( + { + "observation": UnboundedContinuousTensorSpec( + shape=multi_agent_input_shape + ), + "other": UnboundedContinuousTensorSpec( + shape=other_multi_agent_input_shape + ), + }, + shape=(n_agents,), + ) + } + ) + else: + input_spec = CompositeSpec( + { + "observation": UnboundedContinuousTensorSpec( + shape=single_agent_input_shape + ), + "other": UnboundedContinuousTensorSpec( + shape=other_single_agent_input_shape + ), + }, + ) + + if output_has_agent_dim(centralised=centralised, share_params=share_params): + output_spec = CompositeSpec( + { + "agents": CompositeSpec( + { + "out": UnboundedContinuousTensorSpec( + shape=(n_agents, out_features) + ) + }, + shape=(n_agents,), + ) + }, + ) + else: + output_spec = CompositeSpec( + {"out": UnboundedContinuousTensorSpec(shape=(out_features,))}, + ) + return input_spec, output_spec + + @pytest.mark.parametrize("model_name", model_config_registry.keys()) def test_loading_simple_models(model_name): with initialize(version_base=None, config_path="../benchmarl/conf"): @@ -72,7 +142,7 @@ def test_loading_sequence_models(model_name, intermediate_size=10): ], ) def test_models_forward_shape( - share_params, centralised, input_has_agent_dim, model_name, batch_size + share_params, centralised, input_has_agent_dim, model_name, batch_size, n_agents=3 ): if not input_has_agent_dim and not centralised: pytest.skip() # this combination should never happen @@ -94,68 +164,84 @@ def test_models_forward_shape( else: config = model_config_registry[model_name].get_from_yaml() - n_agents = 2 - x = 12 - y = 12 - channels = 3 - out_features = 4 + input_spec, output_spec = _get_input_and_output_specs( + centralised=centralised, + input_has_agent_dim=input_has_agent_dim, + model_name=model_name if isinstance(model_name, str) else model_name[0], + share_params=share_params, + n_agents=n_agents, + ) - if "cnn" in model_name: - multi_agent_tensor = torch.rand((*batch_size, n_agents, x, y, channels)) - single_agent_tensor = torch.rand((*batch_size, x, y, channels)) - else: - multi_agent_tensor = torch.rand((*batch_size, n_agents, channels)) - single_agent_tensor = torch.rand((*batch_size, channels)) + model = config.get_model( + input_spec=input_spec, + output_spec=output_spec, + share_params=share_params, + centralised=centralised, + input_has_agent_dim=input_has_agent_dim, + n_agents=n_agents, + device="cpu", + agent_group="agents", + action_spec=None, + ) + input_td = input_spec.expand(batch_size).rand() + out_td = model(input_td) + assert output_spec.expand(batch_size).is_in(out_td) - other_multi_agent_tensor = torch.rand((*batch_size, n_agents, channels)) - other_single_agent_tensor = torch.rand((*batch_size, channels)) - if input_has_agent_dim: - input_spec = CompositeSpec( - { - "agents": CompositeSpec( - { - "observation": UnboundedContinuousTensorSpec( - shape=multi_agent_tensor.shape[len(batch_size) :] - ), - "other": UnboundedContinuousTensorSpec( - shape=other_multi_agent_tensor.shape[len(batch_size) :] - ), - }, - shape=(n_agents,), - ) - } - ) - else: - input_spec = CompositeSpec( - { - "observation": UnboundedContinuousTensorSpec( - shape=single_agent_tensor.shape[len(batch_size) :] - ), - "other": UnboundedContinuousTensorSpec( - shape=other_single_agent_tensor.shape[len(batch_size) :] - ), - }, - ) +@pytest.mark.parametrize("input_has_agent_dim", [True, False]) +@pytest.mark.parametrize("centralised", [True, False]) +@pytest.mark.parametrize("share_params", [True, False]) +@pytest.mark.parametrize( + "model_name", + [ + *model_config_registry.keys(), + ["cnn", "gnn", "mlp"], + ["cnn", "mlp", "gnn"], + ["cnn", "mlp"], + ], +) +@pytest.mark.parametrize("batch_size", [(), (2,), (3, 2)]) +def test_share_params_between_models( + share_params, + centralised, + input_has_agent_dim, + model_name, + batch_size, + n_agents=3, +): + if not input_has_agent_dim and not centralised: + pytest.skip() # this combination should never happen + if ("gnn" in model_name) and ( + not input_has_agent_dim + or (isinstance(model_name, list) and model_name[0] != "gnn") + ): + pytest.skip("gnn model needs agent dim as input") + torch.manual_seed(0) - if output_has_agent_dim(centralised=centralised, share_params=share_params): - output_spec = CompositeSpec( - { - "agents": CompositeSpec( - { - "out": UnboundedContinuousTensorSpec( - shape=(n_agents, out_features) - ) - }, - shape=(n_agents,), - ) - }, + input_spec, output_spec = _get_input_and_output_specs( + centralised=centralised, + input_has_agent_dim=input_has_agent_dim, + model_name=model_name if isinstance(model_name, str) else model_name[0], + share_params=share_params, + n_agents=n_agents, + ) + input_spec2, output_spec2 = _get_input_and_output_specs( + centralised=centralised, + input_has_agent_dim=input_has_agent_dim, + model_name=model_name if isinstance(model_name, str) else model_name[0], + share_params=share_params, + n_agents=n_agents, + ) + + if isinstance(model_name, List): + config = SequenceModelConfig( + model_configs=[ + model_config_registry[config].get_from_yaml() for config in model_name + ], + intermediate_sizes=[4] * (len(model_name) - 1), ) else: - output_spec = CompositeSpec( - {"out": UnboundedContinuousTensorSpec(shape=(out_features,))}, - ) - + config = model_config_registry[model_name].get_from_yaml() model = config.get_model( input_spec=input_spec, output_spec=output_spec, @@ -167,9 +253,22 @@ def test_models_forward_shape( agent_group="agents", action_spec=None, ) - input_td = input_spec.expand(batch_size).rand() - out_td = model(input_td) - assert output_spec.expand(batch_size).is_in(out_td) + second_model = config.get_model( + input_spec=input_spec2, + output_spec=output_spec2, + share_params=share_params, + centralised=centralised, + input_has_agent_dim=input_has_agent_dim, + n_agents=n_agents, + device="cpu", + agent_group="agents", + action_spec=None, + ) + for param, second_param in zip(model.parameters(), second_model.parameters()): + assert not torch.eq(param, second_param).any() + model.share_params_with(second_model) + for param, second_param in zip(model.parameters(), second_model.parameters()): + assert torch.eq(param, second_param).all() class TestGnn: