Skip to content

Commit

Permalink
[Errors] Error on unavailable combinations (#142)
Browse files Browse the repository at this point in the history
* bugs

* bugs

* bugs

* bugs

* bugs

* bugs

* bugs

* bugs
  • Loading branch information
matteobettini authored Nov 14, 2024
1 parent 747e0c4 commit 9813807
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@
from torchrl.record.loggers import generate_exp_name
from tqdm import tqdm

from benchmarl.algorithms import IppoConfig, MappoConfig

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task
from benchmarl.experiment.callback import Callback, CallbackNotifier
from benchmarl.experiment.logger import Logger
from benchmarl.models import GnnConfig, SequenceModelConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import _read_yaml_config, seed_everything

Expand Down Expand Up @@ -361,7 +364,30 @@ def _setup(self):
self._on_setup()

def _perfrom_checks(self):
pass
for config in (self.model_config, self.critic_model_config):
if isinstance(config, SequenceModelConfig):
for layer_config in config.model_configs[1:]:
if isinstance(layer_config, GnnConfig) and (
layer_config.position_key is not None
or layer_config.velocity_key is not None
):
raise ValueError(
"GNNs reading position or velocity keys are currently only usable in first"
" layer of sequence models"
)

if self.algorithm_config in (MappoConfig, IppoConfig):
critic_model_config = self.critic_model_config
if isinstance(critic_model_config, SequenceModelConfig):
critic_model_config = self.critic_model_config.model_configs[0]
if (
isinstance(critic_model_config, GnnConfig)
and critic_model_config.topology == "from_pos"
):
raise ValueError(
"GNNs in PPO critics with topology 'from_pos' are currently not available, "
"see https://github.com/pytorch/rl/issues/2537"
)

def _set_action_type(self):
if (
Expand Down

0 comments on commit 9813807

Please sign in to comment.