diff --git a/.codecov.yml b/.codecov.yml deleted file mode 100644 index a92352ef..00000000 --- a/.codecov.yml +++ /dev/null @@ -1,46 +0,0 @@ -#see https://github.com/codecov/support/wiki/Codecov-Yaml -codecov: - require_ci_to_pass: yes - -coverage: - - # 2 = xx.xx%, 0 = xx% - precision: 2 - - # https://docs.codecov.com/docs/commit-status - status: - - # We want our total main project to always remain above 87% coverage, a - # drop of 0.20% is allowed. It should fail if coverage couldn't be uploaded - # of the CI fails otherwise - project: - default: - target: 10% - threshold: 0.20% - if_not_found: failure - if_ci_failed: error - - # The code changed by a PR should have 90% coverage. This is different from the - # overall number shown above. - # This encourages small PR's as they are easier to test. - patch: - default: - target: 10% - if_not_found: failure - if_ci_failed: failure - -# We upload additional information on branching with pytest-cov `--cov-branch` -# This information can be used by codecov.com to increase analysis of code -parsers: - gcov: - branch_detection: - conditional: true - loop: true - method: true - macro: false - - -comment: - layout: diff, reach - behavior: default - require_changes: false diff --git a/.github/ISSUE_TEMPLATE/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE/ISSUE_TEMPLATE.md new file mode 100644 index 00000000..b79752ce --- /dev/null +++ b/.github/ISSUE_TEMPLATE/ISSUE_TEMPLATE.md @@ -0,0 +1,37 @@ +--- +name: Issue Template +about: General template issues +labels: + +--- + +* {{ cookiecutter.project_name }} version: +* Python version: +* Operating System: + + + + +#### Description + + +#### Steps/Code to Reproduce + + +#### Expected Results + + +#### Actual Results + + +#### Additional Info + +- Did you try upgrading to the most current version? yes/no +- Are you using a supported operating system (version)? yes/no +- How did you install this package (e.g. GitHub, pip, etc.)? + + \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..6556baad --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,39 @@ + + +#### Reference Issues/PRs + + +#### What does this implement/fix? Explain your changes. + + + +#### Checklist + +- Are the tests passing locally? yes/no +- Is the pre-commit passing locally? yes/no +- Are all new features documented in code and docs? yes/no +- Are all examples still running? yes/no +- Are the requirements up to date? yes/no +- Did you add yourself to the contributors in the authors file? yes/no + +#### Any other comments? + + \ No newline at end of file diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 26a20757..4649c1ef 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -21,10 +21,6 @@ on: - main - development - schedule: - # Every day at 7AM UTC - - cron: '0 07 * * *' - env: # Arguments used for pytest @@ -47,32 +43,17 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest, macos-latest, ubuntu-latest] - python-version: ['3.9', '3.10'] - kind: ['conda', 'source', 'dist'] - - exclude: - # Exclude all configurations *-*-dist, but include one later in `include` - - kind: 'dist' - - # Exclude windows as bash commands wont work in windows runner - - os: windows-latest - - # Exclude macos as there are permission errors using conda as we do - - os: macos-latest + os: [ubuntu-latest] + python-version: ['3.9', '3.10', '3.11'] + kind: ['conda'] include: # Add the tag code-cov to ubuntu-3.7-source - os: ubuntu-latest python-version: 3.9 - kind: 'source' + kind: 'conda' code-cov: true - # Include one config with dist, ubuntu-3.7-dist - - os: ubuntu-latest - python-version: 3.9 - kind: 'dist' - steps: - name: Checkout @@ -91,21 +72,7 @@ jobs: # Miniconda is available in $CONDA env var $CONDA/bin/conda create -n testenv --yes pip wheel gxx_linux-64 gcc_linux-64 python=${{ matrix.python-version }} $CONDA/envs/testenv/bin/python3 -m pip install --upgrade pip - $CONDA/envs/testenv/bin/pip3 install -e ".[dev,box2d,brax,dm_control,mario]" - - - name: Source install - if: matrix.kind == 'source' - run: | - python -m pip install --upgrade pip - pip install -e ".[dev,box2d,brax,dm_control,mario]" - - - name: Dist install - if: matrix.kind == 'dist' - run: | - python -m pip install --upgrade pip - python setup.py sdist - last_dist=$(ls -t dist/carl-*.tar.gz | head -n 1) - pip install $last_dist[dev,box2d,brax,dm_control,mario] + $CONDA/envs/testenv/bin/pip3 install -e .[dev,dm_control,mario,brax,box2d] - name: Tests timeout-minutes: 60 @@ -123,10 +90,3 @@ jobs: else $PYTHON -m pytest ${{ env.pytest-args }} --ignore=test/local_only test fi - - - name: Upload coverage - if: matrix.code-cov && always() - uses: codecov/codecov-action@v2 - with: - fail_ci_if_error: true - verbose: true diff --git a/.readthedocs.yaml b/.readthedocs.yaml deleted file mode 100644 index 98fba55a..00000000 --- a/.readthedocs.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# use version 2, which is now recommended -version: 2 - -build: - os: ubuntu-20.04 - tools: - python: "3.9" - -# Build from the docs/ directory with Sphinx -sphinx: - configuration: docs/conf.py - -# build all -formats: all - -# Explicitly set the version of Python and its requirements -python: - install: - - method: pip - path: . - extra_requirements: - - docs diff --git a/CITATION.bib b/CITATION.bib deleted file mode 100644 index f6a2b599..00000000 --- a/CITATION.bib +++ /dev/null @@ -1,14 +0,0 @@ -@inproceedings { BenEim2023a, - author = {Carolin Benjamins and - Theresa Eimer and - Frederik Schubert and - Aditya Mohan and - Sebastian Döhler and - André Biedenkapp and - Bodo Rosenhahn and - Frank Hutter and - Marius Lindauer}, - title = {Contextualize Me - The Case for Context in Reinforcement Learning}, - journal = {Transactions on Machine Learning Research}, - year = {2023}, -} diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 00000000..509cd5f7 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,69 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite our paper:" + +url: "https://automl.github.io/CARL/" +repository-code: "https://github.com/automl/CARL" +title: "CARL - Context Adaptive Reinforcement Learning" + +authors: + - family-names: "Benjamins" + given-names: "Carolin" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Eimer" + given-names: "Theresa" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Schubert" + given-names: "Frederik" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Mohan" + given-names: "Aditya" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Döhler" + given-names: "Sebastian" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Biedenkapp" + given-names: "André" + affiliation: "Albert-Ludwigs University Freiburg, Germany" + - family-names: "Rosenhahn" + given-names: "Bodo" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Hutter" + given-names: "Frank" + affiliation: "Albert-Ludwigs University Freiburg, Germany" + - family-names: "Lindauer" + given-names: "Marius" + affiliation: "Leibniz University Hannover, Germany" + +preferred-citation: + type: "article" + title: "Contextualize Me - The Case for Context in Reinforcement Learning" + year: 2023 + journal: "Transactions on Machine Learning Research" + authors: + - family-names: "Benjamins" + given-names: "Carolin" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Eimer" + given-names: "Theresa" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Schubert" + given-names: "Frederik" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Mohan" + given-names: "Aditya" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Döhler" + given-names: "Sebastian" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Biedenkapp" + given-names: "André" + affiliation: "Albert-Ludwigs University Freiburg, Germany" + - family-names: "Rosenhahn" + given-names: "Bodo" + affiliation: "Leibniz University Hannover, Germany" + - family-names: "Hutter" + given-names: "Frank" + affiliation: "Albert-Ludwigs University Freiburg, Germany" + - family-names: "Lindauer" + given-names: "Marius" + affiliation: "Leibniz University Hannover, Germany" \ No newline at end of file diff --git a/Makefile b/Makefile index 995451d5..968def34 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,9 @@ install-dev: $(PIP) install -e ".[dev, docs]" pre-commit install +install: + $(PIP) install -e . + check-black: $(BLACK) carl test --check || : @@ -63,7 +66,7 @@ pre-commit: $(PRECOMMIT) run --all-files format-black: - $(BLACK) carl test + $(BLACK) carl test examples format-isort: $(ISORT) carl test @@ -71,7 +74,10 @@ format-isort: format: format-black format-isort test: - $(PYTEST) test + $(PYTEST) --disable-warnings test + +cov-report: + coverage html -d coverage_html clean-doc: $(MAKE) -C ${DOCDIR} clean diff --git a/README.md b/README.md index 2a5ed021..c379ec96 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,11 @@ CARL # – The Benchmark Library +[![PyPI Version](https://img.shields.io/pypi/v/carl-bench.svg)](https://pypi.python.org/pypi/carl-bench) +[![Test](https://github.com/automl/carl/actions/workflows/tests.yaml/badge.svg)](https://github.com/automl/carl/actions/workflows/tests.yaml) +[![Doc Status](https://github.com/automl/carl/actions/workflows/docs.yaml/badge.svg)](https://github.com/automl/carl/actions/workflows/docs.yaml) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + CARL (context adaptive RL) provides highly configurable contextual extensions to several well-known RL environments. It's designed to test your agent's generalization capabilities diff --git a/carl/__init__.py b/carl/__init__.py index f1391313..a7bb344f 100644 --- a/carl/__init__.py +++ b/carl/__init__.py @@ -1,9 +1,11 @@ __license__ = "Apache-2.0 License" -__version__ = "1.0.0" +__version__ = "1.1.0" __author__ = "Carolin Benjamins, Theresa Eimer, Frederik Schubert, André Biedenkapp, Aditya Mohan, Sebastian Döhler" import datetime +import importlib.util as iutil +import warnings name = "CARL" package_name = "carl-bench" @@ -20,3 +22,66 @@ Copyright {datetime.date.today().strftime('%Y')}, AutoML.org Freiburg-Hannover """ version = __version__ + +try: + from gymnasium.envs.registration import register + + from carl import envs + + for e in envs.gymnasium.classic_control.__all__: + register( + id=f"carl/{e}-v0", + entry_point=f"carl.envs.gymnasium.classic_control:{e}", + ) + + def check_spec(spec_name: str) -> bool: + """Check if the spec is installed + + Parameters + ---------- + spec_name : str + Name of package that is necessary for the environment suite. + + Returns + ------- + bool + Whether the spec was found. + """ + spec = iutil.find_spec(spec_name) + found = spec is not None + if not found: + with warnings.catch_warnings(): + warnings.simplefilter("once") + warnings.warn( + f"""Module {spec_name} not found. If you want to use these environments, + please follow the installation guide.""" + ) + return found + + found = check_spec("Box2D") + if found: + for e in envs.gymnasium.box2d.__all__: + register( + id=f"carl/{e}-v0", + entry_point=f"carl.envs.gymnasium.box2d:{e}", + ) + + found = check_spec("dm_control") + if found: + for e in envs.dmc.__all__: + register( + id=f"carl/{e}-v0", + entry_point=f"carl.envs.dmc:{e}", + ) + + found = check_spec("py4j") + if found: + register( + id="carl/CARLMarioEnv-v0", + entry_point="carl.envs.mario:CARLMarioEnv", + ) +except: + print( + """Gym registration failed - this is normal during installation. + After that, please check that gymnasium is installed correctly.""" + ) diff --git a/carl/context/context_space.py b/carl/context/context_space.py index 4409a370..8a8cf5ce 100644 --- a/carl/context/context_space.py +++ b/carl/context/context_space.py @@ -219,7 +219,7 @@ def sample_contexts( contexts = [] for _ in range(size): - context = {cf.name: cf.sample() for cf in self.context_space.values()} + context = {cf.name: cf.rvs() for cf in self.context_space.values()} context = self.insert_defaults(context, context_keys) contexts += [context] diff --git a/carl/context/search_space_encoding.py b/carl/context/search_space_encoding.py index 893a71d4..ae955d3a 100644 --- a/carl/context/search_space_encoding.py +++ b/carl/context/search_space_encoding.py @@ -106,11 +106,11 @@ def search_space_to_config_space( ------- ConfigurationSpace """ - if type(search_space) == str: + if isinstance(search_space, str): with open(search_space, "r") as f: jason_string = f.read() cs = csjson.read(jason_string) - elif type(search_space) == DictConfig: + elif isinstance(search_space, DictConfig): # reorder hyperparameters as List[Dict] hyperparameters = [] for name, cfg in search_space.hyperparameters.items(): @@ -130,8 +130,10 @@ def search_space_to_config_space( jason_string = json.dumps(search_space, cls=JSONCfgEncoder) cs = csjson.read(jason_string) - elif type(search_space) == ConfigurationSpace: + elif isinstance(search_space, ConfigurationSpace): cs = search_space + elif isinstance(search_space, dict): + cs = csjson.read(json.dumps(search_space)) else: raise ValueError( f"search_space must be of type str or DictConfig. Got {type(search_space)}." diff --git a/carl/envs/__init__.py b/carl/envs/__init__.py index 7d34ffc1..0578a7a5 100644 --- a/carl/envs/__init__.py +++ b/carl/envs/__init__.py @@ -6,6 +6,14 @@ # Classic control is in gym and thus necessary for the base version to run from carl.envs.gymnasium import * +__all__ = [ + "CARLAcrobot", + "CARLCartPole", + "CARLMountainCar", + "CARLMountainCarContinuous", + "CARLPendulum", +] + def check_spec(spec_name: str) -> bool: """Check if the spec is installed @@ -36,18 +44,44 @@ def check_spec(spec_name: str) -> bool: if found: from carl.envs.gymnasium.box2d import * + __all__ += ["CARLBipedalWalker", "CARLLunarLander", "CARLVehicleRacing"] + found = check_spec("brax") if found: from carl.envs.brax import * + __all__ += [ + "CARLBraxAnt", + "CARLBraxHalfcheetah", + "CARLBraxHopper", + "CARLBraxHumanoid", + "CARLBraxHumanoidStandup", + "CARLBraxInvertedDoublePendulum", + "CARLBraxInvertedPendulum", + "CARLBraxPusher", + "CARLBraxReacher", + "CARLBraxWalker2d", + ] + found = check_spec("py4j") if found: from carl.envs.mario import * + __all__ += ["CARLMarioEnv"] + found = check_spec("dm_control") if found: from carl.envs.dmc import * -# found = check_spec("distance") -# if found: -# from carl.envs.rna import * + __all__ += [ + "CARLDmcFingerEnv", + "CARLDmcFishEnv", + "CARLDmcQuadrupedEnv", + "CARLDmcWalkerEnv", + ] + +found = check_spec("distance") +if found: + from carl.envs.rna import * + + __all__ += ["CARLRnaDesignEnv"] diff --git a/carl/envs/brax/brax_walker_goal_wrapper.py b/carl/envs/brax/brax_walker_goal_wrapper.py new file mode 100644 index 00000000..9c7a49c8 --- /dev/null +++ b/carl/envs/brax/brax_walker_goal_wrapper.py @@ -0,0 +1,180 @@ +import gym +import numpy as np +from brax.io import mjcf +from etils import epath + +STATE_INDICES = { + "ant": [13, 14], + "humanoid": [22, 23], + "halfcheetah": [14, 15], + "hopper": [5, 6], + "walker2d": [8, 9], +} + +DIRECTION_NAMES = { + 1: "north", + 3: "south", + 2: "east", + 4: "west", + 12: "north east", + 32: "south east", + 14: "north west", + 34: "south west", + 112: "north north east", + 332: "south south east", + 114: "north north west", + 334: "south south west", + 212: "east north east", + 232: "east south east", + 414: "west north west", + 434: "west south west", +} + +directions = [ + 1, # north + 3, # south + 2, # east + 4, # west + 12, + 32, + 14, + 34, + 112, + 332, + 114, + 334, + 212, + 232, + 414, + 434, +] + + +class BraxWalkerGoalWrapper(gym.Wrapper): + """Adds a positional goal to brax walker envs""" + + def __init__(self, env: gym.Env, env_name: str, asset_path: str) -> None: + super().__init__(env) + self.env_name = env_name + if ( + self.env_name == "humanoid" + or self.env_name == "halfcheetah" + or self.env_name == "hopper" + or self.env_name == "walker2d" + ): + self.env._forward_reward_weight = 0 + self.context = None + self.position = None + self.goal_position = None + self.goal_radius = None + self.direction_values = { + 3: [0, -1], + 1: [0, 1], + 2: [1, 0], + 4: [-1, 0], + 34: [-np.sqrt(0.5), -np.sqrt(0.5)], + 14: [-np.sqrt(0.5), np.sqrt(0.5)], + 32: [np.sqrt(0.5), -np.sqrt(0.5)], + 12: [np.sqrt(0.5), np.sqrt(0.5)], + 334: [ + -np.cos(22.5 * np.pi / 180), + -np.sin(22.5 * np.pi / 180), + ], + 434: [ + -np.sin(22.5 * np.pi / 180), + -np.cos(22.5 * np.pi / 180), + ], + 114: [ + -np.cos(22.5 * np.pi / 180), + np.sin(22.5 * np.pi / 180), + ], + 414: [ + -np.sin(22.5 * np.pi / 180), + np.cos(22.5 * np.pi / 180), + ], + 332: [ + np.cos(22.5 * np.pi / 180), + -np.sin(22.5 * np.pi / 180), + ], + 232: [ + np.sin(22.5 * np.pi / 180), + -np.cos(22.5 * np.pi / 180), + ], + 112: [ + np.cos(22.5 * np.pi / 180), + np.sin(22.5 * np.pi / 180), + ], + 212: [np.sin(22.5 * np.pi / 180), np.cos(22.5 * np.pi / 180)], + } + path = epath.resource_path("brax") / asset_path + sys = mjcf.load(path) + self.dt = sys.dt + + def reset(self, seed=None, options={}): + state, info = self.env.reset(seed=seed, options=options) + self.position = (0, 0) + self.goal_position = ( + np.array(self.direction_values[self.context["target_direction"]]) + * self.context["target_distance"] + ) + self.goal_radius = self.context["target_radius"] + info["success"] = 0 + return state, info + + def step(self, action): + state, _, te, tr, info = self.env.step(action) + indices = STATE_INDICES[self.env_name] + new_position = ( + np.array(list(self.position)) + + np.array([state[indices[0]], state[indices[1]]]) * self.dt + ) + current_distance_to_goal = np.linalg.norm(self.goal_position - new_position) + previous_distance_to_goal = np.linalg.norm(self.goal_position - self.position) + direction_reward = max(0, previous_distance_to_goal - current_distance_to_goal) + self.position = new_position + if abs(current_distance_to_goal) <= self.goal_radius: + te = True + info["success"] = 1 + else: + info["success"] = 0 + return state, direction_reward, te, tr, info + + +class BraxLanguageWrapper(gym.Wrapper): + """Translates the context features target distance and target radius into language""" + + def __init__(self, env) -> None: + super().__init__(env) + self.context = None + + def reset(self, seed=None, options={}): + self.env.context = self.context + state, info = self.env.reset(seed=seed, options=options) + goal_str = self.get_goal_desc(self.context) + if isinstance(state, dict): + state["goal"] = goal_str + else: + state = {"obs": state, "goal": goal_str} + return state, info + + def step(self, action): + state, reward, te, tr, info = self.env.step(action) + goal_str = self.get_goal_desc(self.context) + if isinstance(state, dict): + state["goal"] = goal_str + else: + state = {"obs": state, "goal": goal_str} + return state, reward, te, tr, info + + def get_goal_desc(self, context): + if "target_radius" in context.keys(): + target_distance = context["target_distance"] + target_direction = context["target_direction"] + target_radius = context["target_radius"] + return f"""The distance to the goal is {target_distance}m + {DIRECTION_NAMES[target_direction]}. + Move within {target_radius} steps of the goal.""" + else: + target_distance = context["target_distance"] + target_direction = context["target_direction"] + return f"Move {target_distance}m {DIRECTION_NAMES[target_direction]}." diff --git a/carl/envs/brax/carl_ant.py b/carl/envs/brax/carl_ant.py index e8bb6d7c..ee181ecc 100644 --- a/carl/envs/brax/carl_ant.py +++ b/carl/envs/brax/carl_ant.py @@ -2,13 +2,19 @@ import numpy as np -from carl.context.context_space import ContextFeature, UniformFloatContextFeature +from carl.context.context_space import ( + CategoricalContextFeature, + ContextFeature, + UniformFloatContextFeature, +) +from carl.envs.brax.brax_walker_goal_wrapper import directions from carl.envs.brax.carl_brax_env import CARLBraxEnv class CARLBraxAnt(CARLBraxEnv): env_name: str = "ant" asset_path: str = "envs/assets/ant.xml" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: @@ -31,4 +37,13 @@ def get_context_features() -> dict[str, ContextFeature]: "viscosity": UniformFloatContextFeature( "viscosity", lower=0, upper=np.inf, default_value=0 ), + "target_distance": UniformFloatContextFeature( + "target_distance", lower=0, upper=np.inf, default_value=100 + ), + "target_direction": CategoricalContextFeature( + "target_direction", choices=directions, default_value=1 + ), + "target_radius": UniformFloatContextFeature( + "target_radius", lower=0.1, upper=np.inf, default_value=5 + ), } diff --git a/carl/envs/brax/carl_brax_env.py b/carl/envs/brax/carl_brax_env.py index dffc0573..3b774c6e 100644 --- a/carl/envs/brax/carl_brax_env.py +++ b/carl/envs/brax/carl_brax_env.py @@ -13,9 +13,13 @@ from jax import numpy as jp from carl.context.selection import AbstractSelector +from carl.envs.brax.brax_walker_goal_wrapper import ( + BraxLanguageWrapper, + BraxWalkerGoalWrapper, +) from carl.envs.brax.wrappers import GymWrapper, VectorGymWrapper from carl.envs.carl_env import CARLEnv -from carl.utils.types import Contexts +from carl.utils.types import Context, Contexts def set_geom_attr( @@ -152,6 +156,7 @@ def __init__( obs_context_as_dict: bool = True, context_selector: AbstractSelector | type[AbstractSelector] | None = None, context_selector_kwargs: dict = None, + use_language_goals: bool = False, **kwargs, ) -> None: """ @@ -204,6 +209,37 @@ def __init__( dtype=np.float32, ) + if contexts is not None: + if ( + "target_distance" in contexts[list(contexts.keys())[0]].keys() + or "target_direction" in contexts[list(contexts.keys())[0]].keys() + ): + assert all( + [ + "target_direction" in contexts[list(contexts.keys())[i]].keys() + for i in range(len(contexts)) + ] + ), "All contexts must have a 'target_direction' key" + assert all( + [ + "target_distance" in contexts[list(contexts.keys())[i]].keys() + for i in range(len(contexts)) + ] + ), "All contexts must have a 'target_distance' key" + base_dir = contexts[list(contexts.keys())[0]]["target_direction"] + base_dist = contexts[list(contexts.keys())[0]]["target_distance"] + max_diff_dir = max( + [c["target_direction"] - base_dir for c in contexts.values()] + ) + max_diff_dist = max( + [c["target_distance"] - base_dist for c in contexts.values()] + ) + if max_diff_dir > 0.1 or max_diff_dist > 0.1: + env = BraxWalkerGoalWrapper(env, self.env_name, self.asset_path) + if use_language_goals: + env = BraxLanguageWrapper(env) + self.use_language_goals = use_language_goals + super().__init__( env=env, contexts=contexts, @@ -213,6 +249,7 @@ def __init__( context_selector_kwargs=context_selector_kwargs, **kwargs, ) + self.env.context = self.context def _update_context(self) -> None: context = self.context @@ -224,6 +261,9 @@ def _update_context(self) -> None: "gravity", "viscosity", "elasticity", + "target_distance", + "target_direction", + "target_radius", ] check_context(context, registered_cfs) @@ -252,3 +292,47 @@ def _update_context(self) -> None: sys = sys.replace(geoms=updated_geoms) self.env.unwrapped.sys = sys + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[Any, dict[str, Any]]: + """Overwrites reset in super to update context in wrapper.""" + last_context_id = self.context_id + self._progress_instance() + if self.context_id != last_context_id: + self._update_context() + self.env.context = self.context + state, info = self.env.reset(seed=seed, options=options) + state = self._add_context_to_state(state) + info["context_id"] = self.context_id + return state, info + + @classmethod + def get_default_context(cls) -> Context: + """Get the default context (without any goal features) + + Returns + ------- + Context + Default context. + """ + default_context = cls.get_context_space().get_default_context() + if "target_distance" in default_context: + del default_context["target_distance"] + if "target_direction" in default_context: + del default_context["target_direction"] + if "target_radius" in default_context: + del default_context["target_radius"] + return default_context + + @classmethod + def get_default_goal_context(cls) -> Context: + """Get the default context (with goal features) + + Returns + ------- + Context + Default context. + """ + default_context = cls.get_context_space().get_default_context() + return default_context diff --git a/carl/envs/brax/carl_halfcheetah.py b/carl/envs/brax/carl_halfcheetah.py index c1a69e46..97c8d7ab 100644 --- a/carl/envs/brax/carl_halfcheetah.py +++ b/carl/envs/brax/carl_halfcheetah.py @@ -2,13 +2,19 @@ import numpy as np -from carl.context.context_space import ContextFeature, UniformFloatContextFeature +from carl.context.context_space import ( + CategoricalContextFeature, + ContextFeature, + UniformFloatContextFeature, +) +from carl.envs.brax.brax_walker_goal_wrapper import directions from carl.envs.brax.carl_brax_env import CARLBraxEnv class CARLBraxHalfcheetah(CARLBraxEnv): env_name: str = "halfcheetah" asset_path: str = "envs/assets/half_cheetah.xml" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: @@ -49,4 +55,13 @@ def get_context_features() -> dict[str, ContextFeature]: "mass_ffoot": UniformFloatContextFeature( "mass_ffoot", lower=1e-6, upper=np.inf, default_value=0.8845188 ), + "target_distance": UniformFloatContextFeature( + "target_distance", lower=0, upper=np.inf, default_value=100 + ), + "target_direction": CategoricalContextFeature( + "target_direction", choices=directions, default_value=1 + ), + "target_radius": UniformFloatContextFeature( + "target_radius", lower=0.1, upper=np.inf, default_value=5 + ), } diff --git a/carl/envs/brax/carl_hopper.py b/carl/envs/brax/carl_hopper.py index be9c1699..1b042e1b 100644 --- a/carl/envs/brax/carl_hopper.py +++ b/carl/envs/brax/carl_hopper.py @@ -2,13 +2,19 @@ import numpy as np -from carl.context.context_space import ContextFeature, UniformFloatContextFeature +from carl.context.context_space import ( + CategoricalContextFeature, + ContextFeature, + UniformFloatContextFeature, +) +from carl.envs.brax.brax_walker_goal_wrapper import directions from carl.envs.brax.carl_brax_env import CARLBraxEnv class CARLBraxHopper(CARLBraxEnv): env_name: str = "hopper" asset_path: str = "envs/assets/hopper.xml" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: @@ -40,4 +46,13 @@ def get_context_features() -> dict[str, ContextFeature]: "mass_foot": UniformFloatContextFeature( "mass_foot", lower=1e-6, upper=np.inf, default_value=5.3155746 ), + "target_distance": UniformFloatContextFeature( + "target_distance", lower=0, upper=np.inf, default_value=100 + ), + "target_direction": CategoricalContextFeature( + "target_direction", choices=directions, default_value=1 + ), + "target_radius": UniformFloatContextFeature( + "target_radius", lower=0.1, upper=np.inf, default_value=5 + ), } diff --git a/carl/envs/brax/carl_humanoid.py b/carl/envs/brax/carl_humanoid.py index 27a57146..4ddbe3b0 100644 --- a/carl/envs/brax/carl_humanoid.py +++ b/carl/envs/brax/carl_humanoid.py @@ -2,13 +2,19 @@ import numpy as np -from carl.context.context_space import ContextFeature, UniformFloatContextFeature +from carl.context.context_space import ( + CategoricalContextFeature, + ContextFeature, + UniformFloatContextFeature, +) +from carl.envs.brax.brax_walker_goal_wrapper import directions from carl.envs.brax.carl_brax_env import CARLBraxEnv class CARLBraxHumanoid(CARLBraxEnv): env_name: str = "humanoid" asset_path: str = "envs/assets/humanoid.xml" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: @@ -67,4 +73,13 @@ def get_context_features() -> dict[str, ContextFeature]: "mass_left_lower_arm": UniformFloatContextFeature( "mass_left_lower_arm", lower=1e-6, upper=np.inf, default_value=1.2295402 ), + "target_distance": UniformFloatContextFeature( + "target_distance", lower=0, upper=np.inf, default_value=100 + ), + "target_direction": CategoricalContextFeature( + "target_direction", choices=directions, default_value=1 + ), + "target_radius": UniformFloatContextFeature( + "target_radius", lower=0.1, upper=np.inf, default_value=5 + ), } diff --git a/carl/envs/brax/carl_humanoidstandup.py b/carl/envs/brax/carl_humanoidstandup.py index 7edb6ef6..1d923bbd 100644 --- a/carl/envs/brax/carl_humanoidstandup.py +++ b/carl/envs/brax/carl_humanoidstandup.py @@ -9,6 +9,7 @@ class CARLBraxHumanoidStandup(CARLBraxEnv): env_name: str = "humanoidstandup" asset_path: str = "envs/assets/humanoidstandup.xml" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/brax/carl_inverted_double_pendulum.py b/carl/envs/brax/carl_inverted_double_pendulum.py index 07976e49..ea467ae0 100644 --- a/carl/envs/brax/carl_inverted_double_pendulum.py +++ b/carl/envs/brax/carl_inverted_double_pendulum.py @@ -9,6 +9,7 @@ class CARLBraxInvertedDoublePendulum(CARLBraxEnv): env_name: str = "inverted_double_pendulum" asset_path: str = "envs/assets/inverted_double_pendulum.xml" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/brax/carl_inverted_pendulum.py b/carl/envs/brax/carl_inverted_pendulum.py index 280d81f5..831330c6 100644 --- a/carl/envs/brax/carl_inverted_pendulum.py +++ b/carl/envs/brax/carl_inverted_pendulum.py @@ -9,6 +9,7 @@ class CARLBraxInvertedPendulum(CARLBraxEnv): env_name: str = "inverted_pendulum" asset_path: str = "envs/assets/inverted_pendulum.xml" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/brax/carl_pusher.py b/carl/envs/brax/carl_pusher.py index d7de1599..19cdec86 100644 --- a/carl/envs/brax/carl_pusher.py +++ b/carl/envs/brax/carl_pusher.py @@ -1,5 +1,7 @@ from __future__ import annotations +from copy import deepcopy + import numpy as np from carl.context.context_space import ContextFeature, UniformFloatContextFeature @@ -9,6 +11,7 @@ class CARLBraxPusher(CARLBraxEnv): env_name: str = "pusher" asset_path: str = "envs/assets/pusher.xml" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: @@ -76,4 +79,25 @@ def get_context_features() -> dict[str, ContextFeature]: "mass_object": UniformFloatContextFeature( "mass_object", lower=1e-6, upper=np.inf, default_value=1.8325957e-03 ), + "goal_position_x": UniformFloatContextFeature( + "goal_position_x", lower=0, upper=np.inf, default_value=0.45 + ), + "goal_position_y": UniformFloatContextFeature( + "goal_position_y", lower=0, upper=np.inf, default_value=0.05 + ), + "goal_position_z": UniformFloatContextFeature( + "goal_position_z", lower=0, upper=np.inf, default_value=0.05 + ), } + + def _update_context(self) -> None: + goal_x = self.context["goal_position_x"] + goal_y = self.context["goal_position_y"] + goal_z = self.context["goal_position_z"] + context = deepcopy(self.context) + del self.context["goal_position_x"] + del self.context["goal_position_y"] + del self.context["goal_position_z"] + super()._update_context() + self.env._goal_pos = np.array([goal_x, goal_y, goal_z]) + self.context = context diff --git a/carl/envs/brax/carl_reacher.py b/carl/envs/brax/carl_reacher.py index a6d75b62..5d35d68d 100644 --- a/carl/envs/brax/carl_reacher.py +++ b/carl/envs/brax/carl_reacher.py @@ -9,6 +9,7 @@ class CARLBraxReacher(CARLBraxEnv): env_name: str = "reacher" asset_path: str = "envs/assets/reacher.xml" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/brax/carl_walker2d.py b/carl/envs/brax/carl_walker2d.py index 3aa66b89..8d927da0 100644 --- a/carl/envs/brax/carl_walker2d.py +++ b/carl/envs/brax/carl_walker2d.py @@ -2,13 +2,19 @@ import numpy as np -from carl.context.context_space import ContextFeature, UniformFloatContextFeature +from carl.context.context_space import ( + CategoricalContextFeature, + ContextFeature, + UniformFloatContextFeature, +) +from carl.envs.brax.brax_walker_goal_wrapper import directions from carl.envs.brax.carl_brax_env import CARLBraxEnv class CARLBraxWalker2d(CARLBraxEnv): env_name: str = "walker2d" asset_path: str = "envs/assets/walker2d.xml" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: @@ -49,4 +55,13 @@ def get_context_features() -> dict[str, ContextFeature]: "mass_foot_left": UniformFloatContextFeature( "mass_foot_left", lower=1e-6, upper=np.inf, default_value=3.1667254 ), + "target_distance": UniformFloatContextFeature( + "target_distance", lower=0, upper=np.inf, default_value=100 + ), + "target_direction": CategoricalContextFeature( + "target_direction", choices=directions, default_value=1 + ), + "target_radius": UniformFloatContextFeature( + "target_radius", lower=0.1, upper=np.inf, default_value=5 + ), } diff --git a/carl/envs/dmc/__init__.py b/carl/envs/dmc/__init__.py index 430665fe..1b3d526a 100644 --- a/carl/envs/dmc/__init__.py +++ b/carl/envs/dmc/__init__.py @@ -2,6 +2,7 @@ # Contexts and bounds by name from carl.envs.dmc.carl_dm_finger import CARLDmcFingerEnv from carl.envs.dmc.carl_dm_fish import CARLDmcFishEnv +from carl.envs.dmc.carl_dm_pointmass import CARLDmcPointMassEnv from carl.envs.dmc.carl_dm_quadruped import CARLDmcQuadrupedEnv from carl.envs.dmc.carl_dm_walker import CARLDmcWalkerEnv @@ -10,4 +11,5 @@ "CARLDmcFishEnv", "CARLDmcQuadrupedEnv", "CARLDmcWalkerEnv", + "CARLDmcPointMassEnv", ] diff --git a/carl/envs/dmc/carl_dm_finger.py b/carl/envs/dmc/carl_dm_finger.py index b2604fec..ce1ef575 100644 --- a/carl/envs/dmc/carl_dm_finger.py +++ b/carl/envs/dmc/carl_dm_finger.py @@ -7,12 +7,13 @@ class CARLDmcFingerEnv(CARLDmcEnv): domain = "finger" task = "spin_context" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: return { "gravity": UniformFloatContextFeature( - "gravity", lower=-np.inf, upper=-0.1, default_value=-9.81 + "gravity", lower=0.1, upper=np.inf, default_value=9.81 ), "friction_torsional": UniformFloatContextFeature( "friction_torsional", lower=0, upper=np.inf, default_value=1.0 diff --git a/carl/envs/dmc/carl_dm_fish.py b/carl/envs/dmc/carl_dm_fish.py index 619364a8..669c88dc 100644 --- a/carl/envs/dmc/carl_dm_fish.py +++ b/carl/envs/dmc/carl_dm_fish.py @@ -7,12 +7,13 @@ class CARLDmcFishEnv(CARLDmcEnv): domain = "fish" task = "swim_context" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: return { "gravity": UniformFloatContextFeature( - "gravity", lower=-np.inf, upper=-0.1, default_value=-9.81 + "gravity", lower=0.1, upper=np.inf, default_value=9.81 ), "friction_torsional": UniformFloatContextFeature( "friction_torsional", lower=0, upper=np.inf, default_value=1.0 diff --git a/carl/envs/dmc/carl_dm_pointmass.py b/carl/envs/dmc/carl_dm_pointmass.py new file mode 100644 index 00000000..dc02d434 --- /dev/null +++ b/carl/envs/dmc/carl_dm_pointmass.py @@ -0,0 +1,74 @@ +import numpy as np + +from carl.context.context_space import ContextFeature, UniformFloatContextFeature +from carl.envs.dmc.carl_dmcontrol import CARLDmcEnv + + +class CARLDmcPointMassEnv(CARLDmcEnv): + domain = "pointmass" + task = "easy_pointmass" + + @staticmethod + def get_context_features() -> dict[str, ContextFeature]: + return { + "gravity": UniformFloatContextFeature( + "gravity", lower=-np.inf, upper=-0.1, default_value=-9.81 + ), + "friction_torsional": UniformFloatContextFeature( + "friction_torsional", lower=0, upper=np.inf, default_value=1.0 + ), + "friction_rolling": UniformFloatContextFeature( + "friction_rolling", lower=0, upper=np.inf, default_value=1.0 + ), + "friction_tangential": UniformFloatContextFeature( + "friction_tangential", lower=0, upper=np.inf, default_value=1.0 + ), + "timestep": UniformFloatContextFeature( + "timestep", lower=0.001, upper=0.1, default_value=0.004 + ), + "joint_damping": UniformFloatContextFeature( + "joint_damping", lower=0.0, upper=np.inf, default_value=1.0 + ), + "joint_stiffness": UniformFloatContextFeature( + "joint_stiffness", lower=0.0, upper=np.inf, default_value=0.0 + ), + "actuator_strength": UniformFloatContextFeature( + "actuator_strength", lower=0.0, upper=np.inf, default_value=1.0 + ), + "density": UniformFloatContextFeature( + "density", lower=0.0, upper=np.inf, default_value=5000.0 + ), + "viscosity": UniformFloatContextFeature( + "viscosity", lower=0.0, upper=np.inf, default_value=0.0 + ), + "geom_density": UniformFloatContextFeature( + "geom_density", lower=0.0, upper=np.inf, default_value=1.0 + ), + "wind_x": UniformFloatContextFeature( + "wind_x", lower=-np.inf, upper=np.inf, default_value=0.0 + ), + "wind_y": UniformFloatContextFeature( + "wind_y", lower=-np.inf, upper=np.inf, default_value=0.0 + ), + "wind_z": UniformFloatContextFeature( + "wind_z", lower=-np.inf, upper=np.inf, default_value=0.0 + ), + "mass": UniformFloatContextFeature( + "mass", lower=0.0, upper=np.inf, default_value=0.3 + ), + "starting_x": UniformFloatContextFeature( + "starting_x", lower=-np.inf, upper=np.inf, default_value=0.14 + ), + "starting_y": UniformFloatContextFeature( + "starting_y", lower=-np.inf, upper=np.inf, default_value=0.14 + ), + "target_x": UniformFloatContextFeature( + "target_x", lower=-np.inf, upper=np.inf, default_value=0.0 + ), + "target_y": UniformFloatContextFeature( + "target_y", lower=-np.inf, upper=np.inf, default_value=0.0 + ), + "area_size": UniformFloatContextFeature( + "area_size", lower=-np.inf, upper=np.inf, default_value=0.6 + ), + } diff --git a/carl/envs/dmc/carl_dm_quadruped.py b/carl/envs/dmc/carl_dm_quadruped.py index 8d1fbdd2..697750d0 100644 --- a/carl/envs/dmc/carl_dm_quadruped.py +++ b/carl/envs/dmc/carl_dm_quadruped.py @@ -7,12 +7,13 @@ class CARLDmcQuadrupedEnv(CARLDmcEnv): domain = "quadruped" task = "walk_context" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: return { "gravity": UniformFloatContextFeature( - "gravity", lower=-np.inf, upper=-0.1, default_value=-9.81 + "gravity", lower=0.1, upper=np.inf, default_value=9.81 ), "friction_torsional": UniformFloatContextFeature( "friction_torsional", lower=0, upper=np.inf, default_value=1.0 diff --git a/carl/envs/dmc/carl_dm_walker.py b/carl/envs/dmc/carl_dm_walker.py index 97b6d9b1..9e88e051 100644 --- a/carl/envs/dmc/carl_dm_walker.py +++ b/carl/envs/dmc/carl_dm_walker.py @@ -7,12 +7,13 @@ class CARLDmcWalkerEnv(CARLDmcEnv): domain = "walker" task = "walk_context" + metadata = {"render_modes": []} @staticmethod def get_context_features() -> dict[str, ContextFeature]: return { "gravity": UniformFloatContextFeature( - "gravity", lower=-np.inf, upper=-0.1, default_value=-9.81 + "gravity", lower=0.1, upper=np.inf, default_value=9.81 ), "friction_torsional": UniformFloatContextFeature( "friction_torsional", lower=0, upper=np.inf, default_value=1.0 diff --git a/carl/envs/dmc/dmc_tasks/pointmass.py b/carl/envs/dmc/dmc_tasks/pointmass.py new file mode 100644 index 00000000..1eb80db5 --- /dev/null +++ b/carl/envs/dmc/dmc_tasks/pointmass.py @@ -0,0 +1,271 @@ +# flake8: noqa: E501 +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Finger Domain.""" +from __future__ import annotations + +from typing import Any + +from multiprocessing.sharedctypes import Value + +import numpy as np +from dm_control.mujoco.wrapper import mjbindings +from dm_control.rl import control # type: ignore +from dm_control.suite.point_mass import ( # type: ignore + _DEFAULT_TIME_LIMIT, + SUITE, + Physics, + PointMass, + get_model_and_assets, +) + +from carl.envs.dmc.dmc_tasks.utils import adapt_context # type: ignore +from carl.utils.types import Context + + +def check_constraints( + mass, + starting_x, + starting_y, + target_x, + target_y, + area_size, +) -> None: + if ( + starting_x >= area_size / 4 + or starting_y >= area_size / 4 + or starting_x <= -area_size / 4 + or starting_y <= -area_size / 4 + ): + raise ValueError( + f"The starting points are located outside of the grid. Choose a value lower than {area_size/4}." + ) + + if ( + target_x >= area_size / 4 + or target_y >= area_size / 4 + or target_x <= -area_size / 4 + or target_y <= -area_size / 4 + ): + raise ValueError( + f"The target points are located outside of the grid. Choose a value lower than {area_size/4}." + ) + + +def make_model( + mass: float = 0.3, + starting_x: float = 0.0, + starting_y: float = 0.0, + target_x: float = 0.0, + target_y: float = 0.0, + area_size: float = 0.6, + **kwargs: Any, +) -> bytes: + check_constraints( + mass=mass, + starting_x=starting_x, + starting_y=starting_y, + target_x=target_x, + target_y=target_y, + area_size=area_size, + ) + + xml_string = f""" + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + """ + xml_string_bytes = xml_string.encode() + return xml_string_bytes + + +def random_limited_quaternion(random, limit): + """Generates a random quaternion limited to the specified rotations.""" + axis = random.randn(3) + axis /= np.linalg.norm(axis) + angle = random.rand() * limit + + quaternion = np.zeros(4) + mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle) + + return quaternion + + +class ContextualPointMass(PointMass): + starting_x: float = 0.2 + starting_y: float = 0.2 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + If _randomize_gains is True, the relationship between the controls and + the joints is randomized, so that each control actuates a random linear + combination of joints. + + Args: + physics: An instance of `mujoco.Physics`. + """ + if self._randomize_gains: + dir1 = self.random.randn(2) + dir1 /= np.linalg.norm(dir1) + # Find another actuation direction that is not 'too parallel' to dir1. + parallel = True + while parallel: + dir2 = self.random.randn(2) + dir2 /= np.linalg.norm(dir2) + parallel = abs(np.dot(dir1, dir2)) > 0.9 + physics.model.wrap_prm[[0, 1]] = dir1 + physics.model.wrap_prm[[2, 3]] = dir2 + super().initialize_episode(physics) + self.randomize_limited_and_rotational_joints(physics, self.random) + + def randomize_limited_and_rotational_joints(self, physics, random=None): + random = random or np.random + + hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE + slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE + ball = mjbindings.enums.mjtJoint.mjJNT_BALL + free = mjbindings.enums.mjtJoint.mjJNT_FREE + + qpos = physics.named.data.qpos + + for joint_id in range(physics.model.njnt): + joint_name = physics.model.id2name(joint_id, "joint") + joint_type = physics.model.jnt_type[joint_id] + is_limited = physics.model.jnt_limited[joint_id] + range_min, range_max = physics.model.jnt_range[joint_id] + + if is_limited: + if joint_type == hinge or joint_type == slide: + if "root_x" in joint_name: + qpos[joint_name] = self.starting_x + elif "root_y" in joint_name: + qpos[joint_name] = self.starting_y + else: + qpos[joint_name] = random.uniform(range_min, range_max) + + elif joint_type == ball: + qpos[joint_name] = random_limited_quaternion(random, range_max) + + else: + if joint_type == hinge: + qpos[joint_name] = random.uniform(-np.pi, np.pi) + + elif joint_type == ball: + quat = random.randn(4) + quat /= np.linalg.norm(quat) + qpos[joint_name] = quat + + elif joint_type == free: + # this should be random.randn, but changing it now could significantly + # affect benchmark results. + quat = random.rand(4) + quat /= np.linalg.norm(quat) + qpos[joint_name][3:] = quat + + +@SUITE.add("benchmarking") # type: ignore[misc] +def easy_pointmass( + context: Context = {}, + context_mask: list = [], + time_limit: float = _DEFAULT_TIME_LIMIT, + random: np.random.RandomState | int | None = None, + environment_kwargs: dict | None = None, +) -> control.Environment: + """No randomization.""" + xml_string, assets = get_model_and_assets() + xml_string = make_model(**context) + if context != {}: + xml_string = adapt_context(xml_string=xml_string, context=context) + physics = Physics.from_xml_string(xml_string, assets) + task = ContextualPointMass(randomize_gains=False, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + **environment_kwargs, + ) + + +@SUITE.add("benchmarking") # type: ignore[misc] +def hard_pointmass( + context: Context = {}, + context_mask: list = [], + time_limit: float = _DEFAULT_TIME_LIMIT, + random: np.random.RandomState | int | None = None, + environment_kwargs: dict | None = None, +) -> control.Environment: + """Randomized initializations.""" + xml_string, assets = get_model_and_assets() + xml_string = make_model(**context) + if context != {}: + xml_string = adapt_context(xml_string=xml_string, context=context) + physics = Physics.from_xml_string(xml_string, assets) + task = ContextualPointMass(randomize_gains=True, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + **environment_kwargs, + ) diff --git a/carl/envs/dmc/dmc_tasks/utils.py b/carl/envs/dmc/dmc_tasks/utils.py index 7482dfbe..08abc366 100644 --- a/carl/envs/dmc/dmc_tasks/utils.py +++ b/carl/envs/dmc/dmc_tasks/utils.py @@ -122,17 +122,21 @@ def adapt_context(xml_string: bytes, context: Context) -> bytes: # find option settings and override them if they exist, otherwise create new option option = mjcf.find(".//option") + import logging + if option is None: option = etree.Element("option") mjcf.append(option) if "gravity" in context: gravity = option.get("gravity") + logging.info(gravity) if gravity is not None: g = gravity.split(" ") gravity = " ".join([g[0], g[1], str(-context["gravity"])]) else: - gravity = " ".join(["0", "0", str(-context["gravity"])]) + gravity = " ".join(["0", "0", f"{str(-context['gravity'])}"]) + logging.info(gravity) option.set("gravity", gravity) if "wind_x" in context and "wind_y" in context and "wind_z" in context: diff --git a/carl/envs/dmc/loader.py b/carl/envs/dmc/loader.py index 0633fabd..709ab568 100644 --- a/carl/envs/dmc/loader.py +++ b/carl/envs/dmc/loader.py @@ -8,6 +8,7 @@ from carl.envs.dmc.dmc_tasks import ( # type: ignore [import] # noqa: F401 finger, fish, + pointmass, quadruped, walker, ) diff --git a/carl/envs/dmc/wrappers.py b/carl/envs/dmc/wrappers.py index 7ac9a059..665e5c26 100644 --- a/carl/envs/dmc/wrappers.py +++ b/carl/envs/dmc/wrappers.py @@ -1,10 +1,10 @@ from typing import Any, Optional, Tuple, TypeVar, Union import dm_env # type: ignore -import gym +import gymnasium as gym import numpy as np from dm_env import StepType -from gym import spaces +from gymnasium import spaces ObsType = TypeVar("ObsType") ActType = TypeVar("ActType") diff --git a/carl/envs/gymnasium/__init__.py b/carl/envs/gymnasium/__init__.py index 62ba5092..7df661e1 100644 --- a/carl/envs/gymnasium/__init__.py +++ b/carl/envs/gymnasium/__init__.py @@ -1,3 +1,8 @@ +# flake8: noqa: F401 +# Modular imports +import importlib.util as iutil +import warnings + from carl.envs.gymnasium.classic_control import ( CARLAcrobot, CARLCartPole, @@ -13,3 +18,35 @@ "CARLMountainCarContinuous", "CARLPendulum", ] + + +def check_spec(spec_name: str) -> bool: + """Check if the spec is installed + + Parameters + ---------- + spec_name : str + Name of package that is necessary for the environment suite. + + Returns + ------- + bool + Whether the spec was found. + """ + spec = iutil.find_spec(spec_name) + found = spec is not None + if not found: + with warnings.catch_warnings(): + warnings.simplefilter("once") + warnings.warn( + f"Module {spec_name} not found. If you want to use these environments, please follow the installation guide." + ) + return found + + +# Environment loading +found = check_spec("Box2D") +if found: + from carl.envs.gymnasium.box2d import * + + __all__ += ["CARLBipedalWalker", "CARLLunarLander", "CARLVehicleRacing"] diff --git a/carl/envs/gymnasium/box2d/carl_bipedal_walker.py b/carl/envs/gymnasium/box2d/carl_bipedal_walker.py index 2ea9f4dc..a36cd62c 100644 --- a/carl/envs/gymnasium/box2d/carl_bipedal_walker.py +++ b/carl/envs/gymnasium/box2d/carl_bipedal_walker.py @@ -14,6 +14,7 @@ class CARLBipedalWalker(CARLGymnasiumEnv): env_name: str = "BipedalWalker-v3" + metadata = {"render.modes": ["human", "rgb_array"]} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/gymnasium/box2d/carl_lunarlander.py b/carl/envs/gymnasium/box2d/carl_lunarlander.py index 9f3c59f5..52dfcfda 100644 --- a/carl/envs/gymnasium/box2d/carl_lunarlander.py +++ b/carl/envs/gymnasium/box2d/carl_lunarlander.py @@ -14,6 +14,7 @@ class CARLLunarLander(CARLGymnasiumEnv): env_name: str = "LunarLander-v2" + metadata = {"render.modes": ["human", "rgb_array"]} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/gymnasium/box2d/carl_vehicle_racing.py b/carl/envs/gymnasium/box2d/carl_vehicle_racing.py index dcf61d75..0db60a60 100644 --- a/carl/envs/gymnasium/box2d/carl_vehicle_racing.py +++ b/carl/envs/gymnasium/box2d/carl_vehicle_racing.py @@ -209,6 +209,7 @@ def render_if_min(value, points, color): class CARLVehicleRacing(CARLGymnasiumEnv): env_name: str = "CustomCarRacing-v2" + metadata = {"render.modes": ["human", "rgb_array"]} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/gymnasium/carl_gymnasium_env.py b/carl/envs/gymnasium/carl_gymnasium_env.py index 856e79b7..fb96d6d8 100644 --- a/carl/envs/gymnasium/carl_gymnasium_env.py +++ b/carl/envs/gymnasium/carl_gymnasium_env.py @@ -10,10 +10,10 @@ try: pygame.display.init() -except: - import os +except: # pragma: no cover + import os # pragma: no cover - os.environ["SDL_VIDEODRIVER"] = "dummy" + os.environ["SDL_VIDEODRIVER"] = "dummy" # pragma: no cover class CARLGymnasiumEnv(CARLEnv): diff --git a/carl/envs/gymnasium/classic_control/carl_acrobot.py b/carl/envs/gymnasium/classic_control/carl_acrobot.py index d81a7534..01d66b9a 100644 --- a/carl/envs/gymnasium/classic_control/carl_acrobot.py +++ b/carl/envs/gymnasium/classic_control/carl_acrobot.py @@ -10,6 +10,7 @@ class CARLAcrobot(CARLGymnasiumEnv): env_name: str = "Acrobot-v1" + metadata = {"render.modes": ["human", "rgb_array"]} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/gymnasium/classic_control/carl_cartpole.py b/carl/envs/gymnasium/classic_control/carl_cartpole.py index ff8c7a31..93e44b9f 100644 --- a/carl/envs/gymnasium/classic_control/carl_cartpole.py +++ b/carl/envs/gymnasium/classic_control/carl_cartpole.py @@ -10,6 +10,7 @@ class CARLCartPole(CARLGymnasiumEnv): env_name: str = "CartPole-v1" + metadata = {"render.modes": ["human", "rgb_array"]} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/gymnasium/classic_control/carl_mountaincar.py b/carl/envs/gymnasium/classic_control/carl_mountaincar.py index dcde2e77..2ea59621 100644 --- a/carl/envs/gymnasium/classic_control/carl_mountaincar.py +++ b/carl/envs/gymnasium/classic_control/carl_mountaincar.py @@ -10,6 +10,7 @@ class CARLMountainCar(CARLGymnasiumEnv): env_name: str = "MountainCar-v0" + metadata = {"render.modes": ["human", "rgb_array"]} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/gymnasium/classic_control/carl_mountaincarcontinuous.py b/carl/envs/gymnasium/classic_control/carl_mountaincarcontinuous.py index 155823b2..3aeab6d9 100644 --- a/carl/envs/gymnasium/classic_control/carl_mountaincarcontinuous.py +++ b/carl/envs/gymnasium/classic_control/carl_mountaincarcontinuous.py @@ -10,6 +10,7 @@ class CARLMountainCarContinuous(CARLGymnasiumEnv): env_name: str = "MountainCarContinuous-v0" + metadata = {"render.modes": ["human", "rgb_array"]} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/gymnasium/classic_control/carl_pendulum.py b/carl/envs/gymnasium/classic_control/carl_pendulum.py index 4148886a..10226dd8 100644 --- a/carl/envs/gymnasium/classic_control/carl_pendulum.py +++ b/carl/envs/gymnasium/classic_control/carl_pendulum.py @@ -10,6 +10,7 @@ class CARLPendulum(CARLGymnasiumEnv): env_name: str = "Pendulum-v1" + metadata = {"render_modes": ["human", "rgb_array"]} @staticmethod def get_context_features() -> dict[str, ContextFeature]: diff --git a/carl/envs/mario/carl_mario.py b/carl/envs/mario/carl_mario.py index 75fae6ca..2b4c5d09 100644 --- a/carl/envs/mario/carl_mario.py +++ b/carl/envs/mario/carl_mario.py @@ -22,6 +22,11 @@ class CARLMarioEnv(CARLEnv): + metadata = { + "render_modes": ["rgb_array", "tiny_rgb_array"], + "render_fps": 24, + } + def __init__( self, env: MarioEnv = None, diff --git a/changelog.md b/changelog.md index 76d89dae..d8e60288 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,9 @@ +# 1.1.0 +- increased test coverage +- smaller bug fixes +- added DMC pointmass env +- added goal & language goal options for Brax + # 1.0.0 Major overhaul of the CARL environment - Contexts are stored in each environment's class diff --git a/docs/source/environments/data/screenshots/pointmass.jpg b/docs/source/environments/data/screenshots/pointmass.jpg new file mode 100644 index 00000000..216eb985 Binary files /dev/null and b/docs/source/environments/data/screenshots/pointmass.jpg differ diff --git a/examples/brax_with_goals.ipynb b/examples/brax_with_goals.ipynb new file mode 100644 index 00000000..c3a68c1f --- /dev/null +++ b/examples/brax_with_goals.ipynb @@ -0,0 +1,153 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/theeimer/anaconda3/envs/carl/lib/python3.9/site-packages/carl/envs/__init__.py:31: UserWarning: Module py4j not found. If you want to use these environments, please follow the installation guide.\n", + " warnings.warn(\n", + "/Users/theeimer/anaconda3/envs/carl/lib/python3.9/site-packages/carl/envs/__init__.py:31: UserWarning: Module distance not found. If you want to use these environments, please follow the installation guide.\n", + " warnings.warn(\n", + "/Users/theeimer/anaconda3/envs/carl/lib/python3.9/site-packages/carl/__init__.py:55: UserWarning: Module py4j not found. If you want to use these environments,\n", + " please follow the installation guide.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "from carl.context.context_space import NormalFloatContextFeature, CategoricalContextFeature\n", + "from carl.context.sampler import ContextSampler\n", + "from carl.envs import CARLBraxAnt, CARLBraxPusher\n", + "from carl.envs.brax.brax_walker_goal_wrapper import directions\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{0: {'gravity': -9.8, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0, 'target_direction': 112, 'target_distance': 8.957275946170714}, 1: {'gravity': -9.8, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0, 'target_direction': 334, 'target_distance': 11.769924447869895}, 2: {'gravity': -9.8, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0, 'target_direction': 332, 'target_distance': 11.066118529857778}, 3: {'gravity': -9.8, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0, 'target_direction': 112, 'target_distance': 9.294123460239488}, 4: {'gravity': -9.8, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0, 'target_direction': 14, 'target_distance': 12.345200778471055}}\n" + ] + } + ], + "source": [ + "seed = 0\n", + "context_distributions = [NormalFloatContextFeature(\"target_distance\", mu=9.8, sigma=1), CategoricalContextFeature(\"target_direction\", choices=directions)]\n", + "context_sampler = ContextSampler(\n", + " context_distributions=context_distributions,\n", + " context_space=CARLBraxAnt.get_context_space(),\n", + " seed=seed,\n", + " )\n", + "contexts = context_sampler.sample_contexts(n_contexts=5)\n", + "print(contexts)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "'target_direction'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/theeimer/Documents/git/CARL/examples/brax_with_goals.ipynb Cell 3\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m env \u001b[39m=\u001b[39m CARLBraxAnt(contexts\u001b[39m=\u001b[39;49mcontexts, use_language_goals\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n\u001b[1;32m 2\u001b[0m env\u001b[39m.\u001b[39mreset()\n\u001b[1;32m 3\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mCurrent context ID: \u001b[39m\u001b[39m{\u001b[39;00menv\u001b[39m.\u001b[39mcontext_id\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n", + "File \u001b[0;32m~/anaconda3/envs/carl/lib/python3.9/site-packages/carl/envs/brax/carl_brax_env.py:207\u001b[0m, in \u001b[0;36mCARLBraxEnv.__init__\u001b[0;34m(self, env, batch_size, contexts, obs_context_features, obs_context_as_dict, context_selector, context_selector_kwargs, **kwargs)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[39m# The observation space also needs to from gymnasium\u001b[39;00m\n\u001b[1;32m 201\u001b[0m env\u001b[39m.\u001b[39mobservation_space \u001b[39m=\u001b[39m gymnasium\u001b[39m.\u001b[39mspaces\u001b[39m.\u001b[39mBox(\n\u001b[1;32m 202\u001b[0m low\u001b[39m=\u001b[39menv\u001b[39m.\u001b[39mobservation_space\u001b[39m.\u001b[39mlow,\n\u001b[1;32m 203\u001b[0m high\u001b[39m=\u001b[39menv\u001b[39m.\u001b[39mobservation_space\u001b[39m.\u001b[39mhigh,\n\u001b[1;32m 204\u001b[0m dtype\u001b[39m=\u001b[39mnp\u001b[39m.\u001b[39mfloat32,\n\u001b[1;32m 205\u001b[0m )\n\u001b[0;32m--> 207\u001b[0m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 208\u001b[0m env\u001b[39m=\u001b[39;49menv,\n\u001b[1;32m 209\u001b[0m contexts\u001b[39m=\u001b[39;49mcontexts,\n\u001b[1;32m 210\u001b[0m obs_context_features\u001b[39m=\u001b[39;49mobs_context_features,\n\u001b[1;32m 211\u001b[0m obs_context_as_dict\u001b[39m=\u001b[39;49mobs_context_as_dict,\n\u001b[1;32m 212\u001b[0m context_selector\u001b[39m=\u001b[39;49mcontext_selector,\n\u001b[1;32m 213\u001b[0m context_selector_kwargs\u001b[39m=\u001b[39;49mcontext_selector_kwargs,\n\u001b[1;32m 214\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 215\u001b[0m )\n", + "File \u001b[0;32m~/anaconda3/envs/carl/lib/python3.9/site-packages/carl/envs/carl_env.py:110\u001b[0m, in \u001b[0;36mCARLEnv.__init__\u001b[0;34m(self, env, contexts, obs_context_features, obs_context_as_dict, context_selector, context_selector_kwargs, **kwargs)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 105\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 106\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mContext selector must be None or an AbstractSelector class or instance. \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 107\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mGot type \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mtype\u001b[39m(context_selector)\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 108\u001b[0m )\n\u001b[0;32m--> 110\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobservation_space: gymnasium\u001b[39m.\u001b[39mspaces\u001b[39m.\u001b[39mDict \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mget_observation_space(\n\u001b[1;32m 111\u001b[0m obs_context_feature_names\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mobs_context_features\n\u001b[1;32m 112\u001b[0m )\n", + "File \u001b[0;32m~/anaconda3/envs/carl/lib/python3.9/site-packages/carl/envs/carl_env.py:177\u001b[0m, in \u001b[0;36mCARLEnv.get_observation_space\u001b[0;34m(self, obs_context_feature_names)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Get the observation space for the context.\u001b[39;00m\n\u001b[1;32m 163\u001b[0m \n\u001b[1;32m 164\u001b[0m \u001b[39mParameters\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[39m underlying environment (\"state\") and for the context (\"context\").\u001b[39;00m\n\u001b[1;32m 175\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 176\u001b[0m context_space \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_context_space()\n\u001b[0;32m--> 177\u001b[0m obs_space_context \u001b[39m=\u001b[39m context_space\u001b[39m.\u001b[39;49mto_gymnasium_space(\n\u001b[1;32m 178\u001b[0m context_feature_names\u001b[39m=\u001b[39;49mobs_context_feature_names,\n\u001b[1;32m 179\u001b[0m as_dict\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mobs_context_as_dict,\n\u001b[1;32m 180\u001b[0m )\n\u001b[1;32m 182\u001b[0m obs_space \u001b[39m=\u001b[39m spaces\u001b[39m.\u001b[39mDict(\n\u001b[1;32m 183\u001b[0m {\n\u001b[1;32m 184\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mobs\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbase_observation_space,\n\u001b[1;32m 185\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mcontext\u001b[39m\u001b[39m\"\u001b[39m: obs_space_context,\n\u001b[1;32m 186\u001b[0m }\n\u001b[1;32m 187\u001b[0m )\n\u001b[1;32m 188\u001b[0m \u001b[39mreturn\u001b[39;00m obs_space\n", + "File \u001b[0;32m~/anaconda3/envs/carl/lib/python3.9/site-packages/carl/context/context_space.py:170\u001b[0m, in \u001b[0;36mContextSpace.to_gymnasium_space\u001b[0;34m(self, context_feature_names, as_dict)\u001b[0m\n\u001b[1;32m 167\u001b[0m context_space \u001b[39m=\u001b[39m {}\n\u001b[1;32m 169\u001b[0m \u001b[39mfor\u001b[39;00m cf_name \u001b[39min\u001b[39;00m context_feature_names:\n\u001b[0;32m--> 170\u001b[0m context_feature \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcontext_space[cf_name]\n\u001b[1;32m 171\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(context_feature, NumericalContextFeature):\n\u001b[1;32m 172\u001b[0m context_space[context_feature\u001b[39m.\u001b[39mname] \u001b[39m=\u001b[39m spaces\u001b[39m.\u001b[39mBox(\n\u001b[1;32m 173\u001b[0m low\u001b[39m=\u001b[39mcontext_feature\u001b[39m.\u001b[39mlower, high\u001b[39m=\u001b[39mcontext_feature\u001b[39m.\u001b[39mupper\n\u001b[1;32m 174\u001b[0m )\n", + "\u001b[0;31mKeyError\u001b[0m: 'target_direction'" + ] + } + ], + "source": [ + "env = CARLBraxAnt(contexts=contexts, use_language_goals=True)\n", + "env.reset()\n", + "print(f\"Current context ID: {env.context_id}\")\n", + "print(f\"Current context: {env.context}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "action = env.action_space.sample()\n", + "state, reward, terminated, truncated, info = env.step(action)\n", + "done = terminated or truncated\n", + "plt.imshow(env.render())\n", + "print(state)\n", + "print(reward)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "context_distributions = [NormalFloatContextFeature(\"goal_position_x\", mu=9.8, sigma=1), NormalFloatContextFeature(\"goal_position_y\", mu=9.8, sigma=1)]\n", + "context_sampler = ContextSampler(\n", + " context_distributions=context_distributions,\n", + " context_space=CARLBraxPusher.get_context_space(),\n", + " seed=seed,\n", + " )\n", + "contexts = context_sampler.sample_contexts(n_contexts=5)\n", + "print(contexts)\n", + "env = CARLBraxPusher(contexts)\n", + "env.reset()\n", + "print(f\"Current context ID: {env.context_id}\")\n", + "print(f\"Current context: {env.context}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "action = env.action_space.sample()\n", + "state, reward, terminated, truncated, info = env.step(action)\n", + "done = terminated or truncated\n", + "plt.imshow(env.render())\n", + "print(state)\n", + "print(reward)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "carl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/carl_with_sb3.py b/examples/carl_with_sb3.py new file mode 100644 index 00000000..01109ef4 --- /dev/null +++ b/examples/carl_with_sb3.py @@ -0,0 +1,38 @@ +import carl +import gymnasium as gym +from gymnasium.wrappers import FlattenObservation +from stable_baselines3 import DQN +from stable_baselines3.common.evaluation import evaluate_policy + +from carl.envs import CARLLunarLander +from carl.context.context_space import NormalFloatContextFeature +from carl.context.sampler import ContextSampler + +# Create environment +context_distributions = [NormalFloatContextFeature("GRAVITY_X", mu=9.8, sigma=1)] +context_sampler = ContextSampler( + context_distributions=context_distributions, + context_space=CARLLunarLander.get_context_space(), + seed=42, +) +contexts = context_sampler.sample_contexts(n_contexts=5) + +print("Training contexts are:") +print(contexts) + +env = gym.make("carl/CARLLunarLander-v0", render_mode="rgb_array", contexts=contexts) +env = FlattenObservation(env) + +# Instantiate the agent +model = DQN("MlpPolicy", env, verbose=1) +# Train the agent and display a progress bar +model.learn(total_timesteps=int(2e4), progress_bar=True) +mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10) + +# Enjoy trained agent +vec_env = model.get_env() +obs = vec_env.reset() +for i in range(1000): + action, _states = model.predict(obs, deterministic=True) + obs, rewards, dones, info = vec_env.step(action) + vec_env.render("human") diff --git a/examples/demo_carracing.py b/examples/demo_carracing.py index b6b96927..34bfc4f7 100644 --- a/examples/demo_carracing.py +++ b/examples/demo_carracing.py @@ -42,6 +42,7 @@ def register_input(): contexts = {i: {"VEHICLE_ID": i} for i in range(len(VEHICLE_NAMES))} CARLVehicleRacing.render_mode = "human" + CARLVehicleRacing.render_mode = "human" env = CARLVehicleRacing(contexts=contexts) record_video = False @@ -62,13 +63,14 @@ def register_input(): while True: register_input() s, r, truncated, terminated, info = env.step(a) + s, r, truncated, terminated, info = env.step(a) time.sleep(0.025) total_reward += r - if steps % 200 == 0 or truncated or terminated: + if steps % 200 == 0 or terminated or truncated: print("\naction " + str(["{:+0.2f}".format(x) for x in a])) print("step {} total_reward {:+0.2f}".format(steps, total_reward)) steps += 1 env.render() - if truncated or terminated or restart or not isopen: + if terminated or truncated or restart or not isopen: break env.close() diff --git a/pyproject.toml b/pyproject.toml index 588d7984..9d01dce0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,11 +4,18 @@ [tool.pytest.ini_options] testpaths = ["test"] minversion = "3.9" -addopts = "--cov=carl" +addopts="--cov=carl" [tool.coverage.run] branch = true -context = "carl" +include = ["carl/*"] +omit = [ + "*/mario/pcg_smb_env/*", + "*/rna/*", + "*/utils/doc_building/*", + "*/mario/models/*", + "__init__.py" +] [tool.coverage.report] show_missing = true @@ -19,6 +26,13 @@ exclude_lines = [ "raise NotImplementedError", "if TYPE_CHECKING" ] +omit = [ + "*/mario/pcg_smb_env/*", + "*/rna/*", + "*/utils/doc_building/*", + "*/mario/models/*", + "__init__.py" +] [tool.black] target-version = ['py39'] diff --git a/setup.py b/setup.py index d1f48f8b..fd711596 100644 --- a/setup.py +++ b/setup.py @@ -25,8 +25,9 @@ def read_file(filepath: str) -> str: "gymnasium[box2d]>=0.27.1", ], "brax": [ - "brax>=0.9.1", + "brax==0.9.3", "protobuf>=3.17.3", + "mujoco==3.0.1" ], "dm_control": [ "dm_control>=1.0.3", @@ -56,6 +57,9 @@ def read_file(filepath: str) -> str: "sphinx-autoapi>=1.8.4", "automl-sphinx-theme>=0.1.9", ], + "examples": [ + "stable-baselines3", + ] } setuptools.setup( diff --git a/test/test_all_envs.py b/test/test_all_envs.py index 141b06b4..68f57ccf 100644 --- a/test/test_all_envs.py +++ b/test/test_all_envs.py @@ -16,6 +16,7 @@ def test_init_all_envs(self): env = ( # noqa: F841 local variable is assigned to but never used var() ) + env.reset() except Exception as e: print(f"Cannot instantiate {var} environment.") raise e diff --git a/test/test_box2d_envs.py b/test/test_box2d_envs.py new file mode 100644 index 00000000..2cbabe51 --- /dev/null +++ b/test/test_box2d_envs.py @@ -0,0 +1,26 @@ +import inspect +import unittest + +import carl.envs.gymnasium + + +class TestBox2DEnvs(unittest.TestCase): + def test_envs(self): + envs = inspect.getmembers(carl.envs.gymnasium.box2d) + + for env_name, env_obj in envs: + if inspect.isclass(env_obj) and "CARL" in env_name: + try: + env_obj.get_context_features() + + env = env_obj() + env._progress_instance() + env._update_context() + env.reset() + except Exception as e: + print(f"Cannot instantiate {env_name} environment.") + raise e + + +if __name__ == "__main__": + TestBox2DEnvs().test_envs() diff --git a/test/test_brax_env.py b/test/test_brax_env.py new file mode 100644 index 00000000..16f36de4 --- /dev/null +++ b/test/test_brax_env.py @@ -0,0 +1,27 @@ +import inspect +import unittest + +import carl.envs.gymnasium + + +class TestBraxEnvs(unittest.TestCase): + def test_envs(self): + envs = inspect.getmembers(carl.envs.brax) + + for env_name, env_obj in envs: + if inspect.isclass(env_obj) and "CARL" in env_name: + try: + env_obj.get_context_features() + + env = env_obj() + env._progress_instance() + env._update_context() + env.reset() + + except Exception as e: + print(f"Cannot instantiate {env_name} environment.") + raise e + + +if __name__ == "__main__": + TestBraxEnvs().test_envs() diff --git a/test/test_context_bounds.py b/test/test_context_bounds.py new file mode 100644 index 00000000..6c022b8d --- /dev/null +++ b/test/test_context_bounds.py @@ -0,0 +1,63 @@ +import unittest + +import numpy as np + +from carl.context.utils import get_context_bounds + + +class TestContextBounds(unittest.TestCase): + def test_context_bounds(self): + DEFAULT_CONTEXT = { + "min_position": -1.2, # unit? + "max_position": 0.6, # unit? + "max_speed": 0.07, # unit? + "goal_position": 0.5, # unit? + "goal_velocity": 0, # unit? + "force": 0.001, # unit? + "gravity": 0.0025, # unit? + "min_position_start": -0.6, + "max_position_start": -0.4, + "min_velocity_start": 0.0, + "max_velocity_start": 0.0, + } + + CONTEXT_BOUNDS = { + "min_position": (-np.inf, np.inf, float), + "max_position": (-np.inf, np.inf, float), + "max_speed": (0, np.inf, float), + "goal_position": (-np.inf, np.inf, float), + "goal_velocity": (-np.inf, np.inf, float), + "force": (-np.inf, np.inf, float), + "gravity": (0, np.inf, float), + "min_position_start": (-np.inf, np.inf, float), + "max_position_start": (-np.inf, np.inf, float), + "min_velocity_start": (-np.inf, np.inf, float), + "max_velocity_start": (-np.inf, np.inf, float), + } + + lower, upper = get_context_bounds(list(DEFAULT_CONTEXT.keys()), CONTEXT_BOUNDS) + + self.assertEqual( + lower.all(), + np.array( + [ + -np.inf, + -np.inf, + 0.0, + -np.inf, + -np.inf, + -np.inf, + 0.0, + -np.inf, + -np.inf, + -np.inf, + -np.inf, + ] + ).all(), + ) + + self.assertEqual(upper.all(), np.array([np.inf] * upper.shape[0]).all()) + + +if __name__ == "__main__": + TestContextBounds.test_context_bounds() diff --git a/test/test_context_sampler.py b/test/test_context_sampler.py index 111255ec..7a1b73ac 100644 --- a/test/test_context_sampler.py +++ b/test/test_context_sampler.py @@ -44,11 +44,23 @@ def test_init(self): name="TestSampler", ) + with self.assertRaises(ValueError): + ContextSampler( + context_distributions=0, + context_space=self.cspace, + seed=0, + name="TestSampler", + ) + def test_sample_contexts(self): contexts = self.sampler.sample_contexts(n_contexts=3) self.assertEqual(len(contexts), 3) self.assertEqual(contexts[0]["gravity"], 9.8) + contexts = self.sampler.sample_contexts(n_contexts=1) + self.assertEqual(len(contexts), 1) + self.assertEqual(contexts[0]["gravity"], 9.8) + if __name__ == "__main__": unittest.main() diff --git a/test/test_context_space.py b/test/test_context_space.py index 00478679..10c8c953 100644 --- a/test/test_context_space.py +++ b/test/test_context_space.py @@ -88,6 +88,26 @@ def test_verify_context(self): is_valid = self.context_space.verify_context(context) self.assertEqual(is_valid, False) + def test_sample(self): + context = self.context_space.sample_contexts(["gravity"], size=1) + is_valid = self.context_space.verify_context(context) + self.assertEqual(is_valid, True) + + contexts = self.context_space.sample_contexts(["gravity"], size=10) + self.assertTrue(len(contexts) == 10) + for context in contexts: + is_valid = self.context_space.verify_context(context) + self.assertEqual(is_valid, True) + + contexts = self.context_space.sample_contexts(None, size=10) + self.assertTrue(len(contexts) == 10) + for context in contexts: + is_valid = self.context_space.verify_context(context) + self.assertEqual(is_valid, True) + + with self.assertRaises(ValueError): + self.context_space.sample_contexts(["false_feature"], size=0) + if __name__ == "__main__": unittest.main() diff --git a/test/test_dmc.py b/test/test_dmc.py index 83a9c05c..5556b64c 100644 --- a/test/test_dmc.py +++ b/test/test_dmc.py @@ -1,16 +1,34 @@ -import unittest - +import pytest + +from carl.envs.dmc import ( + CARLDmcFingerEnv, + CARLDmcFishEnv, + CARLDmcPointMassEnv, + CARLDmcQuadrupedEnv, + CARLDmcWalkerEnv, +) +from carl.envs.dmc.dmc_tasks.finger import check_constraints +from carl.envs.dmc.dmc_tasks.finger import ( + get_model_and_assets as get_finger_model_and_assets, +) from carl.envs.dmc.dmc_tasks.finger import ( - check_constraints, spin_context, turn_easy_context, turn_hard_context, ) +from carl.envs.dmc.dmc_tasks.pointmass import ( + check_constraints as check_constraints_pointmass, +) +from carl.envs.dmc.dmc_tasks.pointmass import make_model as make_pointmass_model +from carl.envs.dmc.dmc_tasks.quadruped import make_model as make_quadruped_model from carl.envs.dmc.dmc_tasks.utils import adapt_context +from carl.envs.dmc.dmc_tasks.walker import ( + get_model_and_assets as get_walker_model_and_assets, +) from carl.envs.dmc.loader import load_dmc_env -class TestDMCLoader(unittest.TestCase): +class TestDMCLoader: def test_load_classic_dmc_env(self): _ = load_dmc_env( domain_name="walker", @@ -24,24 +42,24 @@ def test_load_context_dmc_env(self): ) def test_load_unknowntask_dmc_env(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _ = load_dmc_env( domain_name="walker", task_name="walk_context_blub", ) def test_load_unknowndomain_dmc_env(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _ = load_dmc_env( domain_name="sdfsdf", task_name="walk", ) -class TestDmcEnvs(unittest.TestCase): +class TestFinger: def test_finger_constraints(self): # Finger can reach spinner? - with self.assertRaises(ValueError): + with pytest.raises(ValueError): check_constraints( limb_length_0=0.17, limb_length_1=0.16, @@ -49,7 +67,7 @@ def test_finger_constraints(self): raise_error=True, ) # Spinner collides with finger hinge? - with self.assertRaises(ValueError): + with pytest.raises(ValueError): check_constraints( limb_length_0=0.17, limb_length_1=0.16, @@ -65,50 +83,40 @@ def test_finger_tasks(self): _ = task(context=context) -class TestDmcUtils(unittest.TestCase): - def setUp(self) -> None: - from carl.envs.dmc.carl_dm_finger import CARLDmcFingerEnv - from carl.envs.dmc.dmc_tasks.finger import get_model_and_assets - - self.xml_string, _ = get_model_and_assets() - self.default_context = CARLDmcFingerEnv.get_default_context() +class TestDmcUtils: + def get_string_and_context(self): + xml_string, _ = get_finger_model_and_assets() + default_context = CARLDmcFingerEnv.get_default_context() + return xml_string, default_context def test_adapt_context_no_context(self): context = {} - _ = adapt_context(xml_string=self.xml_string, context=context) + xml_string, _ = self.get_string_and_context() + _ = adapt_context(xml_string=xml_string, context=context) def test_adapt_context_partialcontext(self): context = {"gravity": 10} - _ = adapt_context(xml_string=self.xml_string, context=context) + xml_string, _ = self.get_string_and_context() + _ = adapt_context(xml_string=xml_string, context=context) def test_adapt_context_fullcontext(self): # only continuous context features - context = self.default_context + xml_string, context = self.get_string_and_context() context["gravity"] *= 1.25 - _ = adapt_context(xml_string=self.xml_string, context=context) + _ = adapt_context(xml_string=xml_string, context=context) def test_adapt_context_friction(self): - from carl.envs.dmc.carl_dm_walker import CARLDmcWalkerEnv - from carl.envs.dmc.dmc_tasks.walker import get_model_and_assets - - xml_string, _ = get_model_and_assets() + xml_string, _ = get_walker_model_and_assets() context = CARLDmcWalkerEnv.get_default_context() context["friction_tangential"] *= 1.3 _ = adapt_context(xml_string=xml_string, context=context) -class TestQuadruped(unittest.TestCase): - def setUp(self) -> None: - pass - +class TestQuadruped: def test_make_model(self): - from carl.envs.dmc.dmc_tasks.quadruped import make_model - - _ = make_model(floor_size=1) + _ = make_quadruped_model(floor_size=1) def test_instantiate_env_with_context(self): - from carl.envs.dmc.carl_dm_quadruped import CARLDmcQuadrupedEnv - tasks = ["escape_context", "run_context", "walk_context", "fetch_context"] for task in tasks: _ = CARLDmcQuadrupedEnv( @@ -119,3 +127,59 @@ def test_instantiate_env_with_context(self): }, task=task, ) + + +class TestFish: + def test_make_model(self): + _ = make_quadruped_model(floor_size=1) + + def test_instantiate_env_with_context(self): + tasks = ["swim_context", "upright_context"] + for task in tasks: + _ = CARLDmcFishEnv( + contexts={ + 0: { + "gravity": -10, + } + }, + task=task, + ) + + +class TestPointmass: + def test_make_model(self): + _ = make_pointmass_model(floor_size=1) + + def test_instantiate_env_with_context(self): + tasks = ["easy_pointmass", "hard_pointmass"] + for task in tasks: + _ = CARLDmcPointMassEnv( + contexts={ + 0: { + "starting_x": 0.3, + } + }, + task=task, + ) + + def test_constraints(self): + # Is starting point inside grid? + with pytest.raises(ValueError): + check_constraints_pointmass( + mass=0.3, + starting_x=0.3, + starting_y=0.3, + target_x=0.0, + target_y=0.0, + area_size=0.6, + ) + # Is target inside grid? + with pytest.raises(ValueError): + check_constraints_pointmass( + mass=0.3, + starting_x=0.0, + starting_y=0.0, + target_x=0.3, + target_y=0.3, + area_size=0.6, + ) diff --git a/test/test_gymnasium_envs.py b/test/test_gymnasium_envs.py index 5be182c2..246f002e 100644 --- a/test/test_gymnasium_envs.py +++ b/test/test_gymnasium_envs.py @@ -1,6 +1,9 @@ import inspect import unittest +import gymnasium as gym + +import carl import carl.envs.gymnasium @@ -12,7 +15,6 @@ def test_envs(self): if inspect.isclass(env_obj) and "CARL" in env_name: try: env_obj.get_context_features() - env = env_obj() env._progress_instance() env._update_context() @@ -21,5 +23,21 @@ def test_envs(self): raise e +class TestGymnasiumRegistration(unittest.TestCase): + def test_registration(self): + registered_envs = gym.envs.registration.registry.keys() + for e in carl.envs.__all__: + if "RNA" not in e and "Brax" not in e: + env_name = f"carl/{e}-v0" + self.assertTrue(env_name in registered_envs) + + def test_make(self): + for e in carl.envs.__all__: + if "RNA" not in e and "Brax" not in e: + env_name = f"carl/{e}-v0" + env = gym.make(env_name) + self.assertTrue(isinstance(env, gym.Env)) + + if __name__ == "__main__": - TestGymnasiumEnvs().test_envs() + unittest.main() diff --git a/test/test_language_goals.py b/test/test_language_goals.py new file mode 100644 index 00000000..0ad5d09b --- /dev/null +++ b/test/test_language_goals.py @@ -0,0 +1,238 @@ +import unittest + +from carl.context.context_space import ( + CategoricalContextFeature, + NormalFloatContextFeature, +) +from carl.context.sampler import ContextSampler +from carl.envs import CARLBraxAnt, CARLBraxHalfcheetah +from carl.envs.brax.brax_walker_goal_wrapper import ( + BraxLanguageWrapper, + BraxWalkerGoalWrapper, +) + +DIRECTIONS = [ + 1, # north + 3, # south + 2, # east + 4, # west + 12, + 32, + 14, + 34, + 112, + 332, + 114, + 334, + 212, + 232, + 414, + 434, +] + + +class TestGoalSampling(unittest.TestCase): + def test_uniform_sampling(self): + context_distributions = [ + NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), + CategoricalContextFeature("target_direction", choices=DIRECTIONS), + ] + context_sampler = ContextSampler( + context_distributions=context_distributions, + context_space=CARLBraxAnt.get_context_space(), + seed=0, + ) + contexts = context_sampler.sample_contexts(n_contexts=10) + assert len(contexts.keys()) == 10 + assert "target_distance" in contexts[0].keys(), "target_distance not in context" + assert ( + "target_direction" in contexts[0].keys() + ), "target_direction not in context" + assert all( + [contexts[i]["target_direction"] in DIRECTIONS for i in range(10)] + ), "Not all directions are valid." + assert all( + [contexts[i]["target_distance"] <= 200 for i in range(10)] + ), "Not all distances are valid (too large)." + assert all( + [contexts[i]["target_distance"] >= 4 for i in range(10)] + ), "Not all distances are valid (too small)." + + def test_normal_sampling(self): + context_distributions = [ + NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), + CategoricalContextFeature("target_direction", choices=DIRECTIONS), + ] + context_sampler = ContextSampler( + context_distributions=context_distributions, + context_space=CARLBraxAnt.get_context_space(), + seed=0, + ) + contexts = context_sampler.sample_contexts(n_contexts=10) + assert ( + len(contexts.keys()) == 10 + ), "Number of sampled contexts does not match the requested number." + assert "target_distance" in contexts[0].keys(), "target_distance not in context" + assert ( + "target_direction" in contexts[0].keys() + ), "target_direction not in context" + assert all( + [contexts[i]["target_direction"] in DIRECTIONS for i in range(10)] + ), "Not all directions are valid." + assert all( + [contexts[i]["target_distance"] <= 200 for i in range(10)] + ), "Not all distances are valid (too large)." + assert all( + [contexts[i]["target_distance"] >= 4 for i in range(10)] + ), "Not all distances are valid (too small)." + + +class TestGoalWrapper(unittest.TestCase): + def test_reset(self): + context_distributions = [ + NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), + CategoricalContextFeature("target_direction", choices=DIRECTIONS), + ] + context_sampler = ContextSampler( + context_distributions=context_distributions, + context_space=CARLBraxAnt.get_context_space(), + seed=0, + ) + contexts = context_sampler.sample_contexts(n_contexts=10) + env = CARLBraxAnt(contexts=contexts) + + assert isinstance(env.env, BraxWalkerGoalWrapper) + assert env.position is None, "Position set before reset." + + state, info = env.reset() + assert state is not None, "No state returned." + assert info is not None, "No info returned." + assert env.position is not None, "Position not set." + + context_distributions = [ + NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), + CategoricalContextFeature("target_direction", choices=DIRECTIONS), + ] + context_sampler = ContextSampler( + context_distributions=context_distributions, + context_space=CARLBraxHalfcheetah.get_context_space(), + seed=0, + ) + contexts = context_sampler.sample_contexts(n_contexts=10) + env = CARLBraxHalfcheetah(contexts=contexts, use_language_goals=True) + + assert isinstance(env.env, BraxLanguageWrapper), "Language wrapper not used." + assert env.position is None, "Position set before reset." + + state, info = env.reset() + assert state is not None, "No state returned." + assert info is not None, "No info returned." + assert env.position is not None, "Position not set." + + def test_reward_scale(self): + context_distributions = [ + NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), + CategoricalContextFeature("target_direction", choices=DIRECTIONS), + ] + context_sampler = ContextSampler( + context_distributions=context_distributions, + context_space=CARLBraxAnt.get_context_space(), + seed=0, + ) + contexts = context_sampler.sample_contexts(n_contexts=10) + env = CARLBraxAnt(contexts=contexts) + + for _ in range(10): + env.reset() + for _ in range(10): + action = env.action_space.sample() + _, wrapped_reward, _, _, _ = env.step(action) + assert wrapped_reward >= 0, "Negative reward." + + context_distributions = [ + NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), + CategoricalContextFeature("target_direction", choices=DIRECTIONS), + ] + context_sampler = ContextSampler( + context_distributions=context_distributions, + context_space=CARLBraxHalfcheetah.get_context_space(), + seed=0, + ) + contexts = context_sampler.sample_contexts(n_contexts=10) + env = CARLBraxHalfcheetah(contexts=contexts) + + for _ in range(10): + env.reset() + for _ in range(10): + action = env.action_space.sample() + _, wrapped_reward, _, _, _ = env.step(action) + assert wrapped_reward >= 0, "Negative reward." + + +class TestLanguageWrapper(unittest.TestCase): + def test_reset(self) -> None: + context_distributions = [ + NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), + CategoricalContextFeature("target_direction", choices=DIRECTIONS), + ] + context_sampler = ContextSampler( + context_distributions=context_distributions, + context_space=CARLBraxAnt.get_context_space(), + seed=0, + ) + contexts = context_sampler.sample_contexts(n_contexts=10) + env = CARLBraxAnt(contexts=contexts, use_language_goals=True) + state, info = env.reset() + assert type(state) is dict, "State is not a dictionary." + assert "obs" in state.keys(), "Observation not in state." + assert "goal" in state["obs"].keys(), "Goal not in observation." + assert type(state["obs"]["goal"]) is str, "Goal is not a string." + assert ( + str(env.context["target_distance"]) in state["obs"]["goal"] + ), "Distance not in goal." + assert "north north east" in state["obs"]["goal"], "Direction not in goal." + assert info is not None, "No info returned." + + def test_step(self): + context_distributions = [ + NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), + CategoricalContextFeature("target_direction", choices=DIRECTIONS), + ] + context_sampler = ContextSampler( + context_distributions=context_distributions, + context_space=CARLBraxHalfcheetah.get_context_space(), + seed=0, + ) + contexts = context_sampler.sample_contexts(n_contexts=10) + env = CARLBraxHalfcheetah(contexts=contexts, use_language_goals=True) + env.reset() + for _ in range(10): + action = env.action_space.sample() + state, _, _, _, _ = env.step(action) + assert type(state) is dict, "State is not a dictionary." + assert "obs" in state.keys(), "Observation not in state." + assert "goal" in state["obs"].keys(), "Goal not in observation." + assert type(state["obs"]["goal"]) is str, "Goal is not a string." + assert "north north east" in state["obs"]["goal"], "Direction not in goal." + assert ( + str(env.context["target_distance"]) in state["obs"]["goal"] + ), "Distance not in goal." + + context_distributions = [ + NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), + CategoricalContextFeature("target_direction", choices=DIRECTIONS), + ] + context_sampler = ContextSampler( + context_distributions=context_distributions, + context_space=CARLBraxAnt.get_context_space(), + seed=0, + ) + contexts = context_sampler.sample_contexts(n_contexts=10) + env = CARLBraxAnt(contexts=contexts) + env.reset() + for _ in range(10): + action = env.action_space.sample() + state, _, _, _, _ = env.step(action) + assert type(state) is dict, "State is not a dictionary." + assert "obs" in state.keys(), "Observation not in state." + assert "goal" not in state.keys(), "Goal in observation." diff --git a/test/test_search_space_encoding.py b/test/test_search_space_encoding.py new file mode 100644 index 00000000..2c1a4e84 --- /dev/null +++ b/test/test_search_space_encoding.py @@ -0,0 +1,81 @@ +import unittest + +from ConfigSpace import ConfigurationSpace +from omegaconf import DictConfig + +from carl.context.search_space_encoding import search_space_to_config_space + +dict_space = { + "uniform_integer": (1, 10), + "uniform_float": (1.0, 10.0), + "categorical": ["a", "b", "c"], + "constant": 1337, +} + +dict_space_2 = { + "hyperparameters": [ + { + "name": "x0", + "type": "uniform_float", + "log": False, + "lower": -512.0, + "upper": 512.0, + "default": -3.0, + "q": None, + }, + { + "name": "x1", + "type": "uniform_float", + "log": False, + "lower": -512.0, + "upper": 512.0, + "default": -4.0, + "q": None, + }, + ], + "conditions": [], + "forbiddens": [], + "python_module_version": "0.4.17", + "json_format_version": 0.2, +} + +str_space = """{ + "uniform_integer": (1, 10), + "uniform_float": (1.0, 10.0), + "categorical": ["a", "b", "c"], + "constant": 1337, + }""" + + +class TestSearchSpacEncoding(unittest.TestCase): + def setUp(self): + self.test_space = None + self.test_space = ConfigurationSpace(name="myspace", space=dict_space) + return super().setUp() + + def test_ss_as_cs(self): + try: + search_space_to_config_space(self.test_space) + except Exception as e: + print(f"Cannot encode search space -- {self.test_space}.") + raise e + + def test_ss_as_dictconfig(self): + try: + dict_space = DictConfig({"hyperparameters": {}}) + + search_space_to_config_space(dict_space) + except Exception as e: + print(f"Cannot encode search space -- {dict_space}.") + raise e + + def test_ss_as_dict(self): + try: + search_space_to_config_space(dict_space_2) + except Exception as e: + print(f"Cannot encode search space -- {dict_space_2}.") + raise e + + +if __name__ == "__main__": + unittest.main()