Skip to content

Commit

Permalink
tests running
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEimer committed Jan 9, 2024
1 parent 57ef8e5 commit ce69cc4
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 153 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ format-isort:
format: format-black format-isort

test:
$(PYTEST) test
$(PYTEST) --disable-warnings test

clean-doc:
$(MAKE) -C ${DOCDIR} clean
Expand Down
77 changes: 44 additions & 33 deletions carl/envs/brax/brax_walker_goal_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import gym
import numpy as np
from brax.io import mjcf
from etils import epath

STATE_INDICES = {
"CARLAnt": [13, 14],
"CARLHumanoid": [22, 23],
"CARLHalfcheetah": [14, 15],
"CARLHopper": [5, 6],
"CARLWalker2d": [8, 9],
"ant": [13, 14],
"humanoid": [22, 23],
"halfcheetah": [14, 15],
"hopper": [5, 6],
"walker2d": [8, 9],
}

DIRECTION_NAMES = {
Expand Down Expand Up @@ -51,15 +53,17 @@
class BraxWalkerGoalWrapper(gym.Wrapper):
"""Adds a positional goal to brax walker envs"""

def __init__(self, env) -> None:
def __init__(self, env, env_name, asset_path) -> None:
super().__init__(env)
self.env_name = env_name
if (
self.env.__class__.__name__ == "CARLHumanoid"
or self.env.__class__.__name__ == "CARLHalfcheetah"
or self.env.__class__.__name__ == "CARLHopper"
or self.env.__class__.__name__ == "CARLWalker2d"
self.env_name == "humanoid"
or self.env_name == "halfcheetah"
or self.env_name == "hopper"
or self.env_name == "walker2d"
):
self.env._forward_reward_weight = 0
self.context = None
self.position = None
self.goal_position = None
self.direction_values = {
Expand Down Expand Up @@ -101,60 +105,67 @@ def __init__(self, env) -> None:
],
212: [np.sin(22.5 * np.pi / 180), np.cos(22.5 * np.pi / 180)],
}
path = epath.resource_path("brax") / asset_path
sys = mjcf.load(path)
self.dt = sys.dt

def reset(self, return_info=False):
state, info = self.env.reset(info=True)
def reset(self, seed=None, options={}):
state, info = self.env.reset(seed=seed, options=options)
self.position = (0, 0)
self.goal_position = (
np.array(self.direction_values[self.context["target_direction"]])
* self.context["target_distance"]
)
if return_info:
info["success"] = 0
return state, info
else:
return state
info["success"] = 0
return state, info


def step(self, action):
state, _, done, info = self.env.step(action)
indices = STATE_INDICES[self.env.__class__.__name__]
state, _, te, tr, info = self.env.step(action)
indices = STATE_INDICES[self.env_name]
new_position = (
np.array(list(self.position))
+ np.array([state[indices[0]], state[indices[1]]])
* self.env.env.sys.config.dt
* self.dt
)
current_distance_to_goal = np.linalg.norm(self.goal_position - new_position)
previous_distance_to_goal = np.linalg.norm(self.goal_position - self.position)
direction_reward = max(0, previous_distance_to_goal - current_distance_to_goal)
self.position = new_position
if abs(current_distance_to_goal) <= 5:
done = True
te = True
info["success"] = 1

Check warning on line 137 in carl/envs/brax/brax_walker_goal_wrapper.py

View check run for this annotation

Codecov / codecov/patch

carl/envs/brax/brax_walker_goal_wrapper.py#L136-L137

Added lines #L136 - L137 were not covered by tests
else:
info["success"] = 0
return state, direction_reward, done, info
return state, direction_reward, te, tr, info


class BraxLanguageWrapper(gym.Wrapper):
"""Translates the context features target distance and target radius into language"""

def __init__(self, env) -> None:
super().__init__(env)
self.context = None

def reset(self, return_info=False):
state, info = self.env.reset(info=True)
goal_str = self.get_goal_desc(info["context"])
extended_state = {"env_state": state, "goal": goal_str}
if return_info:
return extended_state, info
def reset(self, seed=None, options={}):
print(self.context)
self.env.context = self.context
state, info = self.env.reset(seed=seed, options=options)
goal_str = self.get_goal_desc(self.context)
if isinstance(state, dict):
state["goal"] = goal_str

Check warning on line 156 in carl/envs/brax/brax_walker_goal_wrapper.py

View check run for this annotation

Codecov / codecov/patch

carl/envs/brax/brax_walker_goal_wrapper.py#L156

Added line #L156 was not covered by tests
else:
return extended_state
state = {"obs": state, "goal": goal_str}
return state, info

def step(self, action):
state, reward, done, info = self.env.step(action)
goal_str = self.get_goal_desc(info["context"])
extended_state = {"env_state": state, "goal": goal_str}
return extended_state, reward, done, info
state, reward, te, tr, info = self.env.step(action)
goal_str = self.get_goal_desc(self.context)
if isinstance(state, dict):
state["goal"] = goal_str

Check warning on line 165 in carl/envs/brax/brax_walker_goal_wrapper.py

View check run for this annotation

Codecov / codecov/patch

carl/envs/brax/brax_walker_goal_wrapper.py#L165

Added line #L165 was not covered by tests
else:
state = {"obs": state, "goal": goal_str}
return state, reward, te, tr, info

def get_goal_desc(self, context):
if "target_radius" in context.keys():
Expand Down
30 changes: 24 additions & 6 deletions carl/envs/brax/carl_brax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,21 +211,22 @@ def __init__(

if contexts is not None:
if (
"target_distance" in contexts[contexts.keys()[0]]
or "target_direction" in contexts[contexts.keys()[0]]
"target_distance" in contexts[list(contexts.keys())[0]].keys()
or "target_direction" in contexts[list(contexts.keys())[0]].keys()
):
base_dir = contexts[contexts.keys()[0]]["target_direction"]
base_dist = contexts[contexts.keys()[0]]["target_distance"]
base_dir = contexts[list(contexts.keys())[0]]["target_direction"]
base_dist = contexts[list(contexts.keys())[0]]["target_distance"]
max_diff_dir = max(
[c["target_direction"] - base_dir for c in contexts.values()]
)
max_diff_dist = max(
[c["target_distance"] - base_dist for c in contexts.values()]
)
if max_diff_dir > 0.1 or max_diff_dist > 0.1:
env = BraxWalkerGoalWrapper(env)
env = BraxWalkerGoalWrapper(env, self.env_name, self.asset_path)
if use_language_goals:
env = BraxLanguageWrapper(env, contexts)
env = BraxLanguageWrapper(env)
self.use_language_goals = use_language_goals

super().__init__(
env=env,
Expand All @@ -236,6 +237,7 @@ def __init__(
context_selector_kwargs=context_selector_kwargs,
**kwargs,
)
self.env.context = self.context

def _update_context(self) -> None:
context = self.context
Expand All @@ -247,6 +249,8 @@ def _update_context(self) -> None:
"gravity",
"viscosity",
"elasticity",
"target_distance",
"target_direction",
]
check_context(context, registered_cfs)

Expand Down Expand Up @@ -275,3 +279,17 @@ def _update_context(self) -> None:
sys = sys.replace(geoms=updated_geoms)

self.env.unwrapped.sys = sys

def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[Any, dict[str, Any]]:
"""Overwrites reset in super to update context in wrapper."""
last_context_id = self.context_id
self._progress_instance()
if self.context_id != last_context_id:
self._update_context()
#if self.use_language_goals:
#self.env.env.context = self.context
self.env.context = self.context
state, info = self.env.reset(seed=seed, options=options)
state = self._add_context_to_state(state)
info["context_id"] = self.context_id
return state, info
2 changes: 1 addition & 1 deletion carl/envs/brax/carl_halfcheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ def get_context_features() -> dict[str, ContextFeature]:
"target_distance", lower=0, upper=np.inf, default_value=0
),
"target_direction": CategoricalContextFeature(
"target_direction", choices=directions, default_value=0
"target_direction", choices=directions, default_value=1
),
}
2 changes: 1 addition & 1 deletion carl/envs/brax/carl_hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ def get_context_features() -> dict[str, ContextFeature]:
"target_distance", lower=0, upper=np.inf, default_value=0
),
"target_direction": CategoricalContextFeature(
"target_direction", choices=directions, default_value=0
"target_direction", choices=directions, default_value=1
),
}
2 changes: 1 addition & 1 deletion carl/envs/brax/carl_humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,6 @@ def get_context_features() -> dict[str, ContextFeature]:
"target_distance", lower=0, upper=np.inf, default_value=0
),
"target_direction": CategoricalContextFeature(
"target_direction", choices=directions, default_value=0
"target_direction", choices=directions, default_value=1
),
}
2 changes: 1 addition & 1 deletion carl/envs/brax/carl_walker2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ def get_context_features() -> dict[str, ContextFeature]:
"target_distance", lower=0, upper=np.inf, default_value=0
),
"target_direction": CategoricalContextFeature(
"target_direction", choices=directions, default_value=0
"target_direction", choices=directions, default_value=1
),
}
Loading

0 comments on commit ce69cc4

Please sign in to comment.