diff --git a/.github/unittest/install_dependencies_nightly.sh b/.github/unittest/install_dependencies_nightly.sh index a32c38a5..6f90c4f9 100644 --- a/.github/unittest/install_dependencies_nightly.sh +++ b/.github/unittest/install_dependencies_nightly.sh @@ -11,7 +11,8 @@ python -m pip install torch cd ../BenchMARL pip install -e . -pip uninstal torchrl tensordict +pip uninstall --yes torchrl +pip uninstall --yes tensordict cd .. python -m pip install git+https://github.com/pytorch-labs/tensordict.git diff --git a/benchmarl/models/cnn.py b/benchmarl/models/cnn.py index 91f6c667..29740df6 100644 --- a/benchmarl/models/cnn.py +++ b/benchmarl/models/cnn.py @@ -137,10 +137,8 @@ def __init__( device=self.device, **cnn_net_kwargs, ) - if "_empty_net" in self.cnn.__dict__: - example_net = self.cnn._empty_net - else: - example_net = self.cnn.agent_networks[0] + example_net = self.cnn._empty_net + else: self.cnn = nn.ModuleList( [