Skip to content

Commit

Permalink
[Model] CNN (#74)
Browse files Browse the repository at this point in the history
* add multiagent cnn implementation and tests

* amend

* amend

* amend

* amend

* amend

* amend

* amend

* docs

* mend

* mend

* amend

* amend

* amend

* amend

* amend

* amend

* amend

* amend

* amend

* amend

* amend

* amend

* amend

* amend

---------

Co-authored-by: ezhang7423 <[email protected]>
  • Loading branch information
matteobettini and ezhang7423 authored Apr 7, 2024
1 parent e6afcf5 commit e272278
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 10 deletions.
7 changes: 5 additions & 2 deletions .github/unittest/install_dependencies_nightly.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ python -m pip install torch
# Not using nightly torch
# python -m pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall

cd ../BenchMARL
pip install -e .
pip uninstall --yes torchrl
pip uninstall --yes tensordict

cd ..
python -m pip install git+https://github.com/pytorch-labs/tensordict.git
git clone https://github.com/pytorch/rl.git
cd rl
python setup.py develop
cd ../BenchMARL
pip install -e .
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,12 @@ agent group. Here is a table of the models implemented in BenchMARL
|--------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | No | No |
| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes |

And the ones that are _work in progress_

| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| CNN | Yes | Yes | Yes |
| RNN (GRU and LSTM) | Yes | Yes | Yes |


Expand Down
18 changes: 18 additions & 0 deletions benchmarl/conf/model/layers/cnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

name: cnn

mlp_num_cells: [32]
mlp_layer_class: torch.nn.Linear
mlp_activation_class: torch.nn.Tanh
mlp_activation_kwargs: null
mlp_norm_class: null
mlp_norm_kwargs: null

cnn_num_cells: [32, 32, 32]
cnn_kernel_sizes: 3
cnn_strides: 1
cnn_paddings: 0
cnn_activation_class: torch.nn.Tanh
cnn_activation_kwargs: null
cnn_norm_class: null
cnn_norm_kwargs: null
5 changes: 3 additions & 2 deletions benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
# LICENSE file in the root directory of this source tree.
#

from .cnn import Cnn, CnnConfig
from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig
from .gnn import Gnn, GnnConfig
from .mlp import Mlp, MlpConfig

classes = ["Mlp", "MlpConfig", "Gnn", "GnnConfig"]
classes = ["Mlp", "MlpConfig", "Gnn", "GnnConfig", "Cnn", "CnnConfig"]

model_config_registry = {"mlp": MlpConfig, "gnn": GnnConfig}
model_config_registry = {"mlp": MlpConfig, "gnn": GnnConfig, "cnn": CnnConfig}
262 changes: 262 additions & 0 deletions benchmarl/models/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from dataclasses import dataclass, MISSING
from typing import List, Optional, Sequence, Tuple, Type, Union

import torch

from tensordict import TensorDictBase
from torch import nn
from torchrl.modules import ConvNet, MLP, MultiAgentConvNet, MultiAgentMLP

from benchmarl.models.common import Model, ModelConfig


def _number_conv_outputs(
n_conv_inputs: Union[int, Tuple[int, int]],
paddings: List[Union[int, Tuple[int, int]]],
kernel_sizes: List[Union[int, Tuple[int, int]]],
strides: List[Union[int, Tuple[int, int]]],
) -> Tuple[int, int]:
if not isinstance(n_conv_inputs, int):
n_conv_inputs_x, n_conv_inputs_y = n_conv_inputs
else:
n_conv_inputs_x = n_conv_inputs_y = n_conv_inputs
for kernel_size, padding, stride in zip(kernel_sizes, paddings, strides):
if not isinstance(kernel_size, int):
kernel_size_x, kernel_size_y = kernel_size
else:
kernel_size_x = kernel_size_y = kernel_size
if not isinstance(padding, int):
padding_x, padding_y = padding
else:
padding_x = padding_y = padding
if not isinstance(stride, int):
stride_x, stride_y = stride
else:
stride_x = stride_y = stride

n_conv_inputs_x = (
n_conv_inputs_x + 2 * padding_x - kernel_size_x
) // stride_x + 1
n_conv_inputs_y = (
n_conv_inputs_y + 2 * padding_y - kernel_size_y
) // stride_y + 1

return n_conv_inputs_x, n_conv_inputs_y


class Cnn(Model):
"""Convolutional Neural Network (CNN) model.
Args:
cnn_num_cells (int or Sequence of int): number of cells of
every layer in between the input and output. If an integer is
provided, every layer will have the same number of cells. If an
iterable is provided, the linear layers ``out_features`` will match
the content of num_cells.
cnn_kernel_sizes (int, sequence of int): Kernel size(s) of the
conv network. If iterable, the length must match the depth,
defined by the ``num_cells`` or depth arguments.
cnn_strides (int or sequence of int): Stride(s) of the conv network. If
iterable, the length must match the depth, defined by the
``num_cells`` or depth arguments.
cnn_paddings: (int or Sequence of int): padding size for every layer.
cnn_activation_class (Type[nn.Module] or callable): activation
class or constructor to be used.
cnn_activation_kwargs (dict or list of dicts, optional): kwargs to be used
with the activation class. A list of kwargs of length ``depth``
can also be passed, with one element per layer.
cnn_norm_class (Type or callable, optional): normalization class or
constructor, if any.
cnn_norm_kwargs (dict or list of dicts, optional): kwargs to be used with
the normalization layers. A list of kwargs of length ``depth`` can
also be passed, with one element per layer.
mlp_num_cells (int or Sequence[int]): number of cells of every layer in between the input and output. If
an integer is provided, every layer will have the same number of cells. If an iterable is provided,
the linear layers out_features will match the content of num_cells.
mlp_layer_class (Type[nn.Module]): class to be used for the linear layers;
mlp_activation_class (Type[nn.Module]): activation class to be used.
mlp_activation_kwargs (dict, optional): kwargs to be used with the activation class;
mlp_norm_class (Type, optional): normalization class, if any.
mlp_norm_kwargs (dict, optional): kwargs to be used with the normalization layers;
"""

def __init__(
self,
**kwargs,
):
super().__init__(
input_spec=kwargs.pop("input_spec"),
output_spec=kwargs.pop("output_spec"),
agent_group=kwargs.pop("agent_group"),
input_has_agent_dim=kwargs.pop("input_has_agent_dim"),
n_agents=kwargs.pop("n_agents"),
centralised=kwargs.pop("centralised"),
share_params=kwargs.pop("share_params"),
device=kwargs.pop("device"),
action_spec=kwargs.pop("action_spec"),
)

self.x = self.input_leaf_spec.shape[-3]
self.y = self.input_leaf_spec.shape[-2]
self.input_features = self.input_leaf_spec.shape[-1]

self.output_features = self.output_leaf_spec.shape[-1]

mlp_net_kwargs = {
"_".join(k.split("_")[1:]): v
for k, v in kwargs.items()
if k.startswith("mlp_")
}
cnn_net_kwargs = {
"_".join(k.split("_")[1:]): v
for k, v in kwargs.items()
if k.startswith("cnn_")
}

if self.input_has_agent_dim:
self.cnn = MultiAgentConvNet(
in_features=self.input_features,
n_agents=self.n_agents,
centralised=self.centralised,
share_params=self.share_params,
device=self.device,
**cnn_net_kwargs,
)
example_net = self.cnn._empty_net

else:
self.cnn = nn.ModuleList(
[
ConvNet(
in_features=self.input_features,
device=self.device,
**cnn_net_kwargs,
)
for _ in range(self.n_agents if not self.share_params else 1)
]
)
example_net = self.cnn[0]

out_features = example_net.out_features
out_x, out_y = _number_conv_outputs(
n_conv_inputs=(self.x, self.y),
kernel_sizes=example_net.kernel_sizes,
paddings=example_net.paddings,
strides=example_net.strides,
)
cnn_output_size = out_features * out_x * out_y

if self.output_has_agent_dim:
self.mlp = MultiAgentMLP(
n_agent_inputs=cnn_output_size,
n_agent_outputs=self.output_features,
n_agents=self.n_agents,
centralised=self.centralised,
share_params=self.share_params,
device=self.device,
**mlp_net_kwargs,
)
else:
self.mlp = nn.ModuleList(
[
MLP(
in_features=cnn_output_size,
out_features=self.output_features,
device=self.device,
**mlp_net_kwargs,
)
for _ in range(self.n_agents if not self.share_params else 1)
]
)

def _perform_checks(self):
super()._perform_checks()

if self.input_has_agent_dim and self.input_leaf_spec.shape[-4] != self.n_agents:
raise ValueError(
"If the CNN input has the agent dimension,"
" the forth to last spec dimension should be the number of agents"
)
if (
self.output_has_agent_dim
and self.output_leaf_spec.shape[-2] != self.n_agents
):
raise ValueError(
"If the CNN output has the agent dimension,"
" the second to last spec dimension should be the number of agents"
)

def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# Gather in_key
input = tensordict.get(self.in_key)
# BenchMARL images are X,Y,C -> we convert them to C, X, Y for processing in TorchRL models
input = input.transpose(-3, -1).transpose(-2, -1)

# Has multi-agent input dimension
if self.input_has_agent_dim:
cnn_out = self.cnn.forward(input)
if not self.output_has_agent_dim:
# If we are here the module is centralised and parameter shared.
# Thus the multi-agent dimension has been expanded,
# We remove it without loss of data
cnn_out = cnn_out[..., 0, :]

# Does not have multi-agent input dimension
else:
if not self.share_params:
cnn_out = torch.stack(
[net(input) for net in self.cnn],
dim=-2,
)
else:
cnn_out = self.cnn[0](input)

# Cnn output has multi-agent input dimension
if self.output_has_agent_dim:
res = self.mlp.forward(cnn_out)
else:
if not self.share_params:
res = torch.stack(
[net(cnn_out) for net in self.mlp],
dim=-2,
)
else:
res = self.mlp[0](cnn_out)

tensordict.set(self.out_key, res)
return tensordict


@dataclass
class CnnConfig(ModelConfig):
"""Dataclass config for a :class:`~benchmarl.models.Cnn`."""

cnn_num_cells: Sequence[int] = MISSING
cnn_kernel_sizes: Sequence[int] = MISSING
cnn_strides: Sequence[int] = MISSING
cnn_paddings: Sequence[int] = MISSING
cnn_activation_class: Type[nn.Module] = MISSING

mlp_num_cells: Sequence[int] = MISSING
mlp_layer_class: Type[nn.Module] = MISSING
mlp_activation_class: Type[nn.Module] = MISSING

cnn_activation_kwargs: Optional[dict] = None
cnn_norm_class: Type[nn.Module] = None
cnn_norm_kwargs: Optional[dict] = None

mlp_activation_kwargs: Optional[dict] = None
mlp_norm_class: Type[nn.Module] = None
mlp_norm_kwargs: Optional[dict] = None

@staticmethod
def associated_class():
return Cnn
2 changes: 1 addition & 1 deletion benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
*batch_size,
self.n_agents,
self.output_features,
)[:, i]
)[..., i, :]
for i, gnn in enumerate(self.gnns)
],
dim=-2,
Expand Down
2 changes: 2 additions & 0 deletions docs/source/concepts/components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,5 @@ agent group. Here is a table of the models implemented in BenchMARL
+---------------------------------+---------------+-------------------------------+-------------------------------+
| :class:`~benchmarl.models.Gnn` | Yes | No | No |
+---------------------------------+---------------+-------------------------------+-------------------------------+
| :class:`~benchmarl.models.Cnn` | Yes | Yes | Yes |
+---------------------------------+---------------+-------------------------------+-------------------------------+
Loading

0 comments on commit e272278

Please sign in to comment.