diff --git a/benchmarl/environments/meltingpot/common.py b/benchmarl/environments/meltingpot/common.py index 25c9368f..ee99d5e7 100644 --- a/benchmarl/environments/meltingpot/common.py +++ b/benchmarl/environments/meltingpot/common.py @@ -11,7 +11,6 @@ from torchrl.data import CompositeSpec from torchrl.envs import DoubleToFloat, DTypeCastTransform, EnvBase, Transform -from torchrl.envs.libs.meltingpot import MeltingpotEnv from benchmarl.environments.common import Task from benchmarl.utils import DEVICE_TYPING @@ -29,6 +28,8 @@ def get_env_fun( seed: Optional[int], device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: + from torchrl.envs.libs.meltingpot import MeltingpotEnv + return lambda: MeltingpotEnv( substrate=self.name.lower(), categorical_actions=True,