diff --git a/.github/unittest/install_dependencies_nightly.sh b/.github/unittest/install_dependencies_nightly.sh index 9368d200..6f90c4f9 100644 --- a/.github/unittest/install_dependencies_nightly.sh +++ b/.github/unittest/install_dependencies_nightly.sh @@ -9,10 +9,13 @@ python -m pip install torch # Not using nightly torch # python -m pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +cd ../BenchMARL +pip install -e . +pip uninstall --yes torchrl +pip uninstall --yes tensordict + cd .. python -m pip install git+https://github.com/pytorch-labs/tensordict.git git clone https://github.com/pytorch/rl.git cd rl python setup.py develop -cd ../BenchMARL -pip install -e . diff --git a/README.md b/README.md index 0d73cf36..9e421edb 100644 --- a/README.md +++ b/README.md @@ -258,12 +258,12 @@ agent group. Here is a table of the models implemented in BenchMARL |--------------------------------|:-------------:|:-----------------------------:|:-----------------------------:| | [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes | | [GNN](benchmarl/models/gnn.py) | Yes | No | No | +| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes | And the ones that are _work in progress_ | Name | Decentralized | Centralized with local inputs | Centralized with global input | |--------------------|:-------------:|:-----------------------------:|:-----------------------------:| -| CNN | Yes | Yes | Yes | | RNN (GRU and LSTM) | Yes | Yes | Yes | diff --git a/benchmarl/conf/model/layers/cnn.yaml b/benchmarl/conf/model/layers/cnn.yaml new file mode 100644 index 00000000..d299ed57 --- /dev/null +++ b/benchmarl/conf/model/layers/cnn.yaml @@ -0,0 +1,18 @@ + +name: cnn + +mlp_num_cells: [32] +mlp_layer_class: torch.nn.Linear +mlp_activation_class: torch.nn.Tanh +mlp_activation_kwargs: null +mlp_norm_class: null +mlp_norm_kwargs: null + +cnn_num_cells: [32, 32, 32] +cnn_kernel_sizes: 3 +cnn_strides: 1 +cnn_paddings: 0 +cnn_activation_class: torch.nn.Tanh +cnn_activation_kwargs: null +cnn_norm_class: null +cnn_norm_kwargs: null diff --git a/benchmarl/models/__init__.py b/benchmarl/models/__init__.py index c5ad3a47..33510c48 100644 --- a/benchmarl/models/__init__.py +++ b/benchmarl/models/__init__.py @@ -4,10 +4,11 @@ # LICENSE file in the root directory of this source tree. # +from .cnn import Cnn, CnnConfig from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig from .gnn import Gnn, GnnConfig from .mlp import Mlp, MlpConfig -classes = ["Mlp", "MlpConfig", "Gnn", "GnnConfig"] +classes = ["Mlp", "MlpConfig", "Gnn", "GnnConfig", "Cnn", "CnnConfig"] -model_config_registry = {"mlp": MlpConfig, "gnn": GnnConfig} +model_config_registry = {"mlp": MlpConfig, "gnn": GnnConfig, "cnn": CnnConfig} diff --git a/benchmarl/models/cnn.py b/benchmarl/models/cnn.py new file mode 100644 index 00000000..0ecb3621 --- /dev/null +++ b/benchmarl/models/cnn.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from dataclasses import dataclass, MISSING +from typing import List, Optional, Sequence, Tuple, Type, Union + +import torch + +from tensordict import TensorDictBase +from torch import nn +from torchrl.modules import ConvNet, MLP, MultiAgentConvNet, MultiAgentMLP + +from benchmarl.models.common import Model, ModelConfig + + +def _number_conv_outputs( + n_conv_inputs: Union[int, Tuple[int, int]], + paddings: List[Union[int, Tuple[int, int]]], + kernel_sizes: List[Union[int, Tuple[int, int]]], + strides: List[Union[int, Tuple[int, int]]], +) -> Tuple[int, int]: + if not isinstance(n_conv_inputs, int): + n_conv_inputs_x, n_conv_inputs_y = n_conv_inputs + else: + n_conv_inputs_x = n_conv_inputs_y = n_conv_inputs + for kernel_size, padding, stride in zip(kernel_sizes, paddings, strides): + if not isinstance(kernel_size, int): + kernel_size_x, kernel_size_y = kernel_size + else: + kernel_size_x = kernel_size_y = kernel_size + if not isinstance(padding, int): + padding_x, padding_y = padding + else: + padding_x = padding_y = padding + if not isinstance(stride, int): + stride_x, stride_y = stride + else: + stride_x = stride_y = stride + + n_conv_inputs_x = ( + n_conv_inputs_x + 2 * padding_x - kernel_size_x + ) // stride_x + 1 + n_conv_inputs_y = ( + n_conv_inputs_y + 2 * padding_y - kernel_size_y + ) // stride_y + 1 + + return n_conv_inputs_x, n_conv_inputs_y + + +class Cnn(Model): + """Convolutional Neural Network (CNN) model. + + Args: + + cnn_num_cells (int or Sequence of int): number of cells of + every layer in between the input and output. If an integer is + provided, every layer will have the same number of cells. If an + iterable is provided, the linear layers ``out_features`` will match + the content of num_cells. + cnn_kernel_sizes (int, sequence of int): Kernel size(s) of the + conv network. If iterable, the length must match the depth, + defined by the ``num_cells`` or depth arguments. + cnn_strides (int or sequence of int): Stride(s) of the conv network. If + iterable, the length must match the depth, defined by the + ``num_cells`` or depth arguments. + cnn_paddings: (int or Sequence of int): padding size for every layer. + cnn_activation_class (Type[nn.Module] or callable): activation + class or constructor to be used. + cnn_activation_kwargs (dict or list of dicts, optional): kwargs to be used + with the activation class. A list of kwargs of length ``depth`` + can also be passed, with one element per layer. + cnn_norm_class (Type or callable, optional): normalization class or + constructor, if any. + cnn_norm_kwargs (dict or list of dicts, optional): kwargs to be used with + the normalization layers. A list of kwargs of length ``depth`` can + also be passed, with one element per layer. + mlp_num_cells (int or Sequence[int]): number of cells of every layer in between the input and output. If + an integer is provided, every layer will have the same number of cells. If an iterable is provided, + the linear layers out_features will match the content of num_cells. + mlp_layer_class (Type[nn.Module]): class to be used for the linear layers; + mlp_activation_class (Type[nn.Module]): activation class to be used. + mlp_activation_kwargs (dict, optional): kwargs to be used with the activation class; + mlp_norm_class (Type, optional): normalization class, if any. + mlp_norm_kwargs (dict, optional): kwargs to be used with the normalization layers; + + """ + + def __init__( + self, + **kwargs, + ): + super().__init__( + input_spec=kwargs.pop("input_spec"), + output_spec=kwargs.pop("output_spec"), + agent_group=kwargs.pop("agent_group"), + input_has_agent_dim=kwargs.pop("input_has_agent_dim"), + n_agents=kwargs.pop("n_agents"), + centralised=kwargs.pop("centralised"), + share_params=kwargs.pop("share_params"), + device=kwargs.pop("device"), + action_spec=kwargs.pop("action_spec"), + ) + + self.x = self.input_leaf_spec.shape[-3] + self.y = self.input_leaf_spec.shape[-2] + self.input_features = self.input_leaf_spec.shape[-1] + + self.output_features = self.output_leaf_spec.shape[-1] + + mlp_net_kwargs = { + "_".join(k.split("_")[1:]): v + for k, v in kwargs.items() + if k.startswith("mlp_") + } + cnn_net_kwargs = { + "_".join(k.split("_")[1:]): v + for k, v in kwargs.items() + if k.startswith("cnn_") + } + + if self.input_has_agent_dim: + self.cnn = MultiAgentConvNet( + in_features=self.input_features, + n_agents=self.n_agents, + centralised=self.centralised, + share_params=self.share_params, + device=self.device, + **cnn_net_kwargs, + ) + example_net = self.cnn._empty_net + + else: + self.cnn = nn.ModuleList( + [ + ConvNet( + in_features=self.input_features, + device=self.device, + **cnn_net_kwargs, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) + example_net = self.cnn[0] + + out_features = example_net.out_features + out_x, out_y = _number_conv_outputs( + n_conv_inputs=(self.x, self.y), + kernel_sizes=example_net.kernel_sizes, + paddings=example_net.paddings, + strides=example_net.strides, + ) + cnn_output_size = out_features * out_x * out_y + + if self.output_has_agent_dim: + self.mlp = MultiAgentMLP( + n_agent_inputs=cnn_output_size, + n_agent_outputs=self.output_features, + n_agents=self.n_agents, + centralised=self.centralised, + share_params=self.share_params, + device=self.device, + **mlp_net_kwargs, + ) + else: + self.mlp = nn.ModuleList( + [ + MLP( + in_features=cnn_output_size, + out_features=self.output_features, + device=self.device, + **mlp_net_kwargs, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) + + def _perform_checks(self): + super()._perform_checks() + + if self.input_has_agent_dim and self.input_leaf_spec.shape[-4] != self.n_agents: + raise ValueError( + "If the CNN input has the agent dimension," + " the forth to last spec dimension should be the number of agents" + ) + if ( + self.output_has_agent_dim + and self.output_leaf_spec.shape[-2] != self.n_agents + ): + raise ValueError( + "If the CNN output has the agent dimension," + " the second to last spec dimension should be the number of agents" + ) + + def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: + # Gather in_key + input = tensordict.get(self.in_key) + # BenchMARL images are X,Y,C -> we convert them to C, X, Y for processing in TorchRL models + input = input.transpose(-3, -1).transpose(-2, -1) + + # Has multi-agent input dimension + if self.input_has_agent_dim: + cnn_out = self.cnn.forward(input) + if not self.output_has_agent_dim: + # If we are here the module is centralised and parameter shared. + # Thus the multi-agent dimension has been expanded, + # We remove it without loss of data + cnn_out = cnn_out[..., 0, :] + + # Does not have multi-agent input dimension + else: + if not self.share_params: + cnn_out = torch.stack( + [net(input) for net in self.cnn], + dim=-2, + ) + else: + cnn_out = self.cnn[0](input) + + # Cnn output has multi-agent input dimension + if self.output_has_agent_dim: + res = self.mlp.forward(cnn_out) + else: + if not self.share_params: + res = torch.stack( + [net(cnn_out) for net in self.mlp], + dim=-2, + ) + else: + res = self.mlp[0](cnn_out) + + tensordict.set(self.out_key, res) + return tensordict + + +@dataclass +class CnnConfig(ModelConfig): + """Dataclass config for a :class:`~benchmarl.models.Cnn`.""" + + cnn_num_cells: Sequence[int] = MISSING + cnn_kernel_sizes: Sequence[int] = MISSING + cnn_strides: Sequence[int] = MISSING + cnn_paddings: Sequence[int] = MISSING + cnn_activation_class: Type[nn.Module] = MISSING + + mlp_num_cells: Sequence[int] = MISSING + mlp_layer_class: Type[nn.Module] = MISSING + mlp_activation_class: Type[nn.Module] = MISSING + + cnn_activation_kwargs: Optional[dict] = None + cnn_norm_class: Type[nn.Module] = None + cnn_norm_kwargs: Optional[dict] = None + + mlp_activation_kwargs: Optional[dict] = None + mlp_norm_class: Type[nn.Module] = None + mlp_norm_kwargs: Optional[dict] = None + + @staticmethod + def associated_class(): + return Cnn diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index 93cb1c25..6b26674a 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -170,7 +170,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: *batch_size, self.n_agents, self.output_features, - )[:, i] + )[..., i, :] for i, gnn in enumerate(self.gnns) ], dim=-2, diff --git a/docs/source/concepts/components.rst b/docs/source/concepts/components.rst index 638e4e61..95be5e70 100644 --- a/docs/source/concepts/components.rst +++ b/docs/source/concepts/components.rst @@ -111,3 +111,5 @@ agent group. Here is a table of the models implemented in BenchMARL +---------------------------------+---------------+-------------------------------+-------------------------------+ | :class:`~benchmarl.models.Gnn` | Yes | No | No | +---------------------------------+---------------+-------------------------------+-------------------------------+ + | :class:`~benchmarl.models.Cnn` | Yes | Yes | Yes | + +---------------------------------+---------------+-------------------------------+-------------------------------+ diff --git a/test/test_models.py b/test/test_models.py index db579658..dfe6b57a 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -3,15 +3,20 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # +import importlib +from typing import List import pytest +import torch from benchmarl.hydra_config import load_model_config_from_hydra from benchmarl.models import model_config_registry -from benchmarl.models.common import SequenceModelConfig +from benchmarl.models.common import output_has_agent_dim, SequenceModelConfig from hydra import compose, initialize +from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec + @pytest.mark.parametrize("model_name", model_config_registry.keys()) def test_loading_simple_models(model_name): @@ -29,7 +34,7 @@ def test_loading_simple_models(model_name): @pytest.mark.parametrize("model_name", model_config_registry.keys()) -def test_loading_sequence_models(model_name, intermidiate_size=10): +def test_loading_sequence_models(model_name, intermediate_size=10): with initialize(version_base=None, config_path="../benchmarl/conf"): cfg = compose( config_name="config", @@ -40,13 +45,117 @@ def test_loading_sequence_models(model_name, intermidiate_size=10): f"model/layers@model.layers.l1={model_name}", f"model/layers@model.layers.l2={model_name}", f"+model/layers@model.layers.l3={model_name}", - f"model.intermediate_sizes={[intermidiate_size,intermidiate_size]}", + f"model.intermediate_sizes={[intermediate_size,intermediate_size]}", ], ) hydra_model_config = load_model_config_from_hydra(cfg.model) layer_config = model_config_registry[model_name].get_from_yaml() yaml_config = SequenceModelConfig( model_configs=[layer_config, layer_config, layer_config], - intermediate_sizes=[intermidiate_size, intermidiate_size], + intermediate_sizes=[intermediate_size, intermediate_size], ) assert hydra_model_config == yaml_config + + +@pytest.mark.parametrize("input_has_agent_dim", [True, False]) +@pytest.mark.parametrize("centralised", [True, False]) +@pytest.mark.parametrize("share_params", [True, False]) +@pytest.mark.parametrize("batch_size", [(), (2,), (3, 2)]) +@pytest.mark.parametrize( + "model_name", + [ + *model_config_registry.keys(), + ["cnn", "gnn", "mlp"], + ["cnn", "mlp", "gnn"], + ["cnn", "mlp"], + ], +) +def test_models_forward_shape( + share_params, centralised, input_has_agent_dim, model_name, batch_size +): + if not input_has_agent_dim and not centralised: + pytest.skip() # this combination should never happen + if ("gnn" in model_name) and centralised: + pytest.skip("gnn model is always decentralized") + if importlib.metadata.version("torchrl") <= "0.3.1" and "cnn" in model_name: + pytest.skip("TorchRL <= 0.3.1 does not support MultiAgentCNN") + + torch.manual_seed(0) + + if isinstance(model_name, List): + config = SequenceModelConfig( + model_configs=[ + model_config_registry[config].get_from_yaml() for config in model_name + ], + intermediate_sizes=[4] * (len(model_name) - 1), + ) + else: + config = model_config_registry[model_name].get_from_yaml() + + n_agents = 2 + x = 12 + y = 12 + channels = 3 + out_features = 4 + + if "cnn" in model_name: + multi_agent_tensor = torch.rand((*batch_size, n_agents, x, y, channels)) + single_agent_tensor = torch.rand((*batch_size, x, y, channels)) + else: + multi_agent_tensor = torch.rand((*batch_size, n_agents, channels)) + single_agent_tensor = torch.rand((*batch_size, channels)) + + if input_has_agent_dim: + input_spec = CompositeSpec( + { + "agents": CompositeSpec( + { + "observation": UnboundedContinuousTensorSpec( + shape=multi_agent_tensor.shape[len(batch_size) :] + ) + }, + shape=(n_agents,), + ) + } + ) + else: + input_spec = CompositeSpec( + { + "observation": UnboundedContinuousTensorSpec( + shape=single_agent_tensor.shape[len(batch_size) :] + ) + }, + ) + + if output_has_agent_dim(centralised=centralised, share_params=share_params): + output_spec = CompositeSpec( + { + "agents": CompositeSpec( + { + "out": UnboundedContinuousTensorSpec( + shape=(n_agents, out_features) + ) + }, + shape=(n_agents,), + ) + }, + ) + else: + output_spec = CompositeSpec( + {"out": UnboundedContinuousTensorSpec(shape=(out_features,))}, + ) + + model = config.get_model( + input_spec=input_spec, + output_spec=output_spec, + share_params=share_params, + centralised=centralised, + input_has_agent_dim=input_has_agent_dim, + n_agents=n_agents, + device="cpu", + agent_group="agents", + action_spec=None, + ) + input_td = input_spec.rand() + out_td = model(input_td) + assert output_spec.expand(batch_size).is_in(out_td)