diff --git a/.codecov.yml b/.codecov.yml
deleted file mode 100644
index a92352ef..00000000
--- a/.codecov.yml
+++ /dev/null
@@ -1,46 +0,0 @@
-#see https://github.com/codecov/support/wiki/Codecov-Yaml
-codecov:
- require_ci_to_pass: yes
-
-coverage:
-
- # 2 = xx.xx%, 0 = xx%
- precision: 2
-
- # https://docs.codecov.com/docs/commit-status
- status:
-
- # We want our total main project to always remain above 87% coverage, a
- # drop of 0.20% is allowed. It should fail if coverage couldn't be uploaded
- # of the CI fails otherwise
- project:
- default:
- target: 10%
- threshold: 0.20%
- if_not_found: failure
- if_ci_failed: error
-
- # The code changed by a PR should have 90% coverage. This is different from the
- # overall number shown above.
- # This encourages small PR's as they are easier to test.
- patch:
- default:
- target: 10%
- if_not_found: failure
- if_ci_failed: failure
-
-# We upload additional information on branching with pytest-cov `--cov-branch`
-# This information can be used by codecov.com to increase analysis of code
-parsers:
- gcov:
- branch_detection:
- conditional: true
- loop: true
- method: true
- macro: false
-
-
-comment:
- layout: diff, reach
- behavior: default
- require_changes: false
diff --git a/.github/ISSUE_TEMPLATE/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE/ISSUE_TEMPLATE.md
new file mode 100644
index 00000000..b79752ce
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/ISSUE_TEMPLATE.md
@@ -0,0 +1,37 @@
+---
+name: Issue Template
+about: General template issues
+labels:
+
+---
+
+* {{ cookiecutter.project_name }} version:
+* Python version:
+* Operating System:
+
+
+
+
+#### Description
+
+
+#### Steps/Code to Reproduce
+
+
+#### Expected Results
+
+
+#### Actual Results
+
+
+#### Additional Info
+
+- Did you try upgrading to the most current version? yes/no
+- Are you using a supported operating system (version)? yes/no
+- How did you install this package (e.g. GitHub, pip, etc.)?
+
+
\ No newline at end of file
diff --git a/.github/PULL_REQUEST_TEMPLATE/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 00000000..6556baad
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,39 @@
+
+
+#### Reference Issues/PRs
+
+
+#### What does this implement/fix? Explain your changes.
+
+
+
+#### Checklist
+
+- Are the tests passing locally? yes/no
+- Is the pre-commit passing locally? yes/no
+- Are all new features documented in code and docs? yes/no
+- Are all examples still running? yes/no
+- Are the requirements up to date? yes/no
+- Did you add yourself to the contributors in the authors file? yes/no
+
+#### Any other comments?
+
+
\ No newline at end of file
diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
index 26a20757..4649c1ef 100644
--- a/.github/workflows/tests.yaml
+++ b/.github/workflows/tests.yaml
@@ -21,10 +21,6 @@ on:
- main
- development
- schedule:
- # Every day at 7AM UTC
- - cron: '0 07 * * *'
-
env:
# Arguments used for pytest
@@ -47,32 +43,17 @@ jobs:
strategy:
fail-fast: false
matrix:
- os: [windows-latest, macos-latest, ubuntu-latest]
- python-version: ['3.9', '3.10']
- kind: ['conda', 'source', 'dist']
-
- exclude:
- # Exclude all configurations *-*-dist, but include one later in `include`
- - kind: 'dist'
-
- # Exclude windows as bash commands wont work in windows runner
- - os: windows-latest
-
- # Exclude macos as there are permission errors using conda as we do
- - os: macos-latest
+ os: [ubuntu-latest]
+ python-version: ['3.9', '3.10', '3.11']
+ kind: ['conda']
include:
# Add the tag code-cov to ubuntu-3.7-source
- os: ubuntu-latest
python-version: 3.9
- kind: 'source'
+ kind: 'conda'
code-cov: true
- # Include one config with dist, ubuntu-3.7-dist
- - os: ubuntu-latest
- python-version: 3.9
- kind: 'dist'
-
steps:
- name: Checkout
@@ -91,21 +72,7 @@ jobs:
# Miniconda is available in $CONDA env var
$CONDA/bin/conda create -n testenv --yes pip wheel gxx_linux-64 gcc_linux-64 python=${{ matrix.python-version }}
$CONDA/envs/testenv/bin/python3 -m pip install --upgrade pip
- $CONDA/envs/testenv/bin/pip3 install -e ".[dev,box2d,brax,dm_control,mario]"
-
- - name: Source install
- if: matrix.kind == 'source'
- run: |
- python -m pip install --upgrade pip
- pip install -e ".[dev,box2d,brax,dm_control,mario]"
-
- - name: Dist install
- if: matrix.kind == 'dist'
- run: |
- python -m pip install --upgrade pip
- python setup.py sdist
- last_dist=$(ls -t dist/carl-*.tar.gz | head -n 1)
- pip install $last_dist[dev,box2d,brax,dm_control,mario]
+ $CONDA/envs/testenv/bin/pip3 install -e .[dev,dm_control,mario,brax,box2d]
- name: Tests
timeout-minutes: 60
@@ -123,10 +90,3 @@ jobs:
else
$PYTHON -m pytest ${{ env.pytest-args }} --ignore=test/local_only test
fi
-
- - name: Upload coverage
- if: matrix.code-cov && always()
- uses: codecov/codecov-action@v2
- with:
- fail_ci_if_error: true
- verbose: true
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
deleted file mode 100644
index 98fba55a..00000000
--- a/.readthedocs.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-# use version 2, which is now recommended
-version: 2
-
-build:
- os: ubuntu-20.04
- tools:
- python: "3.9"
-
-# Build from the docs/ directory with Sphinx
-sphinx:
- configuration: docs/conf.py
-
-# build all
-formats: all
-
-# Explicitly set the version of Python and its requirements
-python:
- install:
- - method: pip
- path: .
- extra_requirements:
- - docs
diff --git a/CITATION.bib b/CITATION.bib
deleted file mode 100644
index f6a2b599..00000000
--- a/CITATION.bib
+++ /dev/null
@@ -1,14 +0,0 @@
-@inproceedings { BenEim2023a,
- author = {Carolin Benjamins and
- Theresa Eimer and
- Frederik Schubert and
- Aditya Mohan and
- Sebastian Döhler and
- André Biedenkapp and
- Bodo Rosenhahn and
- Frank Hutter and
- Marius Lindauer},
- title = {Contextualize Me - The Case for Context in Reinforcement Learning},
- journal = {Transactions on Machine Learning Research},
- year = {2023},
-}
diff --git a/CITATION.cff b/CITATION.cff
new file mode 100644
index 00000000..509cd5f7
--- /dev/null
+++ b/CITATION.cff
@@ -0,0 +1,69 @@
+cff-version: 1.2.0
+message: "If you use this software, please cite our paper:"
+
+url: "https://automl.github.io/CARL/"
+repository-code: "https://github.com/automl/CARL"
+title: "CARL - Context Adaptive Reinforcement Learning"
+
+authors:
+ - family-names: "Benjamins"
+ given-names: "Carolin"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Eimer"
+ given-names: "Theresa"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Schubert"
+ given-names: "Frederik"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Mohan"
+ given-names: "Aditya"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Döhler"
+ given-names: "Sebastian"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Biedenkapp"
+ given-names: "André"
+ affiliation: "Albert-Ludwigs University Freiburg, Germany"
+ - family-names: "Rosenhahn"
+ given-names: "Bodo"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Hutter"
+ given-names: "Frank"
+ affiliation: "Albert-Ludwigs University Freiburg, Germany"
+ - family-names: "Lindauer"
+ given-names: "Marius"
+ affiliation: "Leibniz University Hannover, Germany"
+
+preferred-citation:
+ type: "article"
+ title: "Contextualize Me - The Case for Context in Reinforcement Learning"
+ year: 2023
+ journal: "Transactions on Machine Learning Research"
+ authors:
+ - family-names: "Benjamins"
+ given-names: "Carolin"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Eimer"
+ given-names: "Theresa"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Schubert"
+ given-names: "Frederik"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Mohan"
+ given-names: "Aditya"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Döhler"
+ given-names: "Sebastian"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Biedenkapp"
+ given-names: "André"
+ affiliation: "Albert-Ludwigs University Freiburg, Germany"
+ - family-names: "Rosenhahn"
+ given-names: "Bodo"
+ affiliation: "Leibniz University Hannover, Germany"
+ - family-names: "Hutter"
+ given-names: "Frank"
+ affiliation: "Albert-Ludwigs University Freiburg, Germany"
+ - family-names: "Lindauer"
+ given-names: "Marius"
+ affiliation: "Leibniz University Hannover, Germany"
\ No newline at end of file
diff --git a/Makefile b/Makefile
index 995451d5..968def34 100644
--- a/Makefile
+++ b/Makefile
@@ -40,6 +40,9 @@ install-dev:
$(PIP) install -e ".[dev, docs]"
pre-commit install
+install:
+ $(PIP) install -e .
+
check-black:
$(BLACK) carl test --check || :
@@ -63,7 +66,7 @@ pre-commit:
$(PRECOMMIT) run --all-files
format-black:
- $(BLACK) carl test
+ $(BLACK) carl test examples
format-isort:
$(ISORT) carl test
@@ -71,7 +74,10 @@ format-isort:
format: format-black format-isort
test:
- $(PYTEST) test
+ $(PYTEST) --disable-warnings test
+
+cov-report:
+ coverage html -d coverage_html
clean-doc:
$(MAKE) -C ${DOCDIR} clean
diff --git a/README.md b/README.md
index 2a5ed021..c379ec96 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,11 @@
# – The Benchmark Library
+[![PyPI Version](https://img.shields.io/pypi/v/carl-bench.svg)](https://pypi.python.org/pypi/carl-bench)
+[![Test](https://github.com/automl/carl/actions/workflows/tests.yaml/badge.svg)](https://github.com/automl/carl/actions/workflows/tests.yaml)
+[![Doc Status](https://github.com/automl/carl/actions/workflows/docs.yaml/badge.svg)](https://github.com/automl/carl/actions/workflows/docs.yaml)
+[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
+
CARL (context adaptive RL) provides highly configurable contextual extensions
to several well-known RL environments.
It's designed to test your agent's generalization capabilities
diff --git a/carl/__init__.py b/carl/__init__.py
index f1391313..a7bb344f 100644
--- a/carl/__init__.py
+++ b/carl/__init__.py
@@ -1,9 +1,11 @@
__license__ = "Apache-2.0 License"
-__version__ = "1.0.0"
+__version__ = "1.1.0"
__author__ = "Carolin Benjamins, Theresa Eimer, Frederik Schubert, André Biedenkapp, Aditya Mohan, Sebastian Döhler"
import datetime
+import importlib.util as iutil
+import warnings
name = "CARL"
package_name = "carl-bench"
@@ -20,3 +22,66 @@
Copyright {datetime.date.today().strftime('%Y')}, AutoML.org Freiburg-Hannover
"""
version = __version__
+
+try:
+ from gymnasium.envs.registration import register
+
+ from carl import envs
+
+ for e in envs.gymnasium.classic_control.__all__:
+ register(
+ id=f"carl/{e}-v0",
+ entry_point=f"carl.envs.gymnasium.classic_control:{e}",
+ )
+
+ def check_spec(spec_name: str) -> bool:
+ """Check if the spec is installed
+
+ Parameters
+ ----------
+ spec_name : str
+ Name of package that is necessary for the environment suite.
+
+ Returns
+ -------
+ bool
+ Whether the spec was found.
+ """
+ spec = iutil.find_spec(spec_name)
+ found = spec is not None
+ if not found:
+ with warnings.catch_warnings():
+ warnings.simplefilter("once")
+ warnings.warn(
+ f"""Module {spec_name} not found. If you want to use these environments,
+ please follow the installation guide."""
+ )
+ return found
+
+ found = check_spec("Box2D")
+ if found:
+ for e in envs.gymnasium.box2d.__all__:
+ register(
+ id=f"carl/{e}-v0",
+ entry_point=f"carl.envs.gymnasium.box2d:{e}",
+ )
+
+ found = check_spec("dm_control")
+ if found:
+ for e in envs.dmc.__all__:
+ register(
+ id=f"carl/{e}-v0",
+ entry_point=f"carl.envs.dmc:{e}",
+ )
+
+ found = check_spec("py4j")
+ if found:
+ register(
+ id="carl/CARLMarioEnv-v0",
+ entry_point="carl.envs.mario:CARLMarioEnv",
+ )
+except:
+ print(
+ """Gym registration failed - this is normal during installation.
+ After that, please check that gymnasium is installed correctly."""
+ )
diff --git a/carl/context/context_space.py b/carl/context/context_space.py
index 4409a370..8a8cf5ce 100644
--- a/carl/context/context_space.py
+++ b/carl/context/context_space.py
@@ -219,7 +219,7 @@ def sample_contexts(
contexts = []
for _ in range(size):
- context = {cf.name: cf.sample() for cf in self.context_space.values()}
+ context = {cf.name: cf.rvs() for cf in self.context_space.values()}
context = self.insert_defaults(context, context_keys)
contexts += [context]
diff --git a/carl/context/search_space_encoding.py b/carl/context/search_space_encoding.py
index 893a71d4..ae955d3a 100644
--- a/carl/context/search_space_encoding.py
+++ b/carl/context/search_space_encoding.py
@@ -106,11 +106,11 @@ def search_space_to_config_space(
-------
ConfigurationSpace
"""
- if type(search_space) == str:
+ if isinstance(search_space, str):
with open(search_space, "r") as f:
jason_string = f.read()
cs = csjson.read(jason_string)
- elif type(search_space) == DictConfig:
+ elif isinstance(search_space, DictConfig):
# reorder hyperparameters as List[Dict]
hyperparameters = []
for name, cfg in search_space.hyperparameters.items():
@@ -130,8 +130,10 @@ def search_space_to_config_space(
jason_string = json.dumps(search_space, cls=JSONCfgEncoder)
cs = csjson.read(jason_string)
- elif type(search_space) == ConfigurationSpace:
+ elif isinstance(search_space, ConfigurationSpace):
cs = search_space
+ elif isinstance(search_space, dict):
+ cs = csjson.read(json.dumps(search_space))
else:
raise ValueError(
f"search_space must be of type str or DictConfig. Got {type(search_space)}."
diff --git a/carl/envs/__init__.py b/carl/envs/__init__.py
index 7d34ffc1..0578a7a5 100644
--- a/carl/envs/__init__.py
+++ b/carl/envs/__init__.py
@@ -6,6 +6,14 @@
# Classic control is in gym and thus necessary for the base version to run
from carl.envs.gymnasium import *
+__all__ = [
+ "CARLAcrobot",
+ "CARLCartPole",
+ "CARLMountainCar",
+ "CARLMountainCarContinuous",
+ "CARLPendulum",
+]
+
def check_spec(spec_name: str) -> bool:
"""Check if the spec is installed
@@ -36,18 +44,44 @@ def check_spec(spec_name: str) -> bool:
if found:
from carl.envs.gymnasium.box2d import *
+ __all__ += ["CARLBipedalWalker", "CARLLunarLander", "CARLVehicleRacing"]
+
found = check_spec("brax")
if found:
from carl.envs.brax import *
+ __all__ += [
+ "CARLBraxAnt",
+ "CARLBraxHalfcheetah",
+ "CARLBraxHopper",
+ "CARLBraxHumanoid",
+ "CARLBraxHumanoidStandup",
+ "CARLBraxInvertedDoublePendulum",
+ "CARLBraxInvertedPendulum",
+ "CARLBraxPusher",
+ "CARLBraxReacher",
+ "CARLBraxWalker2d",
+ ]
+
found = check_spec("py4j")
if found:
from carl.envs.mario import *
+ __all__ += ["CARLMarioEnv"]
+
found = check_spec("dm_control")
if found:
from carl.envs.dmc import *
-# found = check_spec("distance")
-# if found:
-# from carl.envs.rna import *
+ __all__ += [
+ "CARLDmcFingerEnv",
+ "CARLDmcFishEnv",
+ "CARLDmcQuadrupedEnv",
+ "CARLDmcWalkerEnv",
+ ]
+
+found = check_spec("distance")
+if found:
+ from carl.envs.rna import *
+
+ __all__ += ["CARLRnaDesignEnv"]
diff --git a/carl/envs/brax/brax_walker_goal_wrapper.py b/carl/envs/brax/brax_walker_goal_wrapper.py
new file mode 100644
index 00000000..9c7a49c8
--- /dev/null
+++ b/carl/envs/brax/brax_walker_goal_wrapper.py
@@ -0,0 +1,180 @@
+import gym
+import numpy as np
+from brax.io import mjcf
+from etils import epath
+
+STATE_INDICES = {
+ "ant": [13, 14],
+ "humanoid": [22, 23],
+ "halfcheetah": [14, 15],
+ "hopper": [5, 6],
+ "walker2d": [8, 9],
+}
+
+DIRECTION_NAMES = {
+ 1: "north",
+ 3: "south",
+ 2: "east",
+ 4: "west",
+ 12: "north east",
+ 32: "south east",
+ 14: "north west",
+ 34: "south west",
+ 112: "north north east",
+ 332: "south south east",
+ 114: "north north west",
+ 334: "south south west",
+ 212: "east north east",
+ 232: "east south east",
+ 414: "west north west",
+ 434: "west south west",
+}
+
+directions = [
+ 1, # north
+ 3, # south
+ 2, # east
+ 4, # west
+ 12,
+ 32,
+ 14,
+ 34,
+ 112,
+ 332,
+ 114,
+ 334,
+ 212,
+ 232,
+ 414,
+ 434,
+]
+
+
+class BraxWalkerGoalWrapper(gym.Wrapper):
+ """Adds a positional goal to brax walker envs"""
+
+ def __init__(self, env: gym.Env, env_name: str, asset_path: str) -> None:
+ super().__init__(env)
+ self.env_name = env_name
+ if (
+ self.env_name == "humanoid"
+ or self.env_name == "halfcheetah"
+ or self.env_name == "hopper"
+ or self.env_name == "walker2d"
+ ):
+ self.env._forward_reward_weight = 0
+ self.context = None
+ self.position = None
+ self.goal_position = None
+ self.goal_radius = None
+ self.direction_values = {
+ 3: [0, -1],
+ 1: [0, 1],
+ 2: [1, 0],
+ 4: [-1, 0],
+ 34: [-np.sqrt(0.5), -np.sqrt(0.5)],
+ 14: [-np.sqrt(0.5), np.sqrt(0.5)],
+ 32: [np.sqrt(0.5), -np.sqrt(0.5)],
+ 12: [np.sqrt(0.5), np.sqrt(0.5)],
+ 334: [
+ -np.cos(22.5 * np.pi / 180),
+ -np.sin(22.5 * np.pi / 180),
+ ],
+ 434: [
+ -np.sin(22.5 * np.pi / 180),
+ -np.cos(22.5 * np.pi / 180),
+ ],
+ 114: [
+ -np.cos(22.5 * np.pi / 180),
+ np.sin(22.5 * np.pi / 180),
+ ],
+ 414: [
+ -np.sin(22.5 * np.pi / 180),
+ np.cos(22.5 * np.pi / 180),
+ ],
+ 332: [
+ np.cos(22.5 * np.pi / 180),
+ -np.sin(22.5 * np.pi / 180),
+ ],
+ 232: [
+ np.sin(22.5 * np.pi / 180),
+ -np.cos(22.5 * np.pi / 180),
+ ],
+ 112: [
+ np.cos(22.5 * np.pi / 180),
+ np.sin(22.5 * np.pi / 180),
+ ],
+ 212: [np.sin(22.5 * np.pi / 180), np.cos(22.5 * np.pi / 180)],
+ }
+ path = epath.resource_path("brax") / asset_path
+ sys = mjcf.load(path)
+ self.dt = sys.dt
+
+ def reset(self, seed=None, options={}):
+ state, info = self.env.reset(seed=seed, options=options)
+ self.position = (0, 0)
+ self.goal_position = (
+ np.array(self.direction_values[self.context["target_direction"]])
+ * self.context["target_distance"]
+ )
+ self.goal_radius = self.context["target_radius"]
+ info["success"] = 0
+ return state, info
+
+ def step(self, action):
+ state, _, te, tr, info = self.env.step(action)
+ indices = STATE_INDICES[self.env_name]
+ new_position = (
+ np.array(list(self.position))
+ + np.array([state[indices[0]], state[indices[1]]]) * self.dt
+ )
+ current_distance_to_goal = np.linalg.norm(self.goal_position - new_position)
+ previous_distance_to_goal = np.linalg.norm(self.goal_position - self.position)
+ direction_reward = max(0, previous_distance_to_goal - current_distance_to_goal)
+ self.position = new_position
+ if abs(current_distance_to_goal) <= self.goal_radius:
+ te = True
+ info["success"] = 1
+ else:
+ info["success"] = 0
+ return state, direction_reward, te, tr, info
+
+
+class BraxLanguageWrapper(gym.Wrapper):
+ """Translates the context features target distance and target radius into language"""
+
+ def __init__(self, env) -> None:
+ super().__init__(env)
+ self.context = None
+
+ def reset(self, seed=None, options={}):
+ self.env.context = self.context
+ state, info = self.env.reset(seed=seed, options=options)
+ goal_str = self.get_goal_desc(self.context)
+ if isinstance(state, dict):
+ state["goal"] = goal_str
+ else:
+ state = {"obs": state, "goal": goal_str}
+ return state, info
+
+ def step(self, action):
+ state, reward, te, tr, info = self.env.step(action)
+ goal_str = self.get_goal_desc(self.context)
+ if isinstance(state, dict):
+ state["goal"] = goal_str
+ else:
+ state = {"obs": state, "goal": goal_str}
+ return state, reward, te, tr, info
+
+ def get_goal_desc(self, context):
+ if "target_radius" in context.keys():
+ target_distance = context["target_distance"]
+ target_direction = context["target_direction"]
+ target_radius = context["target_radius"]
+ return f"""The distance to the goal is {target_distance}m
+ {DIRECTION_NAMES[target_direction]}.
+ Move within {target_radius} steps of the goal."""
+ else:
+ target_distance = context["target_distance"]
+ target_direction = context["target_direction"]
+ return f"Move {target_distance}m {DIRECTION_NAMES[target_direction]}."
diff --git a/carl/envs/brax/carl_ant.py b/carl/envs/brax/carl_ant.py
index e8bb6d7c..ee181ecc 100644
--- a/carl/envs/brax/carl_ant.py
+++ b/carl/envs/brax/carl_ant.py
@@ -2,13 +2,19 @@
import numpy as np
-from carl.context.context_space import ContextFeature, UniformFloatContextFeature
+from carl.context.context_space import (
+ CategoricalContextFeature,
+ ContextFeature,
+ UniformFloatContextFeature,
+)
+from carl.envs.brax.brax_walker_goal_wrapper import directions
from carl.envs.brax.carl_brax_env import CARLBraxEnv
class CARLBraxAnt(CARLBraxEnv):
env_name: str = "ant"
asset_path: str = "envs/assets/ant.xml"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
@@ -31,4 +37,13 @@ def get_context_features() -> dict[str, ContextFeature]:
"viscosity": UniformFloatContextFeature(
"viscosity", lower=0, upper=np.inf, default_value=0
),
+ "target_distance": UniformFloatContextFeature(
+ "target_distance", lower=0, upper=np.inf, default_value=100
+ ),
+ "target_direction": CategoricalContextFeature(
+ "target_direction", choices=directions, default_value=1
+ ),
+ "target_radius": UniformFloatContextFeature(
+ "target_radius", lower=0.1, upper=np.inf, default_value=5
+ ),
}
diff --git a/carl/envs/brax/carl_brax_env.py b/carl/envs/brax/carl_brax_env.py
index dffc0573..3b774c6e 100644
--- a/carl/envs/brax/carl_brax_env.py
+++ b/carl/envs/brax/carl_brax_env.py
@@ -13,9 +13,13 @@
from jax import numpy as jp
from carl.context.selection import AbstractSelector
+from carl.envs.brax.brax_walker_goal_wrapper import (
+ BraxLanguageWrapper,
+ BraxWalkerGoalWrapper,
+)
from carl.envs.brax.wrappers import GymWrapper, VectorGymWrapper
from carl.envs.carl_env import CARLEnv
-from carl.utils.types import Contexts
+from carl.utils.types import Context, Contexts
def set_geom_attr(
@@ -152,6 +156,7 @@ def __init__(
obs_context_as_dict: bool = True,
context_selector: AbstractSelector | type[AbstractSelector] | None = None,
context_selector_kwargs: dict = None,
+ use_language_goals: bool = False,
**kwargs,
) -> None:
"""
@@ -204,6 +209,37 @@ def __init__(
dtype=np.float32,
)
+ if contexts is not None:
+ if (
+ "target_distance" in contexts[list(contexts.keys())[0]].keys()
+ or "target_direction" in contexts[list(contexts.keys())[0]].keys()
+ ):
+ assert all(
+ [
+ "target_direction" in contexts[list(contexts.keys())[i]].keys()
+ for i in range(len(contexts))
+ ]
+ ), "All contexts must have a 'target_direction' key"
+ assert all(
+ [
+ "target_distance" in contexts[list(contexts.keys())[i]].keys()
+ for i in range(len(contexts))
+ ]
+ ), "All contexts must have a 'target_distance' key"
+ base_dir = contexts[list(contexts.keys())[0]]["target_direction"]
+ base_dist = contexts[list(contexts.keys())[0]]["target_distance"]
+ max_diff_dir = max(
+ [c["target_direction"] - base_dir for c in contexts.values()]
+ )
+ max_diff_dist = max(
+ [c["target_distance"] - base_dist for c in contexts.values()]
+ )
+ if max_diff_dir > 0.1 or max_diff_dist > 0.1:
+ env = BraxWalkerGoalWrapper(env, self.env_name, self.asset_path)
+ if use_language_goals:
+ env = BraxLanguageWrapper(env)
+ self.use_language_goals = use_language_goals
+
super().__init__(
env=env,
contexts=contexts,
@@ -213,6 +249,7 @@ def __init__(
context_selector_kwargs=context_selector_kwargs,
**kwargs,
)
+ self.env.context = self.context
def _update_context(self) -> None:
context = self.context
@@ -224,6 +261,9 @@ def _update_context(self) -> None:
"gravity",
"viscosity",
"elasticity",
+ "target_distance",
+ "target_direction",
+ "target_radius",
]
check_context(context, registered_cfs)
@@ -252,3 +292,47 @@ def _update_context(self) -> None:
sys = sys.replace(geoms=updated_geoms)
self.env.unwrapped.sys = sys
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[Any, dict[str, Any]]:
+ """Overwrites reset in super to update context in wrapper."""
+ last_context_id = self.context_id
+ self._progress_instance()
+ if self.context_id != last_context_id:
+ self._update_context()
+ self.env.context = self.context
+ state, info = self.env.reset(seed=seed, options=options)
+ state = self._add_context_to_state(state)
+ info["context_id"] = self.context_id
+ return state, info
+
+ @classmethod
+ def get_default_context(cls) -> Context:
+ """Get the default context (without any goal features)
+
+ Returns
+ -------
+ Context
+ Default context.
+ """
+ default_context = cls.get_context_space().get_default_context()
+ if "target_distance" in default_context:
+ del default_context["target_distance"]
+ if "target_direction" in default_context:
+ del default_context["target_direction"]
+ if "target_radius" in default_context:
+ del default_context["target_radius"]
+ return default_context
+
+ @classmethod
+ def get_default_goal_context(cls) -> Context:
+ """Get the default context (with goal features)
+
+ Returns
+ -------
+ Context
+ Default context.
+ """
+ default_context = cls.get_context_space().get_default_context()
+ return default_context
diff --git a/carl/envs/brax/carl_halfcheetah.py b/carl/envs/brax/carl_halfcheetah.py
index c1a69e46..97c8d7ab 100644
--- a/carl/envs/brax/carl_halfcheetah.py
+++ b/carl/envs/brax/carl_halfcheetah.py
@@ -2,13 +2,19 @@
import numpy as np
-from carl.context.context_space import ContextFeature, UniformFloatContextFeature
+from carl.context.context_space import (
+ CategoricalContextFeature,
+ ContextFeature,
+ UniformFloatContextFeature,
+)
+from carl.envs.brax.brax_walker_goal_wrapper import directions
from carl.envs.brax.carl_brax_env import CARLBraxEnv
class CARLBraxHalfcheetah(CARLBraxEnv):
env_name: str = "halfcheetah"
asset_path: str = "envs/assets/half_cheetah.xml"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
@@ -49,4 +55,13 @@ def get_context_features() -> dict[str, ContextFeature]:
"mass_ffoot": UniformFloatContextFeature(
"mass_ffoot", lower=1e-6, upper=np.inf, default_value=0.8845188
),
+ "target_distance": UniformFloatContextFeature(
+ "target_distance", lower=0, upper=np.inf, default_value=100
+ ),
+ "target_direction": CategoricalContextFeature(
+ "target_direction", choices=directions, default_value=1
+ ),
+ "target_radius": UniformFloatContextFeature(
+ "target_radius", lower=0.1, upper=np.inf, default_value=5
+ ),
}
diff --git a/carl/envs/brax/carl_hopper.py b/carl/envs/brax/carl_hopper.py
index be9c1699..1b042e1b 100644
--- a/carl/envs/brax/carl_hopper.py
+++ b/carl/envs/brax/carl_hopper.py
@@ -2,13 +2,19 @@
import numpy as np
-from carl.context.context_space import ContextFeature, UniformFloatContextFeature
+from carl.context.context_space import (
+ CategoricalContextFeature,
+ ContextFeature,
+ UniformFloatContextFeature,
+)
+from carl.envs.brax.brax_walker_goal_wrapper import directions
from carl.envs.brax.carl_brax_env import CARLBraxEnv
class CARLBraxHopper(CARLBraxEnv):
env_name: str = "hopper"
asset_path: str = "envs/assets/hopper.xml"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
@@ -40,4 +46,13 @@ def get_context_features() -> dict[str, ContextFeature]:
"mass_foot": UniformFloatContextFeature(
"mass_foot", lower=1e-6, upper=np.inf, default_value=5.3155746
),
+ "target_distance": UniformFloatContextFeature(
+ "target_distance", lower=0, upper=np.inf, default_value=100
+ ),
+ "target_direction": CategoricalContextFeature(
+ "target_direction", choices=directions, default_value=1
+ ),
+ "target_radius": UniformFloatContextFeature(
+ "target_radius", lower=0.1, upper=np.inf, default_value=5
+ ),
}
diff --git a/carl/envs/brax/carl_humanoid.py b/carl/envs/brax/carl_humanoid.py
index 27a57146..4ddbe3b0 100644
--- a/carl/envs/brax/carl_humanoid.py
+++ b/carl/envs/brax/carl_humanoid.py
@@ -2,13 +2,19 @@
import numpy as np
-from carl.context.context_space import ContextFeature, UniformFloatContextFeature
+from carl.context.context_space import (
+ CategoricalContextFeature,
+ ContextFeature,
+ UniformFloatContextFeature,
+)
+from carl.envs.brax.brax_walker_goal_wrapper import directions
from carl.envs.brax.carl_brax_env import CARLBraxEnv
class CARLBraxHumanoid(CARLBraxEnv):
env_name: str = "humanoid"
asset_path: str = "envs/assets/humanoid.xml"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
@@ -67,4 +73,13 @@ def get_context_features() -> dict[str, ContextFeature]:
"mass_left_lower_arm": UniformFloatContextFeature(
"mass_left_lower_arm", lower=1e-6, upper=np.inf, default_value=1.2295402
),
+ "target_distance": UniformFloatContextFeature(
+ "target_distance", lower=0, upper=np.inf, default_value=100
+ ),
+ "target_direction": CategoricalContextFeature(
+ "target_direction", choices=directions, default_value=1
+ ),
+ "target_radius": UniformFloatContextFeature(
+ "target_radius", lower=0.1, upper=np.inf, default_value=5
+ ),
}
diff --git a/carl/envs/brax/carl_humanoidstandup.py b/carl/envs/brax/carl_humanoidstandup.py
index 7edb6ef6..1d923bbd 100644
--- a/carl/envs/brax/carl_humanoidstandup.py
+++ b/carl/envs/brax/carl_humanoidstandup.py
@@ -9,6 +9,7 @@
class CARLBraxHumanoidStandup(CARLBraxEnv):
env_name: str = "humanoidstandup"
asset_path: str = "envs/assets/humanoidstandup.xml"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/brax/carl_inverted_double_pendulum.py b/carl/envs/brax/carl_inverted_double_pendulum.py
index 07976e49..ea467ae0 100644
--- a/carl/envs/brax/carl_inverted_double_pendulum.py
+++ b/carl/envs/brax/carl_inverted_double_pendulum.py
@@ -9,6 +9,7 @@
class CARLBraxInvertedDoublePendulum(CARLBraxEnv):
env_name: str = "inverted_double_pendulum"
asset_path: str = "envs/assets/inverted_double_pendulum.xml"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/brax/carl_inverted_pendulum.py b/carl/envs/brax/carl_inverted_pendulum.py
index 280d81f5..831330c6 100644
--- a/carl/envs/brax/carl_inverted_pendulum.py
+++ b/carl/envs/brax/carl_inverted_pendulum.py
@@ -9,6 +9,7 @@
class CARLBraxInvertedPendulum(CARLBraxEnv):
env_name: str = "inverted_pendulum"
asset_path: str = "envs/assets/inverted_pendulum.xml"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/brax/carl_pusher.py b/carl/envs/brax/carl_pusher.py
index d7de1599..19cdec86 100644
--- a/carl/envs/brax/carl_pusher.py
+++ b/carl/envs/brax/carl_pusher.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+from copy import deepcopy
+
import numpy as np
from carl.context.context_space import ContextFeature, UniformFloatContextFeature
@@ -9,6 +11,7 @@
class CARLBraxPusher(CARLBraxEnv):
env_name: str = "pusher"
asset_path: str = "envs/assets/pusher.xml"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
@@ -76,4 +79,25 @@ def get_context_features() -> dict[str, ContextFeature]:
"mass_object": UniformFloatContextFeature(
"mass_object", lower=1e-6, upper=np.inf, default_value=1.8325957e-03
),
+ "goal_position_x": UniformFloatContextFeature(
+ "goal_position_x", lower=0, upper=np.inf, default_value=0.45
+ ),
+ "goal_position_y": UniformFloatContextFeature(
+ "goal_position_y", lower=0, upper=np.inf, default_value=0.05
+ ),
+ "goal_position_z": UniformFloatContextFeature(
+ "goal_position_z", lower=0, upper=np.inf, default_value=0.05
+ ),
}
+
+ def _update_context(self) -> None:
+ goal_x = self.context["goal_position_x"]
+ goal_y = self.context["goal_position_y"]
+ goal_z = self.context["goal_position_z"]
+ context = deepcopy(self.context)
+ del self.context["goal_position_x"]
+ del self.context["goal_position_y"]
+ del self.context["goal_position_z"]
+ super()._update_context()
+ self.env._goal_pos = np.array([goal_x, goal_y, goal_z])
+ self.context = context
diff --git a/carl/envs/brax/carl_reacher.py b/carl/envs/brax/carl_reacher.py
index a6d75b62..5d35d68d 100644
--- a/carl/envs/brax/carl_reacher.py
+++ b/carl/envs/brax/carl_reacher.py
@@ -9,6 +9,7 @@
class CARLBraxReacher(CARLBraxEnv):
env_name: str = "reacher"
asset_path: str = "envs/assets/reacher.xml"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/brax/carl_walker2d.py b/carl/envs/brax/carl_walker2d.py
index 3aa66b89..8d927da0 100644
--- a/carl/envs/brax/carl_walker2d.py
+++ b/carl/envs/brax/carl_walker2d.py
@@ -2,13 +2,19 @@
import numpy as np
-from carl.context.context_space import ContextFeature, UniformFloatContextFeature
+from carl.context.context_space import (
+ CategoricalContextFeature,
+ ContextFeature,
+ UniformFloatContextFeature,
+)
+from carl.envs.brax.brax_walker_goal_wrapper import directions
from carl.envs.brax.carl_brax_env import CARLBraxEnv
class CARLBraxWalker2d(CARLBraxEnv):
env_name: str = "walker2d"
asset_path: str = "envs/assets/walker2d.xml"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
@@ -49,4 +55,13 @@ def get_context_features() -> dict[str, ContextFeature]:
"mass_foot_left": UniformFloatContextFeature(
"mass_foot_left", lower=1e-6, upper=np.inf, default_value=3.1667254
),
+ "target_distance": UniformFloatContextFeature(
+ "target_distance", lower=0, upper=np.inf, default_value=100
+ ),
+ "target_direction": CategoricalContextFeature(
+ "target_direction", choices=directions, default_value=1
+ ),
+ "target_radius": UniformFloatContextFeature(
+ "target_radius", lower=0.1, upper=np.inf, default_value=5
+ ),
}
diff --git a/carl/envs/dmc/__init__.py b/carl/envs/dmc/__init__.py
index 430665fe..1b3d526a 100644
--- a/carl/envs/dmc/__init__.py
+++ b/carl/envs/dmc/__init__.py
@@ -2,6 +2,7 @@
# Contexts and bounds by name
from carl.envs.dmc.carl_dm_finger import CARLDmcFingerEnv
from carl.envs.dmc.carl_dm_fish import CARLDmcFishEnv
+from carl.envs.dmc.carl_dm_pointmass import CARLDmcPointMassEnv
from carl.envs.dmc.carl_dm_quadruped import CARLDmcQuadrupedEnv
from carl.envs.dmc.carl_dm_walker import CARLDmcWalkerEnv
@@ -10,4 +11,5 @@
"CARLDmcFishEnv",
"CARLDmcQuadrupedEnv",
"CARLDmcWalkerEnv",
+ "CARLDmcPointMassEnv",
]
diff --git a/carl/envs/dmc/carl_dm_finger.py b/carl/envs/dmc/carl_dm_finger.py
index b2604fec..ce1ef575 100644
--- a/carl/envs/dmc/carl_dm_finger.py
+++ b/carl/envs/dmc/carl_dm_finger.py
@@ -7,12 +7,13 @@
class CARLDmcFingerEnv(CARLDmcEnv):
domain = "finger"
task = "spin_context"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
return {
"gravity": UniformFloatContextFeature(
- "gravity", lower=-np.inf, upper=-0.1, default_value=-9.81
+ "gravity", lower=0.1, upper=np.inf, default_value=9.81
),
"friction_torsional": UniformFloatContextFeature(
"friction_torsional", lower=0, upper=np.inf, default_value=1.0
diff --git a/carl/envs/dmc/carl_dm_fish.py b/carl/envs/dmc/carl_dm_fish.py
index 619364a8..669c88dc 100644
--- a/carl/envs/dmc/carl_dm_fish.py
+++ b/carl/envs/dmc/carl_dm_fish.py
@@ -7,12 +7,13 @@
class CARLDmcFishEnv(CARLDmcEnv):
domain = "fish"
task = "swim_context"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
return {
"gravity": UniformFloatContextFeature(
- "gravity", lower=-np.inf, upper=-0.1, default_value=-9.81
+ "gravity", lower=0.1, upper=np.inf, default_value=9.81
),
"friction_torsional": UniformFloatContextFeature(
"friction_torsional", lower=0, upper=np.inf, default_value=1.0
diff --git a/carl/envs/dmc/carl_dm_pointmass.py b/carl/envs/dmc/carl_dm_pointmass.py
new file mode 100644
index 00000000..dc02d434
--- /dev/null
+++ b/carl/envs/dmc/carl_dm_pointmass.py
@@ -0,0 +1,74 @@
+import numpy as np
+
+from carl.context.context_space import ContextFeature, UniformFloatContextFeature
+from carl.envs.dmc.carl_dmcontrol import CARLDmcEnv
+
+
+class CARLDmcPointMassEnv(CARLDmcEnv):
+ domain = "pointmass"
+ task = "easy_pointmass"
+
+ @staticmethod
+ def get_context_features() -> dict[str, ContextFeature]:
+ return {
+ "gravity": UniformFloatContextFeature(
+ "gravity", lower=-np.inf, upper=-0.1, default_value=-9.81
+ ),
+ "friction_torsional": UniformFloatContextFeature(
+ "friction_torsional", lower=0, upper=np.inf, default_value=1.0
+ ),
+ "friction_rolling": UniformFloatContextFeature(
+ "friction_rolling", lower=0, upper=np.inf, default_value=1.0
+ ),
+ "friction_tangential": UniformFloatContextFeature(
+ "friction_tangential", lower=0, upper=np.inf, default_value=1.0
+ ),
+ "timestep": UniformFloatContextFeature(
+ "timestep", lower=0.001, upper=0.1, default_value=0.004
+ ),
+ "joint_damping": UniformFloatContextFeature(
+ "joint_damping", lower=0.0, upper=np.inf, default_value=1.0
+ ),
+ "joint_stiffness": UniformFloatContextFeature(
+ "joint_stiffness", lower=0.0, upper=np.inf, default_value=0.0
+ ),
+ "actuator_strength": UniformFloatContextFeature(
+ "actuator_strength", lower=0.0, upper=np.inf, default_value=1.0
+ ),
+ "density": UniformFloatContextFeature(
+ "density", lower=0.0, upper=np.inf, default_value=5000.0
+ ),
+ "viscosity": UniformFloatContextFeature(
+ "viscosity", lower=0.0, upper=np.inf, default_value=0.0
+ ),
+ "geom_density": UniformFloatContextFeature(
+ "geom_density", lower=0.0, upper=np.inf, default_value=1.0
+ ),
+ "wind_x": UniformFloatContextFeature(
+ "wind_x", lower=-np.inf, upper=np.inf, default_value=0.0
+ ),
+ "wind_y": UniformFloatContextFeature(
+ "wind_y", lower=-np.inf, upper=np.inf, default_value=0.0
+ ),
+ "wind_z": UniformFloatContextFeature(
+ "wind_z", lower=-np.inf, upper=np.inf, default_value=0.0
+ ),
+ "mass": UniformFloatContextFeature(
+ "mass", lower=0.0, upper=np.inf, default_value=0.3
+ ),
+ "starting_x": UniformFloatContextFeature(
+ "starting_x", lower=-np.inf, upper=np.inf, default_value=0.14
+ ),
+ "starting_y": UniformFloatContextFeature(
+ "starting_y", lower=-np.inf, upper=np.inf, default_value=0.14
+ ),
+ "target_x": UniformFloatContextFeature(
+ "target_x", lower=-np.inf, upper=np.inf, default_value=0.0
+ ),
+ "target_y": UniformFloatContextFeature(
+ "target_y", lower=-np.inf, upper=np.inf, default_value=0.0
+ ),
+ "area_size": UniformFloatContextFeature(
+ "area_size", lower=-np.inf, upper=np.inf, default_value=0.6
+ ),
+ }
diff --git a/carl/envs/dmc/carl_dm_quadruped.py b/carl/envs/dmc/carl_dm_quadruped.py
index 8d1fbdd2..697750d0 100644
--- a/carl/envs/dmc/carl_dm_quadruped.py
+++ b/carl/envs/dmc/carl_dm_quadruped.py
@@ -7,12 +7,13 @@
class CARLDmcQuadrupedEnv(CARLDmcEnv):
domain = "quadruped"
task = "walk_context"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
return {
"gravity": UniformFloatContextFeature(
- "gravity", lower=-np.inf, upper=-0.1, default_value=-9.81
+ "gravity", lower=0.1, upper=np.inf, default_value=9.81
),
"friction_torsional": UniformFloatContextFeature(
"friction_torsional", lower=0, upper=np.inf, default_value=1.0
diff --git a/carl/envs/dmc/carl_dm_walker.py b/carl/envs/dmc/carl_dm_walker.py
index 97b6d9b1..9e88e051 100644
--- a/carl/envs/dmc/carl_dm_walker.py
+++ b/carl/envs/dmc/carl_dm_walker.py
@@ -7,12 +7,13 @@
class CARLDmcWalkerEnv(CARLDmcEnv):
domain = "walker"
task = "walk_context"
+ metadata = {"render_modes": []}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
return {
"gravity": UniformFloatContextFeature(
- "gravity", lower=-np.inf, upper=-0.1, default_value=-9.81
+ "gravity", lower=0.1, upper=np.inf, default_value=9.81
),
"friction_torsional": UniformFloatContextFeature(
"friction_torsional", lower=0, upper=np.inf, default_value=1.0
diff --git a/carl/envs/dmc/dmc_tasks/pointmass.py b/carl/envs/dmc/dmc_tasks/pointmass.py
new file mode 100644
index 00000000..1eb80db5
--- /dev/null
+++ b/carl/envs/dmc/dmc_tasks/pointmass.py
@@ -0,0 +1,271 @@
+# flake8: noqa: E501
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Finger Domain."""
+from __future__ import annotations
+
+from typing import Any
+
+from multiprocessing.sharedctypes import Value
+
+import numpy as np
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.rl import control # type: ignore
+from dm_control.suite.point_mass import ( # type: ignore
+ _DEFAULT_TIME_LIMIT,
+ SUITE,
+ Physics,
+ PointMass,
+ get_model_and_assets,
+)
+
+from carl.envs.dmc.dmc_tasks.utils import adapt_context # type: ignore
+from carl.utils.types import Context
+
+
+def check_constraints(
+ mass,
+ starting_x,
+ starting_y,
+ target_x,
+ target_y,
+ area_size,
+) -> None:
+ if (
+ starting_x >= area_size / 4
+ or starting_y >= area_size / 4
+ or starting_x <= -area_size / 4
+ or starting_y <= -area_size / 4
+ ):
+ raise ValueError(
+ f"The starting points are located outside of the grid. Choose a value lower than {area_size/4}."
+ )
+
+ if (
+ target_x >= area_size / 4
+ or target_y >= area_size / 4
+ or target_x <= -area_size / 4
+ or target_y <= -area_size / 4
+ ):
+ raise ValueError(
+ f"The target points are located outside of the grid. Choose a value lower than {area_size/4}."
+ )
+
+
+def make_model(
+ mass: float = 0.3,
+ starting_x: float = 0.0,
+ starting_y: float = 0.0,
+ target_x: float = 0.0,
+ target_y: float = 0.0,
+ area_size: float = 0.6,
+ **kwargs: Any,
+) -> bytes:
+ check_constraints(
+ mass=mass,
+ starting_x=starting_x,
+ starting_y=starting_y,
+ target_x=target_x,
+ target_y=target_y,
+ area_size=area_size,
+ )
+
+ xml_string = f"""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+ xml_string_bytes = xml_string.encode()
+ return xml_string_bytes
+
+
+def random_limited_quaternion(random, limit):
+ """Generates a random quaternion limited to the specified rotations."""
+ axis = random.randn(3)
+ axis /= np.linalg.norm(axis)
+ angle = random.rand() * limit
+
+ quaternion = np.zeros(4)
+ mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle)
+
+ return quaternion
+
+
+class ContextualPointMass(PointMass):
+ starting_x: float = 0.2
+ starting_y: float = 0.2
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ If _randomize_gains is True, the relationship between the controls and
+ the joints is randomized, so that each control actuates a random linear
+ combination of joints.
+
+ Args:
+ physics: An instance of `mujoco.Physics`.
+ """
+ if self._randomize_gains:
+ dir1 = self.random.randn(2)
+ dir1 /= np.linalg.norm(dir1)
+ # Find another actuation direction that is not 'too parallel' to dir1.
+ parallel = True
+ while parallel:
+ dir2 = self.random.randn(2)
+ dir2 /= np.linalg.norm(dir2)
+ parallel = abs(np.dot(dir1, dir2)) > 0.9
+ physics.model.wrap_prm[[0, 1]] = dir1
+ physics.model.wrap_prm[[2, 3]] = dir2
+ super().initialize_episode(physics)
+ self.randomize_limited_and_rotational_joints(physics, self.random)
+
+ def randomize_limited_and_rotational_joints(self, physics, random=None):
+ random = random or np.random
+
+ hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
+ slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
+ ball = mjbindings.enums.mjtJoint.mjJNT_BALL
+ free = mjbindings.enums.mjtJoint.mjJNT_FREE
+
+ qpos = physics.named.data.qpos
+
+ for joint_id in range(physics.model.njnt):
+ joint_name = physics.model.id2name(joint_id, "joint")
+ joint_type = physics.model.jnt_type[joint_id]
+ is_limited = physics.model.jnt_limited[joint_id]
+ range_min, range_max = physics.model.jnt_range[joint_id]
+
+ if is_limited:
+ if joint_type == hinge or joint_type == slide:
+ if "root_x" in joint_name:
+ qpos[joint_name] = self.starting_x
+ elif "root_y" in joint_name:
+ qpos[joint_name] = self.starting_y
+ else:
+ qpos[joint_name] = random.uniform(range_min, range_max)
+
+ elif joint_type == ball:
+ qpos[joint_name] = random_limited_quaternion(random, range_max)
+
+ else:
+ if joint_type == hinge:
+ qpos[joint_name] = random.uniform(-np.pi, np.pi)
+
+ elif joint_type == ball:
+ quat = random.randn(4)
+ quat /= np.linalg.norm(quat)
+ qpos[joint_name] = quat
+
+ elif joint_type == free:
+ # this should be random.randn, but changing it now could significantly
+ # affect benchmark results.
+ quat = random.rand(4)
+ quat /= np.linalg.norm(quat)
+ qpos[joint_name][3:] = quat
+
+
+@SUITE.add("benchmarking") # type: ignore[misc]
+def easy_pointmass(
+ context: Context = {},
+ context_mask: list = [],
+ time_limit: float = _DEFAULT_TIME_LIMIT,
+ random: np.random.RandomState | int | None = None,
+ environment_kwargs: dict | None = None,
+) -> control.Environment:
+ """No randomization."""
+ xml_string, assets = get_model_and_assets()
+ xml_string = make_model(**context)
+ if context != {}:
+ xml_string = adapt_context(xml_string=xml_string, context=context)
+ physics = Physics.from_xml_string(xml_string, assets)
+ task = ContextualPointMass(randomize_gains=False, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics,
+ task,
+ time_limit=time_limit,
+ **environment_kwargs,
+ )
+
+
+@SUITE.add("benchmarking") # type: ignore[misc]
+def hard_pointmass(
+ context: Context = {},
+ context_mask: list = [],
+ time_limit: float = _DEFAULT_TIME_LIMIT,
+ random: np.random.RandomState | int | None = None,
+ environment_kwargs: dict | None = None,
+) -> control.Environment:
+ """Randomized initializations."""
+ xml_string, assets = get_model_and_assets()
+ xml_string = make_model(**context)
+ if context != {}:
+ xml_string = adapt_context(xml_string=xml_string, context=context)
+ physics = Physics.from_xml_string(xml_string, assets)
+ task = ContextualPointMass(randomize_gains=True, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics,
+ task,
+ time_limit=time_limit,
+ **environment_kwargs,
+ )
diff --git a/carl/envs/dmc/dmc_tasks/utils.py b/carl/envs/dmc/dmc_tasks/utils.py
index 7482dfbe..08abc366 100644
--- a/carl/envs/dmc/dmc_tasks/utils.py
+++ b/carl/envs/dmc/dmc_tasks/utils.py
@@ -122,17 +122,21 @@ def adapt_context(xml_string: bytes, context: Context) -> bytes:
# find option settings and override them if they exist, otherwise create new option
option = mjcf.find(".//option")
+ import logging
+
if option is None:
option = etree.Element("option")
mjcf.append(option)
if "gravity" in context:
gravity = option.get("gravity")
+ logging.info(gravity)
if gravity is not None:
g = gravity.split(" ")
gravity = " ".join([g[0], g[1], str(-context["gravity"])])
else:
- gravity = " ".join(["0", "0", str(-context["gravity"])])
+ gravity = " ".join(["0", "0", f"{str(-context['gravity'])}"])
+ logging.info(gravity)
option.set("gravity", gravity)
if "wind_x" in context and "wind_y" in context and "wind_z" in context:
diff --git a/carl/envs/dmc/loader.py b/carl/envs/dmc/loader.py
index 0633fabd..709ab568 100644
--- a/carl/envs/dmc/loader.py
+++ b/carl/envs/dmc/loader.py
@@ -8,6 +8,7 @@
from carl.envs.dmc.dmc_tasks import ( # type: ignore [import] # noqa: F401
finger,
fish,
+ pointmass,
quadruped,
walker,
)
diff --git a/carl/envs/dmc/wrappers.py b/carl/envs/dmc/wrappers.py
index 7ac9a059..665e5c26 100644
--- a/carl/envs/dmc/wrappers.py
+++ b/carl/envs/dmc/wrappers.py
@@ -1,10 +1,10 @@
from typing import Any, Optional, Tuple, TypeVar, Union
import dm_env # type: ignore
-import gym
+import gymnasium as gym
import numpy as np
from dm_env import StepType
-from gym import spaces
+from gymnasium import spaces
ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
diff --git a/carl/envs/gymnasium/__init__.py b/carl/envs/gymnasium/__init__.py
index 62ba5092..7df661e1 100644
--- a/carl/envs/gymnasium/__init__.py
+++ b/carl/envs/gymnasium/__init__.py
@@ -1,3 +1,8 @@
+# flake8: noqa: F401
+# Modular imports
+import importlib.util as iutil
+import warnings
+
from carl.envs.gymnasium.classic_control import (
CARLAcrobot,
CARLCartPole,
@@ -13,3 +18,35 @@
"CARLMountainCarContinuous",
"CARLPendulum",
]
+
+
+def check_spec(spec_name: str) -> bool:
+ """Check if the spec is installed
+
+ Parameters
+ ----------
+ spec_name : str
+ Name of package that is necessary for the environment suite.
+
+ Returns
+ -------
+ bool
+ Whether the spec was found.
+ """
+ spec = iutil.find_spec(spec_name)
+ found = spec is not None
+ if not found:
+ with warnings.catch_warnings():
+ warnings.simplefilter("once")
+ warnings.warn(
+ f"Module {spec_name} not found. If you want to use these environments, please follow the installation guide."
+ )
+ return found
+
+
+# Environment loading
+found = check_spec("Box2D")
+if found:
+ from carl.envs.gymnasium.box2d import *
+
+ __all__ += ["CARLBipedalWalker", "CARLLunarLander", "CARLVehicleRacing"]
diff --git a/carl/envs/gymnasium/box2d/carl_bipedal_walker.py b/carl/envs/gymnasium/box2d/carl_bipedal_walker.py
index 2ea9f4dc..a36cd62c 100644
--- a/carl/envs/gymnasium/box2d/carl_bipedal_walker.py
+++ b/carl/envs/gymnasium/box2d/carl_bipedal_walker.py
@@ -14,6 +14,7 @@
class CARLBipedalWalker(CARLGymnasiumEnv):
env_name: str = "BipedalWalker-v3"
+ metadata = {"render.modes": ["human", "rgb_array"]}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/gymnasium/box2d/carl_lunarlander.py b/carl/envs/gymnasium/box2d/carl_lunarlander.py
index 9f3c59f5..52dfcfda 100644
--- a/carl/envs/gymnasium/box2d/carl_lunarlander.py
+++ b/carl/envs/gymnasium/box2d/carl_lunarlander.py
@@ -14,6 +14,7 @@
class CARLLunarLander(CARLGymnasiumEnv):
env_name: str = "LunarLander-v2"
+ metadata = {"render.modes": ["human", "rgb_array"]}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/gymnasium/box2d/carl_vehicle_racing.py b/carl/envs/gymnasium/box2d/carl_vehicle_racing.py
index dcf61d75..0db60a60 100644
--- a/carl/envs/gymnasium/box2d/carl_vehicle_racing.py
+++ b/carl/envs/gymnasium/box2d/carl_vehicle_racing.py
@@ -209,6 +209,7 @@ def render_if_min(value, points, color):
class CARLVehicleRacing(CARLGymnasiumEnv):
env_name: str = "CustomCarRacing-v2"
+ metadata = {"render.modes": ["human", "rgb_array"]}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/gymnasium/carl_gymnasium_env.py b/carl/envs/gymnasium/carl_gymnasium_env.py
index 856e79b7..fb96d6d8 100644
--- a/carl/envs/gymnasium/carl_gymnasium_env.py
+++ b/carl/envs/gymnasium/carl_gymnasium_env.py
@@ -10,10 +10,10 @@
try:
pygame.display.init()
-except:
- import os
+except: # pragma: no cover
+ import os # pragma: no cover
- os.environ["SDL_VIDEODRIVER"] = "dummy"
+ os.environ["SDL_VIDEODRIVER"] = "dummy" # pragma: no cover
class CARLGymnasiumEnv(CARLEnv):
diff --git a/carl/envs/gymnasium/classic_control/carl_acrobot.py b/carl/envs/gymnasium/classic_control/carl_acrobot.py
index d81a7534..01d66b9a 100644
--- a/carl/envs/gymnasium/classic_control/carl_acrobot.py
+++ b/carl/envs/gymnasium/classic_control/carl_acrobot.py
@@ -10,6 +10,7 @@
class CARLAcrobot(CARLGymnasiumEnv):
env_name: str = "Acrobot-v1"
+ metadata = {"render.modes": ["human", "rgb_array"]}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/gymnasium/classic_control/carl_cartpole.py b/carl/envs/gymnasium/classic_control/carl_cartpole.py
index ff8c7a31..93e44b9f 100644
--- a/carl/envs/gymnasium/classic_control/carl_cartpole.py
+++ b/carl/envs/gymnasium/classic_control/carl_cartpole.py
@@ -10,6 +10,7 @@
class CARLCartPole(CARLGymnasiumEnv):
env_name: str = "CartPole-v1"
+ metadata = {"render.modes": ["human", "rgb_array"]}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/gymnasium/classic_control/carl_mountaincar.py b/carl/envs/gymnasium/classic_control/carl_mountaincar.py
index dcde2e77..2ea59621 100644
--- a/carl/envs/gymnasium/classic_control/carl_mountaincar.py
+++ b/carl/envs/gymnasium/classic_control/carl_mountaincar.py
@@ -10,6 +10,7 @@
class CARLMountainCar(CARLGymnasiumEnv):
env_name: str = "MountainCar-v0"
+ metadata = {"render.modes": ["human", "rgb_array"]}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/gymnasium/classic_control/carl_mountaincarcontinuous.py b/carl/envs/gymnasium/classic_control/carl_mountaincarcontinuous.py
index 155823b2..3aeab6d9 100644
--- a/carl/envs/gymnasium/classic_control/carl_mountaincarcontinuous.py
+++ b/carl/envs/gymnasium/classic_control/carl_mountaincarcontinuous.py
@@ -10,6 +10,7 @@
class CARLMountainCarContinuous(CARLGymnasiumEnv):
env_name: str = "MountainCarContinuous-v0"
+ metadata = {"render.modes": ["human", "rgb_array"]}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/gymnasium/classic_control/carl_pendulum.py b/carl/envs/gymnasium/classic_control/carl_pendulum.py
index 4148886a..10226dd8 100644
--- a/carl/envs/gymnasium/classic_control/carl_pendulum.py
+++ b/carl/envs/gymnasium/classic_control/carl_pendulum.py
@@ -10,6 +10,7 @@
class CARLPendulum(CARLGymnasiumEnv):
env_name: str = "Pendulum-v1"
+ metadata = {"render_modes": ["human", "rgb_array"]}
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
diff --git a/carl/envs/mario/carl_mario.py b/carl/envs/mario/carl_mario.py
index 75fae6ca..2b4c5d09 100644
--- a/carl/envs/mario/carl_mario.py
+++ b/carl/envs/mario/carl_mario.py
@@ -22,6 +22,11 @@
class CARLMarioEnv(CARLEnv):
+ metadata = {
+ "render_modes": ["rgb_array", "tiny_rgb_array"],
+ "render_fps": 24,
+ }
+
def __init__(
self,
env: MarioEnv = None,
diff --git a/changelog.md b/changelog.md
index 76d89dae..d8e60288 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,3 +1,9 @@
+# 1.1.0
+- increased test coverage
+- smaller bug fixes
+- added DMC pointmass env
+- added goal & language goal options for Brax
+
# 1.0.0
Major overhaul of the CARL environment
- Contexts are stored in each environment's class
diff --git a/docs/source/environments/data/screenshots/pointmass.jpg b/docs/source/environments/data/screenshots/pointmass.jpg
new file mode 100644
index 00000000..216eb985
Binary files /dev/null and b/docs/source/environments/data/screenshots/pointmass.jpg differ
diff --git a/examples/brax_with_goals.ipynb b/examples/brax_with_goals.ipynb
new file mode 100644
index 00000000..c3a68c1f
--- /dev/null
+++ b/examples/brax_with_goals.ipynb
@@ -0,0 +1,153 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/theeimer/anaconda3/envs/carl/lib/python3.9/site-packages/carl/envs/__init__.py:31: UserWarning: Module py4j not found. If you want to use these environments, please follow the installation guide.\n",
+ " warnings.warn(\n",
+ "/Users/theeimer/anaconda3/envs/carl/lib/python3.9/site-packages/carl/envs/__init__.py:31: UserWarning: Module distance not found. If you want to use these environments, please follow the installation guide.\n",
+ " warnings.warn(\n",
+ "/Users/theeimer/anaconda3/envs/carl/lib/python3.9/site-packages/carl/__init__.py:55: UserWarning: Module py4j not found. If you want to use these environments,\n",
+ " please follow the installation guide.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "from carl.context.context_space import NormalFloatContextFeature, CategoricalContextFeature\n",
+ "from carl.context.sampler import ContextSampler\n",
+ "from carl.envs import CARLBraxAnt, CARLBraxPusher\n",
+ "from carl.envs.brax.brax_walker_goal_wrapper import directions\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{0: {'gravity': -9.8, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0, 'target_direction': 112, 'target_distance': 8.957275946170714}, 1: {'gravity': -9.8, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0, 'target_direction': 334, 'target_distance': 11.769924447869895}, 2: {'gravity': -9.8, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0, 'target_direction': 332, 'target_distance': 11.066118529857778}, 3: {'gravity': -9.8, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0, 'target_direction': 112, 'target_distance': 9.294123460239488}, 4: {'gravity': -9.8, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0, 'target_direction': 14, 'target_distance': 12.345200778471055}}\n"
+ ]
+ }
+ ],
+ "source": [
+ "seed = 0\n",
+ "context_distributions = [NormalFloatContextFeature(\"target_distance\", mu=9.8, sigma=1), CategoricalContextFeature(\"target_direction\", choices=directions)]\n",
+ "context_sampler = ContextSampler(\n",
+ " context_distributions=context_distributions,\n",
+ " context_space=CARLBraxAnt.get_context_space(),\n",
+ " seed=seed,\n",
+ " )\n",
+ "contexts = context_sampler.sample_contexts(n_contexts=5)\n",
+ "print(contexts)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "KeyError",
+ "evalue": "'target_direction'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
+ "\u001b[1;32m/Users/theeimer/Documents/git/CARL/examples/brax_with_goals.ipynb Cell 3\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m env \u001b[39m=\u001b[39m CARLBraxAnt(contexts\u001b[39m=\u001b[39;49mcontexts, use_language_goals\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n\u001b[1;32m 2\u001b[0m env\u001b[39m.\u001b[39mreset()\n\u001b[1;32m 3\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mCurrent context ID: \u001b[39m\u001b[39m{\u001b[39;00menv\u001b[39m.\u001b[39mcontext_id\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
+ "File \u001b[0;32m~/anaconda3/envs/carl/lib/python3.9/site-packages/carl/envs/brax/carl_brax_env.py:207\u001b[0m, in \u001b[0;36mCARLBraxEnv.__init__\u001b[0;34m(self, env, batch_size, contexts, obs_context_features, obs_context_as_dict, context_selector, context_selector_kwargs, **kwargs)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[39m# The observation space also needs to from gymnasium\u001b[39;00m\n\u001b[1;32m 201\u001b[0m env\u001b[39m.\u001b[39mobservation_space \u001b[39m=\u001b[39m gymnasium\u001b[39m.\u001b[39mspaces\u001b[39m.\u001b[39mBox(\n\u001b[1;32m 202\u001b[0m low\u001b[39m=\u001b[39menv\u001b[39m.\u001b[39mobservation_space\u001b[39m.\u001b[39mlow,\n\u001b[1;32m 203\u001b[0m high\u001b[39m=\u001b[39menv\u001b[39m.\u001b[39mobservation_space\u001b[39m.\u001b[39mhigh,\n\u001b[1;32m 204\u001b[0m dtype\u001b[39m=\u001b[39mnp\u001b[39m.\u001b[39mfloat32,\n\u001b[1;32m 205\u001b[0m )\n\u001b[0;32m--> 207\u001b[0m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 208\u001b[0m env\u001b[39m=\u001b[39;49menv,\n\u001b[1;32m 209\u001b[0m contexts\u001b[39m=\u001b[39;49mcontexts,\n\u001b[1;32m 210\u001b[0m obs_context_features\u001b[39m=\u001b[39;49mobs_context_features,\n\u001b[1;32m 211\u001b[0m obs_context_as_dict\u001b[39m=\u001b[39;49mobs_context_as_dict,\n\u001b[1;32m 212\u001b[0m context_selector\u001b[39m=\u001b[39;49mcontext_selector,\n\u001b[1;32m 213\u001b[0m context_selector_kwargs\u001b[39m=\u001b[39;49mcontext_selector_kwargs,\n\u001b[1;32m 214\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 215\u001b[0m )\n",
+ "File \u001b[0;32m~/anaconda3/envs/carl/lib/python3.9/site-packages/carl/envs/carl_env.py:110\u001b[0m, in \u001b[0;36mCARLEnv.__init__\u001b[0;34m(self, env, contexts, obs_context_features, obs_context_as_dict, context_selector, context_selector_kwargs, **kwargs)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 105\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 106\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mContext selector must be None or an AbstractSelector class or instance. \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 107\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mGot type \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mtype\u001b[39m(context_selector)\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 108\u001b[0m )\n\u001b[0;32m--> 110\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobservation_space: gymnasium\u001b[39m.\u001b[39mspaces\u001b[39m.\u001b[39mDict \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mget_observation_space(\n\u001b[1;32m 111\u001b[0m obs_context_feature_names\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mobs_context_features\n\u001b[1;32m 112\u001b[0m )\n",
+ "File \u001b[0;32m~/anaconda3/envs/carl/lib/python3.9/site-packages/carl/envs/carl_env.py:177\u001b[0m, in \u001b[0;36mCARLEnv.get_observation_space\u001b[0;34m(self, obs_context_feature_names)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Get the observation space for the context.\u001b[39;00m\n\u001b[1;32m 163\u001b[0m \n\u001b[1;32m 164\u001b[0m \u001b[39mParameters\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[39m underlying environment (\"state\") and for the context (\"context\").\u001b[39;00m\n\u001b[1;32m 175\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 176\u001b[0m context_space \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_context_space()\n\u001b[0;32m--> 177\u001b[0m obs_space_context \u001b[39m=\u001b[39m context_space\u001b[39m.\u001b[39;49mto_gymnasium_space(\n\u001b[1;32m 178\u001b[0m context_feature_names\u001b[39m=\u001b[39;49mobs_context_feature_names,\n\u001b[1;32m 179\u001b[0m as_dict\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mobs_context_as_dict,\n\u001b[1;32m 180\u001b[0m )\n\u001b[1;32m 182\u001b[0m obs_space \u001b[39m=\u001b[39m spaces\u001b[39m.\u001b[39mDict(\n\u001b[1;32m 183\u001b[0m {\n\u001b[1;32m 184\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mobs\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbase_observation_space,\n\u001b[1;32m 185\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mcontext\u001b[39m\u001b[39m\"\u001b[39m: obs_space_context,\n\u001b[1;32m 186\u001b[0m }\n\u001b[1;32m 187\u001b[0m )\n\u001b[1;32m 188\u001b[0m \u001b[39mreturn\u001b[39;00m obs_space\n",
+ "File \u001b[0;32m~/anaconda3/envs/carl/lib/python3.9/site-packages/carl/context/context_space.py:170\u001b[0m, in \u001b[0;36mContextSpace.to_gymnasium_space\u001b[0;34m(self, context_feature_names, as_dict)\u001b[0m\n\u001b[1;32m 167\u001b[0m context_space \u001b[39m=\u001b[39m {}\n\u001b[1;32m 169\u001b[0m \u001b[39mfor\u001b[39;00m cf_name \u001b[39min\u001b[39;00m context_feature_names:\n\u001b[0;32m--> 170\u001b[0m context_feature \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcontext_space[cf_name]\n\u001b[1;32m 171\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(context_feature, NumericalContextFeature):\n\u001b[1;32m 172\u001b[0m context_space[context_feature\u001b[39m.\u001b[39mname] \u001b[39m=\u001b[39m spaces\u001b[39m.\u001b[39mBox(\n\u001b[1;32m 173\u001b[0m low\u001b[39m=\u001b[39mcontext_feature\u001b[39m.\u001b[39mlower, high\u001b[39m=\u001b[39mcontext_feature\u001b[39m.\u001b[39mupper\n\u001b[1;32m 174\u001b[0m )\n",
+ "\u001b[0;31mKeyError\u001b[0m: 'target_direction'"
+ ]
+ }
+ ],
+ "source": [
+ "env = CARLBraxAnt(contexts=contexts, use_language_goals=True)\n",
+ "env.reset()\n",
+ "print(f\"Current context ID: {env.context_id}\")\n",
+ "print(f\"Current context: {env.context}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "action = env.action_space.sample()\n",
+ "state, reward, terminated, truncated, info = env.step(action)\n",
+ "done = terminated or truncated\n",
+ "plt.imshow(env.render())\n",
+ "print(state)\n",
+ "print(reward)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "context_distributions = [NormalFloatContextFeature(\"goal_position_x\", mu=9.8, sigma=1), NormalFloatContextFeature(\"goal_position_y\", mu=9.8, sigma=1)]\n",
+ "context_sampler = ContextSampler(\n",
+ " context_distributions=context_distributions,\n",
+ " context_space=CARLBraxPusher.get_context_space(),\n",
+ " seed=seed,\n",
+ " )\n",
+ "contexts = context_sampler.sample_contexts(n_contexts=5)\n",
+ "print(contexts)\n",
+ "env = CARLBraxPusher(contexts)\n",
+ "env.reset()\n",
+ "print(f\"Current context ID: {env.context_id}\")\n",
+ "print(f\"Current context: {env.context}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "action = env.action_space.sample()\n",
+ "state, reward, terminated, truncated, info = env.step(action)\n",
+ "done = terminated or truncated\n",
+ "plt.imshow(env.render())\n",
+ "print(state)\n",
+ "print(reward)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "carl",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/carl_with_sb3.py b/examples/carl_with_sb3.py
new file mode 100644
index 00000000..01109ef4
--- /dev/null
+++ b/examples/carl_with_sb3.py
@@ -0,0 +1,38 @@
+import carl
+import gymnasium as gym
+from gymnasium.wrappers import FlattenObservation
+from stable_baselines3 import DQN
+from stable_baselines3.common.evaluation import evaluate_policy
+
+from carl.envs import CARLLunarLander
+from carl.context.context_space import NormalFloatContextFeature
+from carl.context.sampler import ContextSampler
+
+# Create environment
+context_distributions = [NormalFloatContextFeature("GRAVITY_X", mu=9.8, sigma=1)]
+context_sampler = ContextSampler(
+ context_distributions=context_distributions,
+ context_space=CARLLunarLander.get_context_space(),
+ seed=42,
+)
+contexts = context_sampler.sample_contexts(n_contexts=5)
+
+print("Training contexts are:")
+print(contexts)
+
+env = gym.make("carl/CARLLunarLander-v0", render_mode="rgb_array", contexts=contexts)
+env = FlattenObservation(env)
+
+# Instantiate the agent
+model = DQN("MlpPolicy", env, verbose=1)
+# Train the agent and display a progress bar
+model.learn(total_timesteps=int(2e4), progress_bar=True)
+mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
+
+# Enjoy trained agent
+vec_env = model.get_env()
+obs = vec_env.reset()
+for i in range(1000):
+ action, _states = model.predict(obs, deterministic=True)
+ obs, rewards, dones, info = vec_env.step(action)
+ vec_env.render("human")
diff --git a/examples/demo_carracing.py b/examples/demo_carracing.py
index b6b96927..34bfc4f7 100644
--- a/examples/demo_carracing.py
+++ b/examples/demo_carracing.py
@@ -42,6 +42,7 @@ def register_input():
contexts = {i: {"VEHICLE_ID": i} for i in range(len(VEHICLE_NAMES))}
CARLVehicleRacing.render_mode = "human"
+ CARLVehicleRacing.render_mode = "human"
env = CARLVehicleRacing(contexts=contexts)
record_video = False
@@ -62,13 +63,14 @@ def register_input():
while True:
register_input()
s, r, truncated, terminated, info = env.step(a)
+ s, r, truncated, terminated, info = env.step(a)
time.sleep(0.025)
total_reward += r
- if steps % 200 == 0 or truncated or terminated:
+ if steps % 200 == 0 or terminated or truncated:
print("\naction " + str(["{:+0.2f}".format(x) for x in a]))
print("step {} total_reward {:+0.2f}".format(steps, total_reward))
steps += 1
env.render()
- if truncated or terminated or restart or not isopen:
+ if terminated or truncated or restart or not isopen:
break
env.close()
diff --git a/pyproject.toml b/pyproject.toml
index 588d7984..9d01dce0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,11 +4,18 @@
[tool.pytest.ini_options]
testpaths = ["test"]
minversion = "3.9"
-addopts = "--cov=carl"
+addopts="--cov=carl"
[tool.coverage.run]
branch = true
-context = "carl"
+include = ["carl/*"]
+omit = [
+ "*/mario/pcg_smb_env/*",
+ "*/rna/*",
+ "*/utils/doc_building/*",
+ "*/mario/models/*",
+ "__init__.py"
+]
[tool.coverage.report]
show_missing = true
@@ -19,6 +26,13 @@ exclude_lines = [
"raise NotImplementedError",
"if TYPE_CHECKING"
]
+omit = [
+ "*/mario/pcg_smb_env/*",
+ "*/rna/*",
+ "*/utils/doc_building/*",
+ "*/mario/models/*",
+ "__init__.py"
+]
[tool.black]
target-version = ['py39']
diff --git a/setup.py b/setup.py
index d1f48f8b..fd711596 100644
--- a/setup.py
+++ b/setup.py
@@ -25,8 +25,9 @@ def read_file(filepath: str) -> str:
"gymnasium[box2d]>=0.27.1",
],
"brax": [
- "brax>=0.9.1",
+ "brax==0.9.3",
"protobuf>=3.17.3",
+ "mujoco==3.0.1"
],
"dm_control": [
"dm_control>=1.0.3",
@@ -56,6 +57,9 @@ def read_file(filepath: str) -> str:
"sphinx-autoapi>=1.8.4",
"automl-sphinx-theme>=0.1.9",
],
+ "examples": [
+ "stable-baselines3",
+ ]
}
setuptools.setup(
diff --git a/test/test_all_envs.py b/test/test_all_envs.py
index 141b06b4..68f57ccf 100644
--- a/test/test_all_envs.py
+++ b/test/test_all_envs.py
@@ -16,6 +16,7 @@ def test_init_all_envs(self):
env = ( # noqa: F841 local variable is assigned to but never used
var()
)
+ env.reset()
except Exception as e:
print(f"Cannot instantiate {var} environment.")
raise e
diff --git a/test/test_box2d_envs.py b/test/test_box2d_envs.py
new file mode 100644
index 00000000..2cbabe51
--- /dev/null
+++ b/test/test_box2d_envs.py
@@ -0,0 +1,26 @@
+import inspect
+import unittest
+
+import carl.envs.gymnasium
+
+
+class TestBox2DEnvs(unittest.TestCase):
+ def test_envs(self):
+ envs = inspect.getmembers(carl.envs.gymnasium.box2d)
+
+ for env_name, env_obj in envs:
+ if inspect.isclass(env_obj) and "CARL" in env_name:
+ try:
+ env_obj.get_context_features()
+
+ env = env_obj()
+ env._progress_instance()
+ env._update_context()
+ env.reset()
+ except Exception as e:
+ print(f"Cannot instantiate {env_name} environment.")
+ raise e
+
+
+if __name__ == "__main__":
+ TestBox2DEnvs().test_envs()
diff --git a/test/test_brax_env.py b/test/test_brax_env.py
new file mode 100644
index 00000000..16f36de4
--- /dev/null
+++ b/test/test_brax_env.py
@@ -0,0 +1,27 @@
+import inspect
+import unittest
+
+import carl.envs.gymnasium
+
+
+class TestBraxEnvs(unittest.TestCase):
+ def test_envs(self):
+ envs = inspect.getmembers(carl.envs.brax)
+
+ for env_name, env_obj in envs:
+ if inspect.isclass(env_obj) and "CARL" in env_name:
+ try:
+ env_obj.get_context_features()
+
+ env = env_obj()
+ env._progress_instance()
+ env._update_context()
+ env.reset()
+
+ except Exception as e:
+ print(f"Cannot instantiate {env_name} environment.")
+ raise e
+
+
+if __name__ == "__main__":
+ TestBraxEnvs().test_envs()
diff --git a/test/test_context_bounds.py b/test/test_context_bounds.py
new file mode 100644
index 00000000..6c022b8d
--- /dev/null
+++ b/test/test_context_bounds.py
@@ -0,0 +1,63 @@
+import unittest
+
+import numpy as np
+
+from carl.context.utils import get_context_bounds
+
+
+class TestContextBounds(unittest.TestCase):
+ def test_context_bounds(self):
+ DEFAULT_CONTEXT = {
+ "min_position": -1.2, # unit?
+ "max_position": 0.6, # unit?
+ "max_speed": 0.07, # unit?
+ "goal_position": 0.5, # unit?
+ "goal_velocity": 0, # unit?
+ "force": 0.001, # unit?
+ "gravity": 0.0025, # unit?
+ "min_position_start": -0.6,
+ "max_position_start": -0.4,
+ "min_velocity_start": 0.0,
+ "max_velocity_start": 0.0,
+ }
+
+ CONTEXT_BOUNDS = {
+ "min_position": (-np.inf, np.inf, float),
+ "max_position": (-np.inf, np.inf, float),
+ "max_speed": (0, np.inf, float),
+ "goal_position": (-np.inf, np.inf, float),
+ "goal_velocity": (-np.inf, np.inf, float),
+ "force": (-np.inf, np.inf, float),
+ "gravity": (0, np.inf, float),
+ "min_position_start": (-np.inf, np.inf, float),
+ "max_position_start": (-np.inf, np.inf, float),
+ "min_velocity_start": (-np.inf, np.inf, float),
+ "max_velocity_start": (-np.inf, np.inf, float),
+ }
+
+ lower, upper = get_context_bounds(list(DEFAULT_CONTEXT.keys()), CONTEXT_BOUNDS)
+
+ self.assertEqual(
+ lower.all(),
+ np.array(
+ [
+ -np.inf,
+ -np.inf,
+ 0.0,
+ -np.inf,
+ -np.inf,
+ -np.inf,
+ 0.0,
+ -np.inf,
+ -np.inf,
+ -np.inf,
+ -np.inf,
+ ]
+ ).all(),
+ )
+
+ self.assertEqual(upper.all(), np.array([np.inf] * upper.shape[0]).all())
+
+
+if __name__ == "__main__":
+ TestContextBounds.test_context_bounds()
diff --git a/test/test_context_sampler.py b/test/test_context_sampler.py
index 111255ec..7a1b73ac 100644
--- a/test/test_context_sampler.py
+++ b/test/test_context_sampler.py
@@ -44,11 +44,23 @@ def test_init(self):
name="TestSampler",
)
+ with self.assertRaises(ValueError):
+ ContextSampler(
+ context_distributions=0,
+ context_space=self.cspace,
+ seed=0,
+ name="TestSampler",
+ )
+
def test_sample_contexts(self):
contexts = self.sampler.sample_contexts(n_contexts=3)
self.assertEqual(len(contexts), 3)
self.assertEqual(contexts[0]["gravity"], 9.8)
+ contexts = self.sampler.sample_contexts(n_contexts=1)
+ self.assertEqual(len(contexts), 1)
+ self.assertEqual(contexts[0]["gravity"], 9.8)
+
if __name__ == "__main__":
unittest.main()
diff --git a/test/test_context_space.py b/test/test_context_space.py
index 00478679..10c8c953 100644
--- a/test/test_context_space.py
+++ b/test/test_context_space.py
@@ -88,6 +88,26 @@ def test_verify_context(self):
is_valid = self.context_space.verify_context(context)
self.assertEqual(is_valid, False)
+ def test_sample(self):
+ context = self.context_space.sample_contexts(["gravity"], size=1)
+ is_valid = self.context_space.verify_context(context)
+ self.assertEqual(is_valid, True)
+
+ contexts = self.context_space.sample_contexts(["gravity"], size=10)
+ self.assertTrue(len(contexts) == 10)
+ for context in contexts:
+ is_valid = self.context_space.verify_context(context)
+ self.assertEqual(is_valid, True)
+
+ contexts = self.context_space.sample_contexts(None, size=10)
+ self.assertTrue(len(contexts) == 10)
+ for context in contexts:
+ is_valid = self.context_space.verify_context(context)
+ self.assertEqual(is_valid, True)
+
+ with self.assertRaises(ValueError):
+ self.context_space.sample_contexts(["false_feature"], size=0)
+
if __name__ == "__main__":
unittest.main()
diff --git a/test/test_dmc.py b/test/test_dmc.py
index 83a9c05c..5556b64c 100644
--- a/test/test_dmc.py
+++ b/test/test_dmc.py
@@ -1,16 +1,34 @@
-import unittest
-
+import pytest
+
+from carl.envs.dmc import (
+ CARLDmcFingerEnv,
+ CARLDmcFishEnv,
+ CARLDmcPointMassEnv,
+ CARLDmcQuadrupedEnv,
+ CARLDmcWalkerEnv,
+)
+from carl.envs.dmc.dmc_tasks.finger import check_constraints
+from carl.envs.dmc.dmc_tasks.finger import (
+ get_model_and_assets as get_finger_model_and_assets,
+)
from carl.envs.dmc.dmc_tasks.finger import (
- check_constraints,
spin_context,
turn_easy_context,
turn_hard_context,
)
+from carl.envs.dmc.dmc_tasks.pointmass import (
+ check_constraints as check_constraints_pointmass,
+)
+from carl.envs.dmc.dmc_tasks.pointmass import make_model as make_pointmass_model
+from carl.envs.dmc.dmc_tasks.quadruped import make_model as make_quadruped_model
from carl.envs.dmc.dmc_tasks.utils import adapt_context
+from carl.envs.dmc.dmc_tasks.walker import (
+ get_model_and_assets as get_walker_model_and_assets,
+)
from carl.envs.dmc.loader import load_dmc_env
-class TestDMCLoader(unittest.TestCase):
+class TestDMCLoader:
def test_load_classic_dmc_env(self):
_ = load_dmc_env(
domain_name="walker",
@@ -24,24 +42,24 @@ def test_load_context_dmc_env(self):
)
def test_load_unknowntask_dmc_env(self):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
_ = load_dmc_env(
domain_name="walker",
task_name="walk_context_blub",
)
def test_load_unknowndomain_dmc_env(self):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
_ = load_dmc_env(
domain_name="sdfsdf",
task_name="walk",
)
-class TestDmcEnvs(unittest.TestCase):
+class TestFinger:
def test_finger_constraints(self):
# Finger can reach spinner?
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
check_constraints(
limb_length_0=0.17,
limb_length_1=0.16,
@@ -49,7 +67,7 @@ def test_finger_constraints(self):
raise_error=True,
)
# Spinner collides with finger hinge?
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
check_constraints(
limb_length_0=0.17,
limb_length_1=0.16,
@@ -65,50 +83,40 @@ def test_finger_tasks(self):
_ = task(context=context)
-class TestDmcUtils(unittest.TestCase):
- def setUp(self) -> None:
- from carl.envs.dmc.carl_dm_finger import CARLDmcFingerEnv
- from carl.envs.dmc.dmc_tasks.finger import get_model_and_assets
-
- self.xml_string, _ = get_model_and_assets()
- self.default_context = CARLDmcFingerEnv.get_default_context()
+class TestDmcUtils:
+ def get_string_and_context(self):
+ xml_string, _ = get_finger_model_and_assets()
+ default_context = CARLDmcFingerEnv.get_default_context()
+ return xml_string, default_context
def test_adapt_context_no_context(self):
context = {}
- _ = adapt_context(xml_string=self.xml_string, context=context)
+ xml_string, _ = self.get_string_and_context()
+ _ = adapt_context(xml_string=xml_string, context=context)
def test_adapt_context_partialcontext(self):
context = {"gravity": 10}
- _ = adapt_context(xml_string=self.xml_string, context=context)
+ xml_string, _ = self.get_string_and_context()
+ _ = adapt_context(xml_string=xml_string, context=context)
def test_adapt_context_fullcontext(self):
# only continuous context features
- context = self.default_context
+ xml_string, context = self.get_string_and_context()
context["gravity"] *= 1.25
- _ = adapt_context(xml_string=self.xml_string, context=context)
+ _ = adapt_context(xml_string=xml_string, context=context)
def test_adapt_context_friction(self):
- from carl.envs.dmc.carl_dm_walker import CARLDmcWalkerEnv
- from carl.envs.dmc.dmc_tasks.walker import get_model_and_assets
-
- xml_string, _ = get_model_and_assets()
+ xml_string, _ = get_walker_model_and_assets()
context = CARLDmcWalkerEnv.get_default_context()
context["friction_tangential"] *= 1.3
_ = adapt_context(xml_string=xml_string, context=context)
-class TestQuadruped(unittest.TestCase):
- def setUp(self) -> None:
- pass
-
+class TestQuadruped:
def test_make_model(self):
- from carl.envs.dmc.dmc_tasks.quadruped import make_model
-
- _ = make_model(floor_size=1)
+ _ = make_quadruped_model(floor_size=1)
def test_instantiate_env_with_context(self):
- from carl.envs.dmc.carl_dm_quadruped import CARLDmcQuadrupedEnv
-
tasks = ["escape_context", "run_context", "walk_context", "fetch_context"]
for task in tasks:
_ = CARLDmcQuadrupedEnv(
@@ -119,3 +127,59 @@ def test_instantiate_env_with_context(self):
},
task=task,
)
+
+
+class TestFish:
+ def test_make_model(self):
+ _ = make_quadruped_model(floor_size=1)
+
+ def test_instantiate_env_with_context(self):
+ tasks = ["swim_context", "upright_context"]
+ for task in tasks:
+ _ = CARLDmcFishEnv(
+ contexts={
+ 0: {
+ "gravity": -10,
+ }
+ },
+ task=task,
+ )
+
+
+class TestPointmass:
+ def test_make_model(self):
+ _ = make_pointmass_model(floor_size=1)
+
+ def test_instantiate_env_with_context(self):
+ tasks = ["easy_pointmass", "hard_pointmass"]
+ for task in tasks:
+ _ = CARLDmcPointMassEnv(
+ contexts={
+ 0: {
+ "starting_x": 0.3,
+ }
+ },
+ task=task,
+ )
+
+ def test_constraints(self):
+ # Is starting point inside grid?
+ with pytest.raises(ValueError):
+ check_constraints_pointmass(
+ mass=0.3,
+ starting_x=0.3,
+ starting_y=0.3,
+ target_x=0.0,
+ target_y=0.0,
+ area_size=0.6,
+ )
+ # Is target inside grid?
+ with pytest.raises(ValueError):
+ check_constraints_pointmass(
+ mass=0.3,
+ starting_x=0.0,
+ starting_y=0.0,
+ target_x=0.3,
+ target_y=0.3,
+ area_size=0.6,
+ )
diff --git a/test/test_gymnasium_envs.py b/test/test_gymnasium_envs.py
index 5be182c2..246f002e 100644
--- a/test/test_gymnasium_envs.py
+++ b/test/test_gymnasium_envs.py
@@ -1,6 +1,9 @@
import inspect
import unittest
+import gymnasium as gym
+
+import carl
import carl.envs.gymnasium
@@ -12,7 +15,6 @@ def test_envs(self):
if inspect.isclass(env_obj) and "CARL" in env_name:
try:
env_obj.get_context_features()
-
env = env_obj()
env._progress_instance()
env._update_context()
@@ -21,5 +23,21 @@ def test_envs(self):
raise e
+class TestGymnasiumRegistration(unittest.TestCase):
+ def test_registration(self):
+ registered_envs = gym.envs.registration.registry.keys()
+ for e in carl.envs.__all__:
+ if "RNA" not in e and "Brax" not in e:
+ env_name = f"carl/{e}-v0"
+ self.assertTrue(env_name in registered_envs)
+
+ def test_make(self):
+ for e in carl.envs.__all__:
+ if "RNA" not in e and "Brax" not in e:
+ env_name = f"carl/{e}-v0"
+ env = gym.make(env_name)
+ self.assertTrue(isinstance(env, gym.Env))
+
+
if __name__ == "__main__":
- TestGymnasiumEnvs().test_envs()
+ unittest.main()
diff --git a/test/test_language_goals.py b/test/test_language_goals.py
new file mode 100644
index 00000000..0ad5d09b
--- /dev/null
+++ b/test/test_language_goals.py
@@ -0,0 +1,238 @@
+import unittest
+
+from carl.context.context_space import (
+ CategoricalContextFeature,
+ NormalFloatContextFeature,
+)
+from carl.context.sampler import ContextSampler
+from carl.envs import CARLBraxAnt, CARLBraxHalfcheetah
+from carl.envs.brax.brax_walker_goal_wrapper import (
+ BraxLanguageWrapper,
+ BraxWalkerGoalWrapper,
+)
+
+DIRECTIONS = [
+ 1, # north
+ 3, # south
+ 2, # east
+ 4, # west
+ 12,
+ 32,
+ 14,
+ 34,
+ 112,
+ 332,
+ 114,
+ 334,
+ 212,
+ 232,
+ 414,
+ 434,
+]
+
+
+class TestGoalSampling(unittest.TestCase):
+ def test_uniform_sampling(self):
+ context_distributions = [
+ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
+ CategoricalContextFeature("target_direction", choices=DIRECTIONS),
+ ]
+ context_sampler = ContextSampler(
+ context_distributions=context_distributions,
+ context_space=CARLBraxAnt.get_context_space(),
+ seed=0,
+ )
+ contexts = context_sampler.sample_contexts(n_contexts=10)
+ assert len(contexts.keys()) == 10
+ assert "target_distance" in contexts[0].keys(), "target_distance not in context"
+ assert (
+ "target_direction" in contexts[0].keys()
+ ), "target_direction not in context"
+ assert all(
+ [contexts[i]["target_direction"] in DIRECTIONS for i in range(10)]
+ ), "Not all directions are valid."
+ assert all(
+ [contexts[i]["target_distance"] <= 200 for i in range(10)]
+ ), "Not all distances are valid (too large)."
+ assert all(
+ [contexts[i]["target_distance"] >= 4 for i in range(10)]
+ ), "Not all distances are valid (too small)."
+
+ def test_normal_sampling(self):
+ context_distributions = [
+ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
+ CategoricalContextFeature("target_direction", choices=DIRECTIONS),
+ ]
+ context_sampler = ContextSampler(
+ context_distributions=context_distributions,
+ context_space=CARLBraxAnt.get_context_space(),
+ seed=0,
+ )
+ contexts = context_sampler.sample_contexts(n_contexts=10)
+ assert (
+ len(contexts.keys()) == 10
+ ), "Number of sampled contexts does not match the requested number."
+ assert "target_distance" in contexts[0].keys(), "target_distance not in context"
+ assert (
+ "target_direction" in contexts[0].keys()
+ ), "target_direction not in context"
+ assert all(
+ [contexts[i]["target_direction"] in DIRECTIONS for i in range(10)]
+ ), "Not all directions are valid."
+ assert all(
+ [contexts[i]["target_distance"] <= 200 for i in range(10)]
+ ), "Not all distances are valid (too large)."
+ assert all(
+ [contexts[i]["target_distance"] >= 4 for i in range(10)]
+ ), "Not all distances are valid (too small)."
+
+
+class TestGoalWrapper(unittest.TestCase):
+ def test_reset(self):
+ context_distributions = [
+ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
+ CategoricalContextFeature("target_direction", choices=DIRECTIONS),
+ ]
+ context_sampler = ContextSampler(
+ context_distributions=context_distributions,
+ context_space=CARLBraxAnt.get_context_space(),
+ seed=0,
+ )
+ contexts = context_sampler.sample_contexts(n_contexts=10)
+ env = CARLBraxAnt(contexts=contexts)
+
+ assert isinstance(env.env, BraxWalkerGoalWrapper)
+ assert env.position is None, "Position set before reset."
+
+ state, info = env.reset()
+ assert state is not None, "No state returned."
+ assert info is not None, "No info returned."
+ assert env.position is not None, "Position not set."
+
+ context_distributions = [
+ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
+ CategoricalContextFeature("target_direction", choices=DIRECTIONS),
+ ]
+ context_sampler = ContextSampler(
+ context_distributions=context_distributions,
+ context_space=CARLBraxHalfcheetah.get_context_space(),
+ seed=0,
+ )
+ contexts = context_sampler.sample_contexts(n_contexts=10)
+ env = CARLBraxHalfcheetah(contexts=contexts, use_language_goals=True)
+
+ assert isinstance(env.env, BraxLanguageWrapper), "Language wrapper not used."
+ assert env.position is None, "Position set before reset."
+
+ state, info = env.reset()
+ assert state is not None, "No state returned."
+ assert info is not None, "No info returned."
+ assert env.position is not None, "Position not set."
+
+ def test_reward_scale(self):
+ context_distributions = [
+ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
+ CategoricalContextFeature("target_direction", choices=DIRECTIONS),
+ ]
+ context_sampler = ContextSampler(
+ context_distributions=context_distributions,
+ context_space=CARLBraxAnt.get_context_space(),
+ seed=0,
+ )
+ contexts = context_sampler.sample_contexts(n_contexts=10)
+ env = CARLBraxAnt(contexts=contexts)
+
+ for _ in range(10):
+ env.reset()
+ for _ in range(10):
+ action = env.action_space.sample()
+ _, wrapped_reward, _, _, _ = env.step(action)
+ assert wrapped_reward >= 0, "Negative reward."
+
+ context_distributions = [
+ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
+ CategoricalContextFeature("target_direction", choices=DIRECTIONS),
+ ]
+ context_sampler = ContextSampler(
+ context_distributions=context_distributions,
+ context_space=CARLBraxHalfcheetah.get_context_space(),
+ seed=0,
+ )
+ contexts = context_sampler.sample_contexts(n_contexts=10)
+ env = CARLBraxHalfcheetah(contexts=contexts)
+
+ for _ in range(10):
+ env.reset()
+ for _ in range(10):
+ action = env.action_space.sample()
+ _, wrapped_reward, _, _, _ = env.step(action)
+ assert wrapped_reward >= 0, "Negative reward."
+
+
+class TestLanguageWrapper(unittest.TestCase):
+ def test_reset(self) -> None:
+ context_distributions = [
+ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
+ CategoricalContextFeature("target_direction", choices=DIRECTIONS),
+ ]
+ context_sampler = ContextSampler(
+ context_distributions=context_distributions,
+ context_space=CARLBraxAnt.get_context_space(),
+ seed=0,
+ )
+ contexts = context_sampler.sample_contexts(n_contexts=10)
+ env = CARLBraxAnt(contexts=contexts, use_language_goals=True)
+ state, info = env.reset()
+ assert type(state) is dict, "State is not a dictionary."
+ assert "obs" in state.keys(), "Observation not in state."
+ assert "goal" in state["obs"].keys(), "Goal not in observation."
+ assert type(state["obs"]["goal"]) is str, "Goal is not a string."
+ assert (
+ str(env.context["target_distance"]) in state["obs"]["goal"]
+ ), "Distance not in goal."
+ assert "north north east" in state["obs"]["goal"], "Direction not in goal."
+ assert info is not None, "No info returned."
+
+ def test_step(self):
+ context_distributions = [
+ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
+ CategoricalContextFeature("target_direction", choices=DIRECTIONS),
+ ]
+ context_sampler = ContextSampler(
+ context_distributions=context_distributions,
+ context_space=CARLBraxHalfcheetah.get_context_space(),
+ seed=0,
+ )
+ contexts = context_sampler.sample_contexts(n_contexts=10)
+ env = CARLBraxHalfcheetah(contexts=contexts, use_language_goals=True)
+ env.reset()
+ for _ in range(10):
+ action = env.action_space.sample()
+ state, _, _, _, _ = env.step(action)
+ assert type(state) is dict, "State is not a dictionary."
+ assert "obs" in state.keys(), "Observation not in state."
+ assert "goal" in state["obs"].keys(), "Goal not in observation."
+ assert type(state["obs"]["goal"]) is str, "Goal is not a string."
+ assert "north north east" in state["obs"]["goal"], "Direction not in goal."
+ assert (
+ str(env.context["target_distance"]) in state["obs"]["goal"]
+ ), "Distance not in goal."
+
+ context_distributions = [
+ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
+ CategoricalContextFeature("target_direction", choices=DIRECTIONS),
+ ]
+ context_sampler = ContextSampler(
+ context_distributions=context_distributions,
+ context_space=CARLBraxAnt.get_context_space(),
+ seed=0,
+ )
+ contexts = context_sampler.sample_contexts(n_contexts=10)
+ env = CARLBraxAnt(contexts=contexts)
+ env.reset()
+ for _ in range(10):
+ action = env.action_space.sample()
+ state, _, _, _, _ = env.step(action)
+ assert type(state) is dict, "State is not a dictionary."
+ assert "obs" in state.keys(), "Observation not in state."
+ assert "goal" not in state.keys(), "Goal in observation."
diff --git a/test/test_search_space_encoding.py b/test/test_search_space_encoding.py
new file mode 100644
index 00000000..2c1a4e84
--- /dev/null
+++ b/test/test_search_space_encoding.py
@@ -0,0 +1,81 @@
+import unittest
+
+from ConfigSpace import ConfigurationSpace
+from omegaconf import DictConfig
+
+from carl.context.search_space_encoding import search_space_to_config_space
+
+dict_space = {
+ "uniform_integer": (1, 10),
+ "uniform_float": (1.0, 10.0),
+ "categorical": ["a", "b", "c"],
+ "constant": 1337,
+}
+
+dict_space_2 = {
+ "hyperparameters": [
+ {
+ "name": "x0",
+ "type": "uniform_float",
+ "log": False,
+ "lower": -512.0,
+ "upper": 512.0,
+ "default": -3.0,
+ "q": None,
+ },
+ {
+ "name": "x1",
+ "type": "uniform_float",
+ "log": False,
+ "lower": -512.0,
+ "upper": 512.0,
+ "default": -4.0,
+ "q": None,
+ },
+ ],
+ "conditions": [],
+ "forbiddens": [],
+ "python_module_version": "0.4.17",
+ "json_format_version": 0.2,
+}
+
+str_space = """{
+ "uniform_integer": (1, 10),
+ "uniform_float": (1.0, 10.0),
+ "categorical": ["a", "b", "c"],
+ "constant": 1337,
+ }"""
+
+
+class TestSearchSpacEncoding(unittest.TestCase):
+ def setUp(self):
+ self.test_space = None
+ self.test_space = ConfigurationSpace(name="myspace", space=dict_space)
+ return super().setUp()
+
+ def test_ss_as_cs(self):
+ try:
+ search_space_to_config_space(self.test_space)
+ except Exception as e:
+ print(f"Cannot encode search space -- {self.test_space}.")
+ raise e
+
+ def test_ss_as_dictconfig(self):
+ try:
+ dict_space = DictConfig({"hyperparameters": {}})
+
+ search_space_to_config_space(dict_space)
+ except Exception as e:
+ print(f"Cannot encode search space -- {dict_space}.")
+ raise e
+
+ def test_ss_as_dict(self):
+ try:
+ search_space_to_config_space(dict_space_2)
+ except Exception as e:
+ print(f"Cannot encode search space -- {dict_space_2}.")
+ raise e
+
+
+if __name__ == "__main__":
+ unittest.main()