From 3a9a40c6dbf6da2c2beb2013d3218ad5234e9772 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Mon, 8 Apr 2024 10:24:08 +0200 Subject: [PATCH] [Feature] Allow multipe inputs to models (#73) * add multiagent cnn implementation and tests * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * docs * mend * mend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend --------- Co-authored-by: ezhang7423 --- benchmarl/algorithms/iddpg.py | 29 +------ benchmarl/algorithms/isac.py | 26 +----- benchmarl/algorithms/maddpg.py | 47 +++-------- benchmarl/algorithms/masac.py | 46 +++-------- benchmarl/models/cnn.py | 100 ++++++++++++++++++++--- benchmarl/models/common.py | 19 +++-- benchmarl/models/gnn.py | 27 +++++- benchmarl/models/mlp.py | 33 ++++++-- examples/extending/model/custom_model.py | 26 ++++-- test/test_models.py | 13 ++- 10 files changed, 208 insertions(+), 158 deletions(-) diff --git a/benchmarl/algorithms/iddpg.py b/benchmarl/algorithms/iddpg.py index b5fe1b2d..a114123b 100644 --- a/benchmarl/algorithms/iddpg.py +++ b/benchmarl/algorithms/iddpg.py @@ -7,7 +7,6 @@ from dataclasses import dataclass, MISSING from typing import Dict, Iterable, Tuple, Type -import torch from tensordict import TensorDictBase from tensordict.nn import TensorDictModule, TensorDictSequential from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec @@ -188,34 +187,12 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase: def get_value_module(self, group: str) -> TensorDictModule: n_agents = len(self.group_map[group]) modules = [] - group_observation_key = list(self.observation_spec[group].keys())[0] - modules.append( - TensorDictModule( - lambda obs, action: torch.cat([obs, action], dim=-1), - in_keys=[ - (group, group_observation_key), - (group, "action"), - ], - out_keys=[(group, "obs_action")], - ) - ) critic_input_spec = CompositeSpec( { - group: CompositeSpec( - { - "obs_action": UnboundedContinuousTensorSpec( - shape=( - n_agents, - self.observation_spec[ - group, group_observation_key - ].shape[-1] - + self.action_spec[group, "action"].shape[-1], - ) - ) - }, - shape=(n_agents,), - ) + group: self.observation_spec[group] + .clone() + .update(self.action_spec[group]) } ) critic_output_spec = CompositeSpec( diff --git a/benchmarl/algorithms/isac.py b/benchmarl/algorithms/isac.py index 3edaf565..74c29ea7 100644 --- a/benchmarl/algorithms/isac.py +++ b/benchmarl/algorithms/isac.py @@ -7,7 +7,6 @@ from dataclasses import dataclass, MISSING from typing import Dict, Iterable, Optional, Tuple, Type, Union -import torch from tensordict import TensorDictBase from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictSequential from torch.distributions import Categorical @@ -315,31 +314,12 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule: def get_continuous_value_module(self, group: str) -> TensorDictModule: n_agents = len(self.group_map[group]) modules = [] - group_observation_key = list(self.observation_spec[group].keys())[0] - modules.append( - TensorDictModule( - lambda obs, action: torch.cat([obs, action], dim=-1), - in_keys=[(group, group_observation_key), (group, "action")], - out_keys=[(group, "obs_action")], - ) - ) critic_input_spec = CompositeSpec( { - group: CompositeSpec( - { - "obs_action": UnboundedContinuousTensorSpec( - shape=( - n_agents, - self.observation_spec[ - group, group_observation_key - ].shape[-1] - + self.action_spec[group, "action"].shape[-1], - ) - ) - }, - shape=(n_agents,), - ) + group: self.observation_spec[group] + .clone() + .update(self.action_spec[group]) } ) diff --git a/benchmarl/algorithms/maddpg.py b/benchmarl/algorithms/maddpg.py index de351c54..1590f81f 100644 --- a/benchmarl/algorithms/maddpg.py +++ b/benchmarl/algorithms/maddpg.py @@ -7,7 +7,6 @@ from dataclasses import dataclass, MISSING from typing import Dict, Iterable, Tuple, Type -import torch from tensordict import TensorDictBase from tensordict.nn import TensorDictModule, TensorDictSequential from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec @@ -41,7 +40,7 @@ def __init__( loss_function: str, delay_value: bool, use_tanh_mapping: bool, - **kwargs + **kwargs, ): super().__init__(**kwargs) @@ -188,7 +187,6 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase: def get_value_module(self, group: str) -> TensorDictModule: n_agents = len(self.group_map[group]) modules = [] - group_observation_key = list(self.observation_spec[group].keys())[0] if self.share_param_critic: critic_output_spec = CompositeSpec( @@ -209,23 +207,18 @@ def get_value_module(self, group: str) -> TensorDictModule: ) if self.state_spec is not None: - global_state_key = list(self.state_spec.keys())[0] modules.append( TensorDictModule( - lambda state, action: torch.cat( - [state, action.reshape(*action.shape[:-2], -1)], dim=-1 - ), - in_keys=[global_state_key, (group, "action")], - out_keys=["state_action"], + lambda action: action.reshape(*action.shape[:-2], -1), + in_keys=[(group, "action")], + out_keys=["global_action"], ) ) - critic_input_spec = CompositeSpec( + + critic_input_spec = self.state_spec.clone().update( { - "state_action": UnboundedContinuousTensorSpec( - shape=( - self.state_spec[global_state_key].shape[-1] - + self.action_spec[group, "action"].shape[-1] * n_agents, - ) + "global_action": UnboundedContinuousTensorSpec( + shape=(self.action_spec[group, "action"].shape[-1] * n_agents,) ) } ) @@ -245,29 +238,11 @@ def get_value_module(self, group: str) -> TensorDictModule: ) else: - modules.append( - TensorDictModule( - lambda obs, action: torch.cat([obs, action], dim=-1), - in_keys=[(group, group_observation_key), (group, "action")], - out_keys=[(group, "obs_action")], - ) - ) critic_input_spec = CompositeSpec( { - group: CompositeSpec( - { - "obs_action": UnboundedContinuousTensorSpec( - shape=( - n_agents, - self.observation_spec[ - group, group_observation_key - ].shape[-1] - + self.action_spec[group, "action"].shape[-1], - ) - ) - }, - shape=(n_agents,), - ) + group: self.observation_spec[group] + .clone() + .update(self.action_spec[group]) } ) diff --git a/benchmarl/algorithms/masac.py b/benchmarl/algorithms/masac.py index 2d628ee5..358010ef 100644 --- a/benchmarl/algorithms/masac.py +++ b/benchmarl/algorithms/masac.py @@ -7,7 +7,6 @@ from dataclasses import dataclass, MISSING from typing import Dict, Iterable, Optional, Tuple, Type, Union -import torch from tensordict import TensorDictBase from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictSequential from torch.distributions import Categorical @@ -342,7 +341,6 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule: def get_continuous_value_module(self, group: str) -> TensorDictModule: n_agents = len(self.group_map[group]) modules = [] - group_observation_key = list(self.observation_spec[group].keys())[0] if self.share_param_critic: critic_output_spec = CompositeSpec( @@ -363,23 +361,19 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule: ) if self.state_spec is not None: - global_state_key = list(self.state_spec.keys())[0] + modules.append( TensorDictModule( - lambda state, action: torch.cat( - [state, action.reshape(*action.shape[:-2], -1)], dim=-1 - ), - in_keys=[global_state_key, (group, "action")], - out_keys=["state_action"], + lambda action: action.reshape(*action.shape[:-2], -1), + in_keys=[(group, "action")], + out_keys=["global_action"], ) ) - critic_input_spec = CompositeSpec( + + critic_input_spec = self.state_spec.clone().update( { - "state_action": UnboundedContinuousTensorSpec( - shape=( - self.state_spec[global_state_key].shape[-1] - + self.action_spec[group, "action"].shape[-1] * n_agents, - ) + "global_action": UnboundedContinuousTensorSpec( + shape=(self.action_spec[group, "action"].shape[-1] * n_agents,) ) } ) @@ -399,29 +393,11 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule: ) else: - modules.append( - TensorDictModule( - lambda obs, action: torch.cat([obs, action], dim=-1), - in_keys=[(group, group_observation_key), (group, "action")], - out_keys=[(group, "obs_action")], - ) - ) critic_input_spec = CompositeSpec( { - group: CompositeSpec( - { - "obs_action": UnboundedContinuousTensorSpec( - shape=( - n_agents, - self.observation_spec[ - group, group_observation_key - ].shape[-1] - + self.action_spec[group, "action"].shape[-1], - ) - ) - }, - shape=(n_agents,), - ) + group: self.observation_spec[group] + .clone() + .update(self.action_spec[group]) } ) diff --git a/benchmarl/models/cnn.py b/benchmarl/models/cnn.py index 0ecb3621..a6956fe3 100644 --- a/benchmarl/models/cnn.py +++ b/benchmarl/models/cnn.py @@ -53,8 +53,20 @@ def _number_conv_outputs( class Cnn(Model): """Convolutional Neural Network (CNN) model. - Args: + The BenchMARL CNN accepts multiple inputs of 2 types: + + - images: Tensors of shape (*batch, X,Y,C) + - arrays: Tensors of shape (*batch, F) + + The CNN model will check that all image inputs have the same shape (excluding the last dimension) + and cat them along that dimension before processing them with :class:`torchrl.modules.ConvNet`. + It will check that all array inputs have the same shape (excluding the last dimension) + and cat them along that dimension. + + It will then cat the arrays and processed images and feed them to the MLP together. + + Args: cnn_num_cells (int or Sequence of int): number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an @@ -104,9 +116,17 @@ def __init__( action_spec=kwargs.pop("action_spec"), ) - self.x = self.input_leaf_spec.shape[-3] - self.y = self.input_leaf_spec.shape[-2] - self.input_features = self.input_leaf_spec.shape[-1] + self.x = self.input_spec[self.image_in_keys[0]].shape[-3] + self.y = self.input_spec[self.image_in_keys[0]].shape[-2] + self.input_features_images = sum( + [self.input_spec[key].shape[-1] for key in self.image_in_keys] + ) + self.input_features_tensors = sum( + [self.input_spec[key].shape[-1] for key in self.tensor_in_keys] + ) + if self.input_has_agent_dim and not self.output_has_agent_dim: + # In this case the tensor features will be centralized + self.input_features_tensors *= self.n_agents self.output_features = self.output_leaf_spec.shape[-1] @@ -123,7 +143,7 @@ def __init__( if self.input_has_agent_dim: self.cnn = MultiAgentConvNet( - in_features=self.input_features, + in_features=self.input_features_images, n_agents=self.n_agents, centralised=self.centralised, share_params=self.share_params, @@ -136,7 +156,7 @@ def __init__( self.cnn = nn.ModuleList( [ ConvNet( - in_features=self.input_features, + in_features=self.input_features_images, device=self.device, **cnn_net_kwargs, ) @@ -156,7 +176,7 @@ def __init__( if self.output_has_agent_dim: self.mlp = MultiAgentMLP( - n_agent_inputs=cnn_output_size, + n_agent_inputs=cnn_output_size + self.input_features_tensors, n_agent_outputs=self.output_features, n_agents=self.n_agents, centralised=self.centralised, @@ -168,7 +188,7 @@ def __init__( self.mlp = nn.ModuleList( [ MLP( - in_features=cnn_output_size, + in_features=cnn_output_size + self.input_features_tensors, out_features=self.output_features, device=self.device, **mlp_net_kwargs, @@ -180,10 +200,49 @@ def __init__( def _perform_checks(self): super()._perform_checks() - if self.input_has_agent_dim and self.input_leaf_spec.shape[-4] != self.n_agents: + input_shape_tensor = None + self.image_in_keys = [] + input_shape_image = None + self.tensor_in_keys = [] + for input_key, input_spec in self.input_spec.items(True, True): + if (self.input_has_agent_dim and len(input_spec.shape) == 4) or ( + not self.input_has_agent_dim and len(input_spec.shape) == 3 + ): + self.image_in_keys.append(input_key) + if input_shape_image is None: + input_shape_image = input_spec.shape[:-1] + elif input_spec.shape[:-1] != input_shape_image: + raise ValueError( + f"CNN image inputs should all have the same shape up to the last dimension, got {self.input_spec}" + ) + elif (self.input_has_agent_dim and len(input_spec.shape) == 2) or ( + not self.input_has_agent_dim and len(input_spec.shape) == 1 + ): + self.tensor_in_keys.append(input_key) + if input_shape_tensor is None: + input_shape_tensor = input_spec.shape[:-1] + elif input_spec.shape[:-1] != input_shape_tensor: + raise ValueError( + f"CNN tensor inputs should all have the same shape up to the last dimension, got {self.input_spec}" + ) + else: + raise ValueError( + f"CNN input value {input_key} from {self.input_spec} has an invalid shape" + ) + + if self.input_has_agent_dim and input_shape_image[-3] != self.n_agents: + raise ValueError( + "If the CNN input has the agent dimension," + " the forth to last spec dimension of image inputs should be the number of agents" + ) + if ( + self.input_has_agent_dim + and input_shape_tensor is not None + and input_shape_tensor[-1] != self.n_agents + ): raise ValueError( "If the CNN input has the agent dimension," - " the forth to last spec dimension should be the number of agents" + " the second to last spec dimension of tensor inputs should be the number of agents" ) if ( self.output_has_agent_dim @@ -195,11 +254,25 @@ def _perform_checks(self): ) def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: - # Gather in_key - input = tensordict.get(self.in_key) + # Gather images + input = torch.cat( + [tensordict.get(in_key) for in_key in self.image_in_keys], dim=-1 + ) # BenchMARL images are X,Y,C -> we convert them to C, X, Y for processing in TorchRL models input = input.transpose(-3, -1).transpose(-2, -1) + # Gather tensor inputs + if len(self.tensor_in_keys): + tensor_inputs = torch.cat( + [tensordict.get(in_key) for in_key in self.tensor_in_keys], dim=-1 + ) + if self.input_has_agent_dim and not self.output_has_agent_dim: + tensor_inputs = tensor_inputs.reshape((*tensor_inputs.shape[:-2], -1)) + elif not self.input_has_agent_dim and self.output_has_agent_dim: + tensor_inputs = tensor_inputs.unsqueeze(-2).expand( + (*tensor_inputs.shape[:-1], self.n_agents, tensor_inputs.shape[-1]) + ) + # Has multi-agent input dimension if self.input_has_agent_dim: cnn_out = self.cnn.forward(input) @@ -219,6 +292,9 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: cnn_out = self.cnn[0](input) + if len(self.tensor_in_keys): + cnn_out = torch.cat([cnn_out, tensor_inputs], dim=-1) + # Cnn output has multi-agent input dimension if self.output_has_agent_dim: res = self.mlp.forward(cnn_out) diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py index cd711821..7b332099 100644 --- a/benchmarl/models/common.py +++ b/benchmarl/models/common.py @@ -12,7 +12,8 @@ from tensordict import TensorDictBase from tensordict.nn import TensorDictModuleBase, TensorDictSequential -from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec +from tensordict.utils import NestedKey +from torchrl.data import CompositeSpec, TensorSpec, UnboundedContinuousTensorSpec from torchrl.envs import EnvBase from benchmarl.utils import _class_from_name, _read_yaml_config, DEVICE_TYPING @@ -103,9 +104,7 @@ def __init__( self.in_keys = list(self.input_spec.keys(True, True)) self.out_keys = list(self.output_spec.keys(True, True)) - self.in_key = self.in_keys[0] self.out_key = self.out_keys[0] - self.input_leaf_spec = self.input_spec[self.in_key] self.output_leaf_spec = self.output_spec[self.out_key] self._perform_checks() @@ -120,14 +119,22 @@ def output_has_agent_dim(self) -> bool: """ return output_has_agent_dim(self.share_params, self.centralised) + @property + def in_key(self) -> NestedKey: + if len(self.in_keys) > 1: + raise ValueError("Model has more than one input key") + return self.in_keys[0] + + @property + def input_leaf_spec(self) -> TensorSpec: + return self.input_spec[self.in_key] + def _perform_checks(self): if not self.input_has_agent_dim and not self.centralised: raise ValueError( "If input does not have an agent dimension the model should be marked as centralised" ) - if len(self.in_keys) > 1: - raise ValueError("Currently models support just one input key") if len(self.out_keys) > 1: raise ValueError("Currently models support just one output key") @@ -158,7 +165,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: """ Method to implement for the forward pass of the model. - It should read self.in_key, process it and write self.out_key. + It should read self.in_keys, process it and write self.out_key. Args: tensordict (TensorDictBase): the input td diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index 6b26674a..76a0754e 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -104,7 +104,9 @@ def __init__( super().__init__(**kwargs) - self.input_features = self.input_leaf_spec.shape[-1] + self.input_features = sum( + [spec.shape[-1] for spec in self.input_spec.values(True, True)] + ) self.output_features = self.output_leaf_spec.shape[-1] if gnn_kwargs is None: @@ -142,9 +144,26 @@ def _perform_checks(self): "if your algorithm has a centralized critic and the task has a global state." ) - if self.input_leaf_spec.shape[-2] != self.n_agents: + input_shape = None + for input_key, input_spec in self.input_spec.items(True, True): + if (self.input_has_agent_dim and len(input_spec.shape) == 2) or ( + not self.input_has_agent_dim and len(input_spec.shape) == 1 + ): + if input_shape is None: + input_shape = input_spec.shape[:-1] + else: + if input_spec.shape[:-1] != input_shape: + raise ValueError( + f"GNN inputs should all have the same shape up to the last dimension, got {self.input_spec}" + ) + else: + raise ValueError( + f"GNN input value {input_key} from {self.input_spec} has an invalid shape" + ) + + if input_shape[-1] != self.n_agents: raise ValueError( - "The second to last input spec dimension should be the number of agents" + f"The second to last input spec dimension should be the number of agents, got {self.input_spec}" ) if ( self.output_has_agent_dim @@ -157,7 +176,7 @@ def _perform_checks(self): def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Gather in_key - input = tensordict.get(self.in_key) + input = torch.cat([tensordict.get(in_key) for in_key in self.in_keys], dim=-1) batch_size = input.shape[:-2] diff --git a/benchmarl/models/mlp.py b/benchmarl/models/mlp.py index dfd49131..9f98c37b 100644 --- a/benchmarl/models/mlp.py +++ b/benchmarl/models/mlp.py @@ -48,7 +48,9 @@ def __init__( action_spec=kwargs.pop("action_spec"), ) - self.input_features = self.input_leaf_spec.shape[-1] + self.input_features = sum( + [spec.shape[-1] for spec in self.input_spec.values(True, True)] + ) self.output_features = self.output_leaf_spec.shape[-1] if self.input_has_agent_dim: @@ -77,11 +79,28 @@ def __init__( def _perform_checks(self): super()._perform_checks() - if self.input_has_agent_dim and self.input_leaf_spec.shape[-2] != self.n_agents: - raise ValueError( - "If the MLP input has the agent dimension," - " the second to last spec dimension should be the number of agents" - ) + input_shape = None + for input_key, input_spec in self.input_spec.items(True, True): + if (self.input_has_agent_dim and len(input_spec.shape) == 2) or ( + not self.input_has_agent_dim and len(input_spec.shape) == 1 + ): + if input_shape is None: + input_shape = input_spec.shape[:-1] + else: + if input_spec.shape[:-1] != input_shape: + raise ValueError( + f"MLP inputs should all have the same shape up to the last dimension, got {self.input_spec}" + ) + else: + raise ValueError( + f"MLP input value {input_key} from {self.input_spec} has an invalid shape" + ) + if self.input_has_agent_dim: + if input_shape[-1] != self.n_agents: + raise ValueError( + "If the MLP input has the agent dimension," + " the second to last spec dimension should be the number of agents, got {self.input_spec}" + ) if ( self.output_has_agent_dim and self.output_leaf_spec.shape[-2] != self.n_agents @@ -93,7 +112,7 @@ def _perform_checks(self): def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Gather in_key - input = tensordict.get(self.in_key) + input = torch.cat([tensordict.get(in_key) for in_key in self.in_keys], dim=-1) # Has multi-agent input dimension if self.input_has_agent_dim: diff --git a/examples/extending/model/custom_model.py b/examples/extending/model/custom_model.py index e4689ec7..7dea32cb 100644 --- a/examples/extending/model/custom_model.py +++ b/examples/extending/model/custom_model.py @@ -75,7 +75,9 @@ def __init__( # and the dimension should be absent otherwise _ = self.output_has_agent_dim - self.input_features = self.input_leaf_spec.shape[-1] + self.input_features = sum( + [spec.shape[-1] for spec in self.input_spec.values(True, True)] + ) self.output_features = self.output_leaf_spec.shape[-1] if self.input_has_agent_dim and not self.centralised: @@ -121,11 +123,21 @@ def _perform_checks(self): super()._perform_checks() # Run some checks - if self.input_has_agent_dim and self.input_leaf_spec.shape[-2] != self.n_agents: - raise ValueError( - "If the MLP input has the agent dimension," - " the second to last spec dimension should be the number of agents" - ) + input_shape = None + for input_spec in self.input_spec.values(True, True): + if input_shape is None: + input_shape = input_spec.shape[:-1] + else: + if input_spec.shape[:-1] != input_shape: + raise ValueError( + f"MLP inputs should all have the same shape up to the last dimension, got {self.input_spec}" + ) + if self.input_has_agent_dim: + if input_shape[-1] != self.n_agents: + raise ValueError( + "If the MLP input has the agent dimension," + " the second to last spec dimension should be the number of agents, got {self.input_spec}" + ) if ( self.output_has_agent_dim and self.output_leaf_spec.shape[-2] != self.n_agents @@ -137,7 +149,7 @@ def _perform_checks(self): def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Gather in_key - input = tensordict.get(self.in_key) + input = torch.cat([tensordict.get(in_key) for in_key in self.in_keys], dim=-1) # Input has multi-agent input dimension if self.input_has_agent_dim: diff --git a/test/test_models.py b/test/test_models.py index dfe6b57a..b88435eb 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -105,6 +105,9 @@ def test_models_forward_shape( multi_agent_tensor = torch.rand((*batch_size, n_agents, channels)) single_agent_tensor = torch.rand((*batch_size, channels)) + other_multi_agent_tensor = torch.rand((*batch_size, n_agents, channels)) + other_single_agent_tensor = torch.rand((*batch_size, channels)) + if input_has_agent_dim: input_spec = CompositeSpec( { @@ -112,7 +115,10 @@ def test_models_forward_shape( { "observation": UnboundedContinuousTensorSpec( shape=multi_agent_tensor.shape[len(batch_size) :] - ) + ), + "other": UnboundedContinuousTensorSpec( + shape=other_multi_agent_tensor.shape[len(batch_size) :] + ), }, shape=(n_agents,), ) @@ -123,7 +129,10 @@ def test_models_forward_shape( { "observation": UnboundedContinuousTensorSpec( shape=single_agent_tensor.shape[len(batch_size) :] - ) + ), + "other": UnboundedContinuousTensorSpec( + shape=other_single_agent_tensor.shape[len(batch_size) :] + ), }, )