diff --git a/alfredo/agents/A1/a1.xml b/alfredo/agents/A1/a1.xml
new file mode 100644
index 0000000..3614f37
--- /dev/null
+++ b/alfredo/agents/A1/a1.xml
@@ -0,0 +1,102 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/alfredo/agents/A1/alfredo_1.py b/alfredo/agents/A1/alfredo_1.py
index 814ae67..1d9efb9 100644
--- a/alfredo/agents/A1/alfredo_1.py
+++ b/alfredo/agents/A1/alfredo_1.py
@@ -9,6 +9,11 @@
from etils import epath
from jax import numpy as jp
+from alfredo.tools import compose_scene
+from alfredo.rewards import rConstant
+from alfredo.rewards import rHealthy_simple_z
+from alfredo.rewards import rSpeed_X
+from alfredo.rewards import rControl_act_ss
class Alfredo(PipelineEnv):
# pyformat: disable
@@ -28,15 +33,26 @@ def __init__(
**kwargs,
):
- # forcing this model to need an input paramFile_path
- # will throw error if this is not included in kwargs
+ # forcing this model to need an input scene_xml_path or
+ # the combination of env_xml_path and agent_xml_path
+ # if none of these options are present, an error will be thrown
path=""
- if "paramFile_path" in kwargs:
- path = kwargs["paramFile_path"]
- del kwargs["paramFile_path"]
-
- sys = mjcf.load(path)
+ if "env_xml_path" and "agent_xml_path" in kwargs:
+ env_xp = kwargs["env_xml_path"]
+ agent_xp = kwargs["agent_xml_path"]
+ xml_scene = compose_scene(env_xp, agent_xp)
+ del kwargs["env_xml_path"]
+ del kwargs["agent_xml_path"]
+
+ sys = mjcf.loads(xml_scene)
+
+ # this is vestigial - get rid of this someday soon
+ if "scene_xml_path" in kwargs:
+ path = kwargs["scene_xml_path"]
+ del kwargs["scene_xml_path"]
+
+ sys = mjcf.load(path)
n_frames = 5
@@ -120,33 +136,38 @@ def step(self, state: State, action: jp.ndarray) -> State:
com_before, *_ = self._com(prev_pipeline_state)
com_after, *_ = self._com(pipeline_state)
- a_velocity = (com_after - com_before) / self.dt
-
- reward_vel = math.safe_norm(a_velocity)
- forward_reward = self._forward_reward_weight * a_velocity[0] # * reward_vel
- ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
-
- min_z, max_z = self._healthy_z_range
- is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, x=0.0, y=1.0)
- is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, x=0.0, y=is_healthy)
-
- if self._terminate_when_unhealthy:
- healthy_reward = self._healthy_reward
- else:
- healthy_reward = self._healthy_reward * is_healthy
-
- reward = healthy_reward - ctrl_cost + forward_reward
- done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
+ x_speed_reward = rSpeed_X(self.sys,
+ state.pipeline_state,
+ CoM_prev=com_before,
+ CoM_now=com_after,
+ dt=self.dt,
+ weight=self._forward_reward_weight)
+
+ ctrl_cost = rControl_act_ss(self.sys,
+ state.pipeline_state,
+ action,
+ weight=-self._ctrl_cost_weight)
+
+ healthy_reward = rHealthy_simple_z(self.sys,
+ state.pipeline_state,
+ self._healthy_z_range,
+ early_terminate=self._terminate_when_unhealthy,
+ weight=self._healthy_reward,
+ focus_idx_range=(0, 2))
+
+ reward = healthy_reward[0] + ctrl_cost + x_speed_reward[0]
+
+ done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0
state.metrics.update(
- reward_ctrl=-ctrl_cost,
- reward_alive=healthy_reward,
- reward_velocity=forward_reward,
+ reward_ctrl=ctrl_cost,
+ reward_alive=healthy_reward[0],
+ reward_velocity=x_speed_reward[0],
agent_x_position=com_after[0],
agent_y_position=com_after[1],
- agent_x_velocity=a_velocity[0],
- agent_y_velocity=a_velocity[1],
+ agent_x_velocity=x_speed_reward[1],
+ agent_y_velocity=x_speed_reward[2],
)
return state.replace(
@@ -154,10 +175,12 @@ def step(self, state: State, action: jp.ndarray) -> State:
)
def _get_obs(self, pipeline_state: base.State, action: jp.ndarray) -> jp.ndarray:
- """Observes humanoid body position, velocities, and angles."""
+ """Observes Alfredo's body position, velocities, and angles."""
a_positions = pipeline_state.q
a_velocities = pipeline_state.qd
+ #print(f"a_positions = {a_positions}")
+ #print(f"a_velocities = {a_velocities}")
if self._exclude_current_positions_from_observation:
a_positions = a_positions[2:]
@@ -194,7 +217,7 @@ def _get_obs(self, pipeline_state: base.State, action: jp.ndarray) -> jp.ndarray
)
def _com(self, pipeline_state: base.State) -> jp.ndarray:
- """Computes Center of Mass of the Humanoid"""
+ """Computes Center of Mass of Alfredo"""
inertia = self.sys.link.inertia
diff --git a/alfredo/rewards/__init__.py b/alfredo/rewards/__init__.py
index e69de29..d5908fa 100644
--- a/alfredo/rewards/__init__.py
+++ b/alfredo/rewards/__init__.py
@@ -0,0 +1,4 @@
+from .rConstant import *
+from .rSpeed import *
+from .rHealthy import *
+from .rControl import *
diff --git a/alfredo/rewards/rConstant.py b/alfredo/rewards/rConstant.py
new file mode 100644
index 0000000..18d4f7b
--- /dev/null
+++ b/alfredo/rewards/rConstant.py
@@ -0,0 +1,15 @@
+from typing import Tuple
+
+import jax
+from brax import actuator, base, math
+from brax.envs import PipelineEnv, State
+from brax.io import mjcf
+from etils import epath
+from jax import numpy as jp
+
+def rConstant(sys: base.System,
+ pipeline_state: base.State,
+ weight=1.0,
+ focus_idx_range=(1, -1)) -> jp.ndarray:
+
+ return jp.array([weight])
diff --git a/alfredo/rewards/rControl.py b/alfredo/rewards/rControl.py
new file mode 100644
index 0000000..7e10532
--- /dev/null
+++ b/alfredo/rewards/rControl.py
@@ -0,0 +1,18 @@
+from typing import Tuple
+
+import jax
+from brax import actuator, base, math
+from brax.envs import PipelineEnv, State
+from brax.io import mjcf
+from etils import epath
+from jax import numpy as jp
+
+def rControl_act_ss(sys: base.System,
+ pipeline_state: base.State,
+ action: jp.ndarray,
+ weight=1.0,
+ focus_idx_range=(1, -1)) -> jp.ndarray:
+
+ ctrl_cost = weight * jp.sum(jp.square(action))
+
+ return ctrl_cost
diff --git a/alfredo/rewards/rHealthy.py b/alfredo/rewards/rHealthy.py
new file mode 100644
index 0000000..fc1f0be
--- /dev/null
+++ b/alfredo/rewards/rHealthy.py
@@ -0,0 +1,31 @@
+from typing import Tuple
+
+import jax
+from brax import actuator, base, math
+from brax.envs import PipelineEnv, State
+from brax.io import mjcf
+from etils import epath
+from jax import numpy as jp
+
+def rHealthy_simple_z(sys: base.System,
+ pipeline_state: base.State,
+ z_range: Tuple,
+ early_terminate: True,
+ weight=1.0,
+ focus_idx_range=(1, -1)) -> jp.ndarray:
+
+ min_z, max_z = z_range
+ focus_s = focus_idx_range[0]
+ focus_e = focus_idx_range[-1]
+
+ focus_x_pos = pipeline_state.x.pos[focus_s, focus_e]
+
+ is_healthy = jp.where(focus_x_pos < min_z, x=0.0, y=1.0)
+ is_healthy = jp.where(focus_x_pos > max_z, x=0.0, y=is_healthy)
+
+ if early_terminate:
+ hr = weight
+ else:
+ hr = weight * is_healthy
+
+ return jp.array([hr, is_healthy])
diff --git a/alfredo/rewards/rSpeed.py b/alfredo/rewards/rSpeed.py
new file mode 100644
index 0000000..f500c52
--- /dev/null
+++ b/alfredo/rewards/rSpeed.py
@@ -0,0 +1,44 @@
+from typing import Tuple
+
+import jax
+from brax import actuator, base, math
+from brax.envs import PipelineEnv, State
+from brax.io import mjcf
+from etils import epath
+from jax import numpy as jp
+
+def rSpeed_X(sys: base.System,
+ pipeline_state: base.State,
+ CoM_prev: jp.ndarray,
+ CoM_now: jp.ndarray,
+ dt,
+ weight=1.0,
+ focus_idx_range=(1, -1)) -> jp.ndarray:
+
+
+ velocity = (CoM_now - CoM_prev) / dt
+
+ focus_s = focus_idx_range[0]
+ focus_e = focus_idx_range[-1]
+
+ sxr = weight * velocity[0]
+
+ return jp.array([sxr, velocity[0], velocity[1]])
+
+def rSpeed_Y(sys: base.System,
+ pipeline_state: base.State,
+ CoM_prev: jp.ndarray,
+ CoM_now: jp.ndarray,
+ dt,
+ weight=1.0,
+ focus_idx_range=(1, -1)) -> jp.ndarray:
+
+
+ velocity = (CoM_now - CoM_prev) / dt
+
+ focus_s = focus_idx_range[0]
+ focus_e = focus_idx_range[-1]
+
+ syr = weight * velocity[1]
+
+ return jp.array([syr, velocity[0], velocity[1]])
diff --git a/alfredo/scenes/flatworld/flatworld_A1_env.xml b/alfredo/scenes/flatworld/flatworld_A1_env.xml
new file mode 100644
index 0000000..a30eb64
--- /dev/null
+++ b/alfredo/scenes/flatworld/flatworld_A1_env.xml
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/alfredo/tools/__init__.py b/alfredo/tools/__init__.py
new file mode 100644
index 0000000..590c5ca
--- /dev/null
+++ b/alfredo/tools/__init__.py
@@ -0,0 +1 @@
+from .tXMLCompose import *
diff --git a/alfredo/tools/tXMLCompose.py b/alfredo/tools/tXMLCompose.py
new file mode 100644
index 0000000..2d9f149
--- /dev/null
+++ b/alfredo/tools/tXMLCompose.py
@@ -0,0 +1,69 @@
+import functools
+import os
+import re
+import sys
+
+import xml.etree.ElementTree as ET
+
+def compose_scene(xml_env, xml_agent):
+ body_index = {}
+
+ env_tree = ET.parse(xml_env)
+ agent_tree = ET.parse(xml_agent)
+
+ env_root = env_tree.getroot()
+ worldbody = env_root.find('worldbody')
+
+ ag_root = agent_tree.getroot()
+ ag_body = ag_root.find('body')
+ ag_actuator = ag_root.find('actuator')
+
+ worldbody.append(ag_body)
+ env_root.append(ag_actuator)
+
+ beautify(env_root)
+
+ scene_xml_string = ET.tostring(env_root, encoding='utf-8')
+
+ return scene_xml_string.decode('utf-8')
+
+def beautify(element, indent=' '):
+ queue = [(0, element)] # (level, element)
+
+ while queue:
+ level, element = queue.pop(0)
+ children = [(level + 1, child) for child in list(element)]
+ if children:
+ element.text = '\n' + indent * (level+1) # for child open
+ if queue:
+ element.tail = '\n' + indent * queue[0][0] # for child close
+ else:
+ element.tail = '\n' + indent * (level-1) # for my close
+
+ queue[0:0] = children # prepend children to process them next
+
+if __name__ == '__main__':
+
+ # example usage of compose_scene function
+ import alfredo.scenes as scenes
+ scene_fp = os.path.dirname(scenes.__file__)
+ env_xml_path = f"{scene_fp}/flatworld/flatworld_A1_env.xml"
+
+ import alfredo.agents as agents
+ agents_fp = os.path.dirname(agents.__file__)
+ agent_xml_path = f"{agents_fp}/A1/a1.xml"
+
+ xml_scene = compose_scene(env_xml_path, agent_xml_path)
+ print(xml_scene)
+
+ from typing import Tuple
+ import jax
+ from brax import actuator, base, math
+ from brax.envs import PipelineEnv, State
+ from brax.io import mjcf
+ from etils import epath
+ from jax import numpy as jp
+
+ sys = mjcf.loads(xml_scene)
+ print(sys)
+
diff --git a/experiments/Alfredo-simple-walk/seq_training.py b/experiments/Alfredo-simple-walk/seq_training.py
index cea5311..ef1c90a 100644
--- a/experiments/Alfredo-simple-walk/seq_training.py
+++ b/experiments/Alfredo-simple-walk/seq_training.py
@@ -63,22 +63,29 @@ def progress(num_steps, metrics):
# ==============================
cwd = os.getcwd()
-# get the filepath to the scene xmls
+# get the filepath to the env and agent xmls
import alfredo.scenes as scenes
-scene_fp = os.path.dirname(scenes.__file__)
+import alfredo.agents as agents
+agents_fp = os.path.dirname(agents.__file__)
+agent_xml_path = f"{agents_fp}/A1/a1.xml"
+
+scenes_fp = os.path.dirname(scenes.__file__)
+
+env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
# ============================
# Loading and Defining Envs
# ============================
-pf_paths = [f"{scene_fp}/flatworld/flatworld_A1.xml"]
# make and save initial ppo_network
key = jax.random.PRNGKey(wandb.config.seed)
global_key, local_key = jax.random.split(key)
key_policy, key_value = jax.random.split(global_key)
-env = Alfredo(backend=wandb.config.backend, paramFile_path=pf_paths[0])
+env = Alfredo(backend=wandb.config.backend,
+ env_xml_path=env_xml_paths[0],
+ agent_xml_path=agent_xml_path)
rng = jax.random.PRNGKey(seed=1)
state = env.reset(rng)
@@ -105,11 +112,13 @@ def progress(num_steps, metrics):
# ============================
i = 0
-for p in pf_paths:
+for p in env_xml_paths:
d_and_t = datetime.now()
print(f"[{d_and_t}] loop start for model: {i}")
- env = Alfredo(backend=wandb.config.backend, paramFile_path=p)
+ env = Alfredo(backend=wandb.config.backend,
+ env_xml_path=p,
+ agent_xml_path=agent_xml_path)
mF = f"{cwd}/param-store/{wandb.config.env_name}_params_{i}"
mParams = model.load_params(mF)
diff --git a/experiments/Alfredo-simple-walk/vis_new_model.py b/experiments/Alfredo-simple-walk/vis_new_model.py
index 69082a4..c729e12 100644
--- a/experiments/Alfredo-simple-walk/vis_new_model.py
+++ b/experiments/Alfredo-simple-walk/vis_new_model.py
@@ -22,13 +22,22 @@
# get filepaths from commandline args
cwd = os.getcwd()
+# get the filepath to the env and agent xmls
import alfredo.scenes as scenes
-scene_fp = os.path.dirname(scenes.__file__)
-pf_path = f"{scene_fp}/{sys.argv[-1]}"
+import alfredo.agents as agents
+agents_fp = os.path.dirname(agents.__file__)
+agent_xml_path = f"{agents_fp}/A1/a1.xml"
+
+scenes_fp = os.path.dirname(scenes.__file__)
+
+env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
# create an env and initial state
-env = Alfredo(backend=backend, paramFile_path=pf_path)
+env = Alfredo(backend=backend,
+ env_xml_path=env_xml_paths[0],
+ agent_xml_path=agent_xml_path)
+
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
# render scene
diff --git a/experiments/Alfredo-simple-walk/vis_traj.py b/experiments/Alfredo-simple-walk/vis_traj.py
index 8aa1ca5..385ae92 100644
--- a/experiments/Alfredo-simple-walk/vis_traj.py
+++ b/experiments/Alfredo-simple-walk/vis_traj.py
@@ -22,20 +22,28 @@
# get filepaths from commandline args
cwd = os.getcwd()
+# get the filepath to the env and agent xmls
import alfredo.scenes as scenes
-scene_fp = os.path.dirname(scenes.__file__)
+import alfredo.agents as agents
+agents_fp = os.path.dirname(agents.__file__)
+agent_xml_path = f"{agents_fp}/A1/a1.xml"
-pf_path = f"{scene_fp}/{sys.argv[-2]}"
+scenes_fp = os.path.dirname(scenes.__file__)
+
+env_xml_path = f"{scenes_fp}/{sys.argv[-2]}"
tpf_path = f"{cwd}/{sys.argv[-1]}"
-print(f"model description file: {pf_path}")
+print(f"agent description file: {agent_xml_path}")
+print(f"environment description file: {env_xml_path}")
print(f"neural parameter file: {tpf_path}")
params = model.load_params(tpf_path)
# create an env with auto-reset and load previously trained parameters
-env = Alfredo(backend=backend, paramFile_path=pf_path)
+env = Alfredo(backend=backend,
+ env_xml_path=env_xml_path,
+ agent_xml_path=agent_xml_path)
auto_reset = True
episode_length = 1000
diff --git a/experiments/Alfredo-simulate-step/one_physics_step.py b/experiments/Alfredo-simulate-step/one_physics_step.py
new file mode 100644
index 0000000..ac37a3a
--- /dev/null
+++ b/experiments/Alfredo-simulate-step/one_physics_step.py
@@ -0,0 +1,57 @@
+import functools
+import os
+import re
+import sys
+from datetime import datetime
+
+import brax
+import jax
+import matplotlib.pyplot as plt
+from brax import envs
+from brax.envs.wrappers import training
+from brax.io import html, json, model
+from brax.training.acme import running_statistics
+from brax.training.agents.ppo import networks as ppo_networks
+from jax import numpy as jp
+
+from alfredo.agents.A1.alfredo_1 import Alfredo
+
+backend = "positional"
+
+# Load desired model xml and trained param set
+# get filepaths from commandline args
+cwd = os.getcwd()
+
+# get the filepath to the env and agent xmls
+import alfredo.scenes as scenes
+
+import alfredo.agents as agents
+agents_fp = os.path.dirname(agents.__file__)
+agent_xml_path = f"{agents_fp}/A1/a1.xml"
+
+scenes_fp = os.path.dirname(scenes.__file__)
+
+env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
+
+# create an env and initial state
+env = Alfredo(backend=backend,
+ env_xml_path=env_xml_paths[0],
+ agent_xml_path=agent_xml_path)
+
+state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
+
+#print(f"Alfredo brax env dir: {dir(env)}")
+#print(f"state: {state}")
+
+com = env._com(state.pipeline_state)
+obs = env._get_obs(state.pipeline_state, jp.zeros(env.action_size))
+#print(f"CoM = {com}")
+#print(f"pipeline_state: {state.pipeline_state}")
+#print(f"observation: {obs}")
+print(f"\n-----------------------------------------------------------------\n")
+nState = env.step(state, jp.zeros(env.action_size))
+com = env._com(state.pipeline_state)
+obs = env._get_obs(state.pipeline_state, jp.zeros(env.action_size))
+#print(f"CoM = {com}")
+#print(f"pipeline_state: {state.pipeline_state}")
+#print(f"observation: {obs}")