diff --git a/alfredo/agents/A1/__init__.py b/alfredo/agents/A1/__init__.py
deleted file mode 100644
index 3d6835f..0000000
--- a/alfredo/agents/A1/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .alfredo_1 import *
diff --git a/alfredo/agents/A1/a1.xml b/alfredo/agents/A1/a1.xml
deleted file mode 100644
index 3614f37..0000000
--- a/alfredo/agents/A1/a1.xml
+++ /dev/null
@@ -1,102 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/alfredo/agents/A1/alfredo_1.py b/alfredo/agents/A1/alfredo_1.py
deleted file mode 100644
index 1d9efb9..0000000
--- a/alfredo/agents/A1/alfredo_1.py
+++ /dev/null
@@ -1,244 +0,0 @@
-# pylint:disable=g-multiple-import
-"""Trains Alfredo to run/walk/move in the +x direction."""
-from typing import Tuple
-
-import jax
-from brax import actuator, base, math
-from brax.envs import PipelineEnv, State
-from brax.io import mjcf
-from etils import epath
-from jax import numpy as jp
-
-from alfredo.tools import compose_scene
-from alfredo.rewards import rConstant
-from alfredo.rewards import rHealthy_simple_z
-from alfredo.rewards import rSpeed_X
-from alfredo.rewards import rControl_act_ss
-
-class Alfredo(PipelineEnv):
- # pyformat: disable
- """ """
- # pyformat: enable
-
- def __init__(
- self,
- forward_reward_weight=1.25,
- ctrl_cost_weight=0.1,
- healthy_reward=5.0,
- terminate_when_unhealthy=True,
- healthy_z_range=(1.0, 2.0),
- reset_noise_scale=1e-2,
- exclude_current_positions_from_observation=True,
- backend="generalized",
- **kwargs,
- ):
-
- # forcing this model to need an input scene_xml_path or
- # the combination of env_xml_path and agent_xml_path
- # if none of these options are present, an error will be thrown
- path=""
-
- if "env_xml_path" and "agent_xml_path" in kwargs:
- env_xp = kwargs["env_xml_path"]
- agent_xp = kwargs["agent_xml_path"]
- xml_scene = compose_scene(env_xp, agent_xp)
- del kwargs["env_xml_path"]
- del kwargs["agent_xml_path"]
-
- sys = mjcf.loads(xml_scene)
-
- # this is vestigial - get rid of this someday soon
- if "scene_xml_path" in kwargs:
- path = kwargs["scene_xml_path"]
- del kwargs["scene_xml_path"]
-
- sys = mjcf.load(path)
-
- n_frames = 5
-
- if backend in ["spring", "positional"]:
- sys = sys.replace(dt=0.0015)
- n_frames = 10
- gear = jp.array(
- [
- 150.0,
- 150.0,
- 150.0,
- 150.0,
- 150.0,
- 150.0,
- 150.0,
- 150.0,
- 150.0,
- 150.0,
- 150.0,
- 150.0,
- 150.0,
- 100.0,
- 100.0,
- 100.0,
- 100.0,
- 100.0,
- 100.0,
- ]
- ) # pyformat: disable
- sys = sys.replace(actuator=sys.actuator.replace(gear=gear))
-
- kwargs["n_frames"] = kwargs.get("n_frames", n_frames)
-
- super().__init__(sys=sys, backend=backend, **kwargs)
-
- self._forward_reward_weight = forward_reward_weight
- self._ctrl_cost_weight = ctrl_cost_weight
- self._healthy_reward = healthy_reward
-
- self._terminate_when_unhealthy = terminate_when_unhealthy
- self._healthy_z_range = healthy_z_range
- self._reset_noise_scale = reset_noise_scale
-
- self._exclude_current_positions_from_observation = (
- exclude_current_positions_from_observation
- )
-
- def reset(self, rng: jp.ndarray) -> State:
- """Resets the environment to an initial state."""
- rng, rng1, rng2 = jax.random.split(rng, 3)
-
- low, hi = -self._reset_noise_scale, self._reset_noise_scale
-
- qpos = self.sys.init_q + jax.random.uniform(
- rng1, (self.sys.q_size(),), minval=low, maxval=hi
- )
- qvel = jax.random.uniform(rng2, (self.sys.qd_size(),), minval=low, maxval=hi)
-
- pipeline_state = self.pipeline_init(qpos, qvel)
-
- obs = self._get_obs(pipeline_state, jp.zeros(self.sys.act_size()))
-
- reward, done, zero = jp.zeros(3)
- metrics = {
- "reward_ctrl": zero,
- "reward_alive": zero,
- "reward_velocity": zero,
- "agent_x_position": zero,
- "agent_y_position": zero,
- "agent_x_velocity": zero,
- "agent_y_velocity": zero,
- }
-
- return State(pipeline_state, obs, reward, done, metrics)
-
- def step(self, state: State, action: jp.ndarray) -> State:
- """Runs one timestep of the environment's dynamics."""
- prev_pipeline_state = state.pipeline_state
- pipeline_state = self.pipeline_step(prev_pipeline_state, action)
- obs = self._get_obs(pipeline_state, action)
-
- com_before, *_ = self._com(prev_pipeline_state)
- com_after, *_ = self._com(pipeline_state)
-
- x_speed_reward = rSpeed_X(self.sys,
- state.pipeline_state,
- CoM_prev=com_before,
- CoM_now=com_after,
- dt=self.dt,
- weight=self._forward_reward_weight)
-
- ctrl_cost = rControl_act_ss(self.sys,
- state.pipeline_state,
- action,
- weight=-self._ctrl_cost_weight)
-
- healthy_reward = rHealthy_simple_z(self.sys,
- state.pipeline_state,
- self._healthy_z_range,
- early_terminate=self._terminate_when_unhealthy,
- weight=self._healthy_reward,
- focus_idx_range=(0, 2))
-
- reward = healthy_reward[0] + ctrl_cost + x_speed_reward[0]
-
- done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0
-
- state.metrics.update(
- reward_ctrl=ctrl_cost,
- reward_alive=healthy_reward[0],
- reward_velocity=x_speed_reward[0],
- agent_x_position=com_after[0],
- agent_y_position=com_after[1],
- agent_x_velocity=x_speed_reward[1],
- agent_y_velocity=x_speed_reward[2],
- )
-
- return state.replace(
- pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
- )
-
- def _get_obs(self, pipeline_state: base.State, action: jp.ndarray) -> jp.ndarray:
- """Observes Alfredo's body position, velocities, and angles."""
-
- a_positions = pipeline_state.q
- a_velocities = pipeline_state.qd
- #print(f"a_positions = {a_positions}")
- #print(f"a_velocities = {a_velocities}")
-
- if self._exclude_current_positions_from_observation:
- a_positions = a_positions[2:]
-
- com, inertia, mass_sum, x_i = self._com(pipeline_state)
- cinr = x_i.replace(pos=x_i.pos - com).vmap().do(inertia)
- com_inertia = jp.hstack(
- [cinr.i.reshape((cinr.i.shape[0], -1)), inertia.mass[:, None]]
- )
-
- xd_i = (
- base.Transform.create(pos=x_i.pos - pipeline_state.x.pos)
- .vmap()
- .do(pipeline_state.xd)
- )
-
- com_vel = inertia.mass[:, None] * xd_i.vel / mass_sum
- com_ang = xd_i.ang
- com_velocity = jp.hstack([com_vel, com_ang])
-
- qfrc_actuator = actuator.to_tau(
- self.sys, action, pipeline_state.q, pipeline_state.qd
- )
-
- # external_contact_forces are excluded
- return jp.concatenate(
- [
- a_positions,
- a_velocities,
- com_inertia.ravel(),
- com_velocity.ravel(),
- qfrc_actuator,
- ]
- )
-
- def _com(self, pipeline_state: base.State) -> jp.ndarray:
- """Computes Center of Mass of Alfredo"""
-
- inertia = self.sys.link.inertia
-
- if self.backend in ["spring", "positional"]:
- inertia = inertia.replace(
- i=jax.vmap(jp.diag)(
- jax.vmap(jp.diagonal)(inertia.i)
- ** (1 - self.sys.spring_inertia_scale)
- ),
- mass=inertia.mass ** (1 - self.sys.spring_mass_scale),
- )
-
- mass_sum = jp.sum(inertia.mass)
- x_i = pipeline_state.x.vmap().do(inertia.transform)
-
- com = jp.sum(jax.vmap(jp.multiply)(inertia.mass, x_i.pos), axis=0) / mass_sum
-
- return (
- com,
- inertia,
- mass_sum,
- x_i,
- ) # pytype: disable=bad-return-type # jax-ndarray
-
diff --git a/alfredo/agents/__init__.py b/alfredo/agents/__init__.py
index 88a07ac..53cd490 100644
--- a/alfredo/agents/__init__.py
+++ b/alfredo/agents/__init__.py
@@ -1 +1 @@
-from . import A1
+from . import aant
diff --git a/alfredo/agents/aant/__init__.py b/alfredo/agents/aant/__init__.py
new file mode 100644
index 0000000..145872e
--- /dev/null
+++ b/alfredo/agents/aant/__init__.py
@@ -0,0 +1 @@
+from .aant import *
diff --git a/alfredo/agents/aant/aant.py b/alfredo/agents/aant/aant.py
new file mode 100644
index 0000000..83daa0c
--- /dev/null
+++ b/alfredo/agents/aant/aant.py
@@ -0,0 +1,224 @@
+from brax import base
+from brax import math
+from brax.envs.base import PipelineEnv, State
+from brax.io import mjcf
+from etils import epath
+import jax
+from jax import numpy as jp
+
+from alfredo.tools import compose_scene
+
+class AAnt(PipelineEnv):
+ """ """
+
+ def __init__(self,
+ rewards = {},
+ env_xml_path = "",
+ agent_xml_path = "",
+ terminate_when_unhealthy=False,
+ reset_noise_scale=0.1,
+ exclude_current_positions_from_observation=False,
+ backend='generalized',
+ **kwargs,):
+
+ # env_xml_path and agent_xml_path must be provided
+ if env_xml_path and agent_xml_path:
+ self._env_xml_path = env_xml_path
+ self._agent_xml_path = agent_xml_path
+
+ xml_scene = compose_scene(self._env_xml_path, self._agent_xml_path)
+ sys = mjcf.loads(xml_scene)
+ else:
+ raise Exception("env_xml_path & agent_xml_path both must be provided")
+
+ # reward dictionary must be provided
+ if rewards:
+ self._rewards = rewards
+ else:
+ raise Exception("reward_Structure must be in kwargs")
+
+ # TODO: clean this up in the future &
+ # make n_frames a function of input dt
+ n_frames = 5
+
+ if backend in ['spring', 'positional']:
+ sys = sys.replace(dt=0.005)
+ n_frames = 10
+
+ if backend == 'positional':
+ # TODO: does the same actuator strength work as in spring
+ sys = sys.replace(
+ actuator=sys.actuator.replace(
+ gear=200 * jp.ones_like(sys.actuator.gear)
+ )
+ )
+
+ kwargs['n_frames'] = kwargs.get('n_frames', n_frames)
+
+ # Initialize the superclass "PipelineEnv"
+ super().__init__(sys=sys, backend=backend, **kwargs)
+
+ # Setting other object parameters based on input params
+ self._terminate_when_unhealthy = terminate_when_unhealthy
+ self._reset_noise_scale = reset_noise_scale
+ self._exclude_current_positions_from_observation = (
+ exclude_current_positions_from_observation
+ )
+
+
+ def reset(self, rng: jax.Array) -> State:
+
+ rng, rng1, rng2, rng3 = jax.random.split(rng, 4)
+ low, hi = -self._reset_noise_scale, self._reset_noise_scale
+
+ # initialize position vector with minor randomization in pose
+ q = self.sys.init_q + jax.random.uniform(
+ rng1, (self.sys.q_size(),), minval=low, maxval=hi
+ )
+
+ # initialize velocity vector with minor randomization
+ qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),))
+
+ # generate sample commands
+ jcmd = self._sample_command(rng3)
+ wcmd = jp.array([0.0, 0.0])
+
+ # initialize pipeline_state (the physics state)
+ pipeline_state = self.pipeline_init(q, qd)
+
+ # reset values and metrics
+ reward, done, zero = jp.zeros(3)
+
+ state_info = {
+ 'jcmd':jcmd,
+ 'wcmd':wcmd,
+ 'rewards': {k: 0.0 for k in self._rewards.keys()},
+ 'step': 0,
+ }
+
+ metrics = {'pos_x_world_abs': zero,
+ 'pos_y_world_abs': zero,
+ 'pos_z_world_abs': zero,}
+
+ for rn, r in self._rewards.items():
+ metrics[rn] = state_info['rewards'][rn]
+
+ # get initial observation vector
+ obs = self._get_obs(pipeline_state, state_info)
+
+ return State(pipeline_state, obs, reward, done, metrics, state_info)
+
+ def step(self, state: State, action: jax.Array) -> State:
+ """Run one timestep of the environment's dynamics."""
+
+ # Save the previous physics state and step physics forward
+ pipeline_state0 = state.pipeline_state
+ pipeline_state = self.pipeline_step(pipeline_state0, action)
+
+ # Add all additional parameters to compute rewards
+ self._rewards['r_lin_vel'].add_param('jcmd', state.info['jcmd'])
+ self._rewards['r_yaw_vel'].add_param('jcmd', state.info['jcmd'])
+
+ # Compute all rewards and accumulate total reward
+ total_reward = 0.0
+ for rn, r in self._rewards.items():
+ r.add_param('sys', self.sys)
+ r.add_param('pipeline_state', pipeline_state)
+
+ reward_value = r.compute()
+ state.info['rewards'][rn] = reward_value[0]
+ total_reward += reward_value[0]
+ # print(f'{rn} reward_val = {reward_value}\n')
+
+ # Computing additional metrics as necessary
+ pos_world = pipeline_state.x.pos[0]
+ abs_pos_world = jp.abs(pos_world)
+
+ # Compute observations
+ obs = self._get_obs(pipeline_state, state.info)
+ done = 0.0
+
+ # State management
+ state.info['step'] += 1
+
+ state.metrics.update(state.info['rewards'])
+
+ state.metrics.update(
+ pos_x_world_abs = abs_pos_world[0],
+ pos_y_world_abs = abs_pos_world[1],
+ pos_z_world_abs = abs_pos_world[2],
+ )
+
+ return state.replace(
+ pipeline_state=pipeline_state, obs=obs, reward=total_reward, done=done
+ )
+
+ def _get_obs(self, pipeline_state, state_info) -> jax.Array:
+ """Observe ant body position and velocities."""
+ qpos = pipeline_state.q
+ qvel = pipeline_state.qd
+
+ inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0])
+ local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot)
+ torso_pos = pipeline_state.x.pos[0]
+
+ jcmd = state_info['jcmd']
+ #wcmd = state_info['wcmd']
+
+ if self._exclude_current_positions_from_observation:
+ qpos = pipeline_state.q[2:]
+
+ obs = jp.concatenate([
+ jp.array(qpos),
+ jp.array(qvel),
+ jp.array(local_rpyrate),
+ jp.array(jcmd),
+ ])
+
+ return obs
+
+ def _sample_waypoint(self, rng: jax.Array) -> jax.Array:
+ x_range = [-25, 25]
+ y_range = [-25, 25]
+ z_range = [0, 2]
+
+ _, key1, key2, key3 = jax.random.split(rng, 4)
+
+ x = jax.random.uniform(
+ key1, (1,), minval=x_range[0], maxval=x_range[1]
+ )
+
+ y = jax.random.uniform(
+ key2, (1,), minval=y_range[0], maxval=y_range[1]
+ )
+
+ z = jax.random.uniform(
+ key3, (1,), minval=z_range[0], maxval=z_range[1]
+ )
+
+ wcmd = jp.array([x[0], y[0]])
+
+ return wcmd
+
+ def _sample_command(self, rng: jax.Array) -> jax.Array:
+ lin_vel_x_range = [-3.0, 3.0] #[m/s]
+ lin_vel_y_range = [-3.0, 3.0] #[m/s]
+ yaw_vel_range = [-1.0, 1.0] #[rad/s]
+
+ _, key1, key2, key3 = jax.random.split(rng, 4)
+
+ lin_vel_x = jax.random.uniform(
+ key1, (1,), minval=lin_vel_x_range[0], maxval=lin_vel_x_range[1]
+ )
+
+ lin_vel_y = jax.random.uniform(
+ key2, (1,), minval=lin_vel_y_range[0], maxval=lin_vel_y_range[1]
+ )
+
+ yaw_vel = jax.random.uniform(
+ key3, (1,), minval=yaw_vel_range[0], maxval=yaw_vel_range[1]
+ )
+
+ jcmd = jp.array([lin_vel_x[0], lin_vel_y[0], yaw_vel[0]])
+
+ return jcmd
diff --git a/alfredo/agents/aant/aant.xml b/alfredo/agents/aant/aant.xml
new file mode 100644
index 0000000..43a5fc3
--- /dev/null
+++ b/alfredo/agents/aant/aant.xml
@@ -0,0 +1,94 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/alfredo/rewards/__init__.py b/alfredo/rewards/__init__.py
index d5908fa..d8e5d93 100644
--- a/alfredo/rewards/__init__.py
+++ b/alfredo/rewards/__init__.py
@@ -1,4 +1,7 @@
+from .reward import Reward
+
from .rConstant import *
-from .rSpeed import *
from .rHealthy import *
from .rControl import *
+from .rEnergy import *
+from .rOrientation import *
diff --git a/alfredo/rewards/rConstant.py b/alfredo/rewards/rConstant.py
index 18d4f7b..74f6c06 100644
--- a/alfredo/rewards/rConstant.py
+++ b/alfredo/rewards/rConstant.py
@@ -8,8 +8,7 @@
from jax import numpy as jp
def rConstant(sys: base.System,
- pipeline_state: base.State,
- weight=1.0,
- focus_idx_range=(1, -1)) -> jp.ndarray:
+ pipeline_state: base.State,
+ focus_idx_range=(0, -1)) -> jax.Array:
- return jp.array([weight])
+ return jp.array([1.0])
diff --git a/alfredo/rewards/rControl.py b/alfredo/rewards/rControl.py
index 7e10532..6e332a1 100644
--- a/alfredo/rewards/rControl.py
+++ b/alfredo/rewards/rControl.py
@@ -10,9 +10,59 @@
def rControl_act_ss(sys: base.System,
pipeline_state: base.State,
action: jp.ndarray,
- weight=1.0,
- focus_idx_range=(1, -1)) -> jp.ndarray:
+ focus_idx_range=(0, -1)) -> jax.Array:
ctrl_cost = weight * jp.sum(jp.square(action))
- return ctrl_cost
+ return jp.array([ctrl_cost])
+
+def rTracking_lin_vel(sys: base.System,
+ pipeline_state: base.State,
+ jcmd: jax.Array,
+ sigma=0.25,
+ focus_idx_range=(0, 1)) -> jax.Array:
+
+ local_vel = math.rotate(pipeline_state.xd.vel[0],
+ math.quat_inv(pipeline_state.x.rot[0]))
+
+ lv_error = jp.sum(jp.square(jcmd[:2] - local_vel[:2])) # just taking a look at x, y velocities
+ lv_reward = jp.exp(-lv_error/sigma)
+
+ return jp.array([lv_reward])
+
+def rTracking_yaw_vel(sys: base.System,
+ pipeline_state: base.State,
+ jcmd: jax.Array,
+ sigma=0.25,
+ focus_idx_range=(0, 1)) -> jax.Array:
+
+ local_ang_vel = math.rotate(pipeline_state.xd.ang[0],
+ math.quat_inv(pipeline_state.x.rot[0]))
+
+ yaw_vel_error = jp.square(jcmd[2] - local_ang_vel[2])
+ yv_reward = jp.exp(-yaw_vel_error/sigma)
+
+ return jp.array([yv_reward])
+
+
+def rTracking_Waypoint(sys: base.System,
+ pipeline_state: base.State,
+ wcmd: jax.Array,
+ focus_idx_range=0) -> jax.Array:
+
+ torso_pos = pipeline_state.x.pos[focus_idx_range]
+ pos_goal_diff = torso_pos[0:2] - waypoint[0:2]
+ pos_sum_abs_diff = -jp.sum(jp.abs(pos_goal_diff))
+
+ return jp.array([pos_sum_abs_diff])
+
+def rStand_still(sys: base.System,
+ pipeline_state: base.State,
+ jcmd: jax.Array,
+ joint_angles: jax.Array,
+ default_pose: jax.Array,
+ focus_idx_range=0) -> jax.Array:
+
+ close_to_still = jp.sum(jp.abs(joint_angles - default_pose)) * math.normalize(jcmd[:2])[1] < 0.1
+
+ return jp.array([close_to_still])
diff --git a/alfredo/rewards/rEnergy.py b/alfredo/rewards/rEnergy.py
new file mode 100644
index 0000000..760d2d8
--- /dev/null
+++ b/alfredo/rewards/rEnergy.py
@@ -0,0 +1,27 @@
+from typing import Tuple
+
+import jax
+from brax import actuator, base, math
+from brax.envs import PipelineEnv, State
+from brax.io import mjcf
+from etils import epath
+from jax import numpy as jp
+
+
+def rTorques(sys: base.System,
+ pipeline_state: base.State,
+ action: jp.ndarray,
+ focus_idx_range=(0, -1)) -> jax.Array:
+
+ s_idx = focus_idx_range[0]
+ e_idx = focus_idx_range[1]
+
+ torque = actuator.to_tau(sys,
+ action,
+ pipeline_state.q[s_idx:e_idx],
+ pipeline_state.qd[s_idx:e_idx])
+
+
+ tr = jp.sqrt(jp.sum(jp.square(torque))) + jp.sum(jp.abs(torque))
+
+ return jp.array([tr])
diff --git a/alfredo/rewards/rHealthy.py b/alfredo/rewards/rHealthy.py
index fc1f0be..2a696ca 100644
--- a/alfredo/rewards/rHealthy.py
+++ b/alfredo/rewards/rHealthy.py
@@ -11,9 +11,8 @@ def rHealthy_simple_z(sys: base.System,
pipeline_state: base.State,
z_range: Tuple,
early_terminate: True,
- weight=1.0,
- focus_idx_range=(1, -1)) -> jp.ndarray:
-
+ focus_idx_range=(0, -1)) -> jax.Array:
+
min_z, max_z = z_range
focus_s = focus_idx_range[0]
focus_e = focus_idx_range[-1]
@@ -24,8 +23,8 @@ def rHealthy_simple_z(sys: base.System,
is_healthy = jp.where(focus_x_pos > max_z, x=0.0, y=is_healthy)
if early_terminate:
- hr = weight
+ hr = 1.0
else:
- hr = weight * is_healthy
+ hr = 1.0 * is_healthy
return jp.array([hr, is_healthy])
diff --git a/alfredo/rewards/rOrientation.py b/alfredo/rewards/rOrientation.py
new file mode 100644
index 0000000..340c171
--- /dev/null
+++ b/alfredo/rewards/rOrientation.py
@@ -0,0 +1,17 @@
+from typing import Tuple
+
+import jax
+from brax import actuator, base, math
+from brax.envs import PipelineEnv, State
+from brax.io import mjcf
+from etils import epath
+from jax import numpy as jp
+
+def rUpright(sys: base.System,
+ pipeline_state: base.State,
+ focus_idx_range = (0,0)) -> jax.Array:
+
+ up = jp.array([0.0, 0.0, 1.0])
+ rot_up = math.rotate(up, pipeline_state.x.rot[0])
+
+ return jp.dot(up, rot_up)
diff --git a/alfredo/rewards/rSpeed.py b/alfredo/rewards/rSpeed.py
deleted file mode 100644
index f500c52..0000000
--- a/alfredo/rewards/rSpeed.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from typing import Tuple
-
-import jax
-from brax import actuator, base, math
-from brax.envs import PipelineEnv, State
-from brax.io import mjcf
-from etils import epath
-from jax import numpy as jp
-
-def rSpeed_X(sys: base.System,
- pipeline_state: base.State,
- CoM_prev: jp.ndarray,
- CoM_now: jp.ndarray,
- dt,
- weight=1.0,
- focus_idx_range=(1, -1)) -> jp.ndarray:
-
-
- velocity = (CoM_now - CoM_prev) / dt
-
- focus_s = focus_idx_range[0]
- focus_e = focus_idx_range[-1]
-
- sxr = weight * velocity[0]
-
- return jp.array([sxr, velocity[0], velocity[1]])
-
-def rSpeed_Y(sys: base.System,
- pipeline_state: base.State,
- CoM_prev: jp.ndarray,
- CoM_now: jp.ndarray,
- dt,
- weight=1.0,
- focus_idx_range=(1, -1)) -> jp.ndarray:
-
-
- velocity = (CoM_now - CoM_prev) / dt
-
- focus_s = focus_idx_range[0]
- focus_e = focus_idx_range[-1]
-
- syr = weight * velocity[1]
-
- return jp.array([syr, velocity[0], velocity[1]])
diff --git a/alfredo/rewards/reward.py b/alfredo/rewards/reward.py
new file mode 100644
index 0000000..384dadb
--- /dev/null
+++ b/alfredo/rewards/reward.py
@@ -0,0 +1,41 @@
+from brax import base
+from jax import numpy as jp
+
+class Reward:
+
+ def __init__(self, f, sc, ps):
+ """
+ :param f: A function handle (ie. function that computes this reward)
+ :param sc: A float that gets multiplied to base computation provided by f
+ :param ps: A dictionary of parameters required for the reward computation
+ """
+
+ self.f = f
+ self.scale = sc
+ self.params = ps
+
+ def add_param(self, p_name, p_value):
+ """
+ Updates self.params dictionary with provided key and value
+ """
+
+ self.params[p_name] = p_value
+
+ def compute(self):
+ """
+ computes reward as specified by self.f given
+ scale and general parameters are set.
+ Otherwise, this errors out quite spectacularly
+ """
+
+ res = self.f(**self.params)
+ res = res.at[0].multiply(self.scale) #may not be the best way to do this
+
+ return res
+
+ def __str__(self):
+ """
+ provides a standard string output
+ """
+
+ return f'reward: {self.f}, scale: {self.scale}'
diff --git a/alfredo/scenes/flatworld/flatworld_A1_env.xml b/alfredo/scenes/flatworld/flatworld_A1_env.xml
index a30eb64..ef9dd55 100644
--- a/alfredo/scenes/flatworld/flatworld_A1_env.xml
+++ b/alfredo/scenes/flatworld/flatworld_A1_env.xml
@@ -8,21 +8,7 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
diff --git a/alfredo/tools/__init__.py b/alfredo/tools/__init__.py
index 590c5ca..1019091 100644
--- a/alfredo/tools/__init__.py
+++ b/alfredo/tools/__init__.py
@@ -1 +1,2 @@
from .tXMLCompose import *
+from .tAnalyzeNetwork import *
diff --git a/alfredo/tools/tAnalyzeNetwork.py b/alfredo/tools/tAnalyzeNetwork.py
new file mode 100644
index 0000000..338ae36
--- /dev/null
+++ b/alfredo/tools/tAnalyzeNetwork.py
@@ -0,0 +1,63 @@
+import jax
+import jax.numpy as jp
+
+def analyze_neural_params(params):
+ """
+ This function provides a metric summary of input neural parameter file
+ Structure of contents are included in a tuple where index:
+ 1. RunningStatisticState
+ 2. FrozenDict(Policy Network Params)
+ 3. FrozenDict(Value Network Params)
+ """
+
+ summary = {}
+
+ summary['Running Statistics'] = params[0]
+ summary['Policy Network'] = analyze_one_neural_network(params[1])
+ summary['Value Network'] = analyze_one_neural_network(params[2])
+
+ return summary
+
+
+def analyze_one_neural_network(sn_params):
+ """
+ Helper function that unpacks a single network parameters
+ Assuming the contents of these parameters are provided as:
+
+ FrozenDict({params: {'hidden_x': {bias: Jax.Array, kernel: Jax.Array}}})
+
+ where x in hidden_x represents order of hidden layer
+ (but these also include input and output layers?)
+ """
+
+ summary = {}
+
+ num_layers = 0
+ per_layer_info = {}
+ total_parameters = 0
+
+ for k, v in sn_params.items():
+
+ for layer_name, layer_data in v.items():
+
+ num_layers += 1
+ param_count = 0
+
+ for param_name, param_data in layer_data.items():
+
+ if param_name == 'bias':
+ per_layer_info[layer_name] = {'size_of_layer': len(param_data)}
+ param_count += len(param_data)
+ # print(len(param_data))
+
+ if param_name == 'kernel':
+ param_count += len(param_data)*len(param_data[0])
+ # print(len(param_data)*len(param_data[0]))
+
+ total_parameters += param_count
+
+ summary['num_layers'] = num_layers
+ summary['per_layer_info'] = per_layer_info
+ summary['total_parameters'] = total_parameters
+
+ return summary
diff --git a/alfredo/tools/tXMLCompose.py b/alfredo/tools/tXMLCompose.py
index 2d9f149..e0832eb 100644
--- a/alfredo/tools/tXMLCompose.py
+++ b/alfredo/tools/tXMLCompose.py
@@ -15,10 +15,13 @@ def compose_scene(xml_env, xml_agent):
worldbody = env_root.find('worldbody')
ag_root = agent_tree.getroot()
+ ag_custom = ag_root.find('custom')
ag_body = ag_root.find('body')
ag_actuator = ag_root.find('actuator')
+
worldbody.append(ag_body)
+ env_root.append(ag_custom)
env_root.append(ag_actuator)
beautify(env_root)
diff --git a/experiments/AAnt-locomotion/one_physics_step.py b/experiments/AAnt-locomotion/one_physics_step.py
new file mode 100644
index 0000000..04ddc72
--- /dev/null
+++ b/experiments/AAnt-locomotion/one_physics_step.py
@@ -0,0 +1,148 @@
+import functools
+import os
+import re
+import sys
+from datetime import datetime
+
+import brax
+import jax
+import matplotlib.pyplot as plt
+from brax import envs, math
+from brax.envs.wrappers import training
+from brax.io import html, json, model
+from brax.training.acme import running_statistics
+from brax.training.agents.ppo import networks as ppo_networks
+from jax import numpy as jp
+
+from alfredo.agents.aant import AAnt
+
+from alfredo.tools import analyze_neural_params
+
+from alfredo.rewards import Reward
+from alfredo.rewards import rConstant
+from alfredo.rewards import rHealthy_simple_z
+from alfredo.rewards import rControl_act_ss
+from alfredo.rewards import rTorques
+from alfredo.rewards import rTracking_lin_vel
+from alfredo.rewards import rTracking_yaw_vel
+from alfredo.rewards import rUpright
+from alfredo.rewards import rTracking_Waypoint
+from alfredo.rewards import rStand_still
+
+backend = "positional"
+
+# Load desired model xml and trained param set
+# get filepaths from commandline args
+cwd = os.getcwd()
+
+# Define reward structure
+rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=15.0, ps={}),
+ 'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.2, ps={})}
+
+print(rewards)
+
+# Get the filepath to the env and agent xmls
+import alfredo.scenes as scenes
+import alfredo.agents as agents
+agents_fp = os.path.dirname(agents.__file__)
+agent_xml_path = f"{agents_fp}/aant/aant.xml"
+
+scenes_fp = os.path.dirname(scenes.__file__)
+env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
+tpf_path = f"{cwd}/{sys.argv[-1]}"
+
+print(f"agent description file: {agent_xml_path}")
+print(f"environment description file: {env_xml_paths[0]}")
+print(f"neural parameter file: {tpf_path}\n")
+
+# Load neural parameters
+params = model.load_params(tpf_path)
+summary = analyze_neural_params(params)
+print(f"summary: {summary}\n")
+
+# create an env and initial state
+env = AAnt(backend=backend,
+ rewards=rewards,
+ env_xml_path=env_xml_paths[0],
+ agent_xml_path=agent_xml_path)
+
+rng = jax.random.PRNGKey(seed=3)
+state = env.reset(rng=rng)
+#state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
+
+# Initialize inference neural network
+normalize = lambda x, y: x
+normalize = running_statistics.normalize
+
+ppo_network = ppo_networks.make_ppo_networks(
+ state.obs.shape[-1], env.action_size, preprocess_observations_fn=normalize
+)
+
+make_policy = ppo_networks.make_inference_fn(ppo_network)
+policy_params = (params[0], params[1])
+inference_fn = make_policy(policy_params)
+
+# Reset the env
+key_envs, _ = jax.random.split(rng)
+state = env.reset(rng=key_envs)
+#state = jax.jit(env.reset)(rng=key_envs)
+
+# Debug printing
+print(f"q: {state.pipeline_state.q}")
+print(f"\n")
+print(f"qd: {state.pipeline_state.qd}")
+print(f"\n")
+print(f"x: {state.pipeline_state.x}")
+print(f"\n")
+print(f"xd: {state.pipeline_state.xd}")
+print(f"\n")
+print(f"contact: {state.pipeline_state.contact}")
+print(f"\n")
+print(f"reward: {state.reward}")
+print(f"\n")
+print(state.metrics)
+print(f"\n")
+print(f"done: {state.done}")
+print(f"\n")
+
+# Rollout trajectory one physics step at a time
+episode_length = 1
+
+jcmd = jp.array([1.0, 0.0, -0.7])
+
+for _ in range(episode_length):
+ print(f"\n---------------------------------------------------------------\n")
+
+ act_rng, rng = jax.random.split(rng)
+ print(f"rng: {rng}")
+ print(f"act_rng: {act_rng}")
+ print(f"\n")
+
+ state.info['jcmd'] = jcmd
+ print(f"state info: {state.info}")
+ print(f"\n")
+
+ act, _ = inference_fn(state.obs, act_rng)
+ print(f"observation: {state.obs}")
+ print(f"\n")
+ print(f"action: {act}")
+ print(f"\n")
+
+ state = env.step(state, act)
+
+ print(f"q: {state.pipeline_state.q}")
+ print(f"\n")
+ print(f"qd: {state.pipeline_state.qd}")
+ print(f"\n")
+ print(f"x: {state.pipeline_state.x}")
+ print(f"\n")
+ print(f"xd: {state.pipeline_state.xd}")
+ print(f"\n")
+ print(f"contact: {state.pipeline_state.contact}")
+ print(f"\n")
+ print(f"reward: {state.reward}")
+ print(f"\n")
+ print(state.metrics)
+ print(f"\n")
+ print(f"done: {state.done}")
+ print(f"\n")
diff --git a/experiments/AAnt-locomotion/training.py b/experiments/AAnt-locomotion/training.py
new file mode 100644
index 0000000..111b704
--- /dev/null
+++ b/experiments/AAnt-locomotion/training.py
@@ -0,0 +1,183 @@
+import functools
+import os
+import re
+from datetime import datetime
+
+import brax
+import flax
+import jax
+import optax
+import wandb
+from brax import envs
+
+from brax.io import html, json, model
+from brax.training.acme import running_statistics, specs
+from brax.training.agents.ppo import losses as ppo_losses
+from brax.training.agents.ppo import networks as ppo_networks
+from jax import numpy as jp
+
+from alfredo.agents.aant import AAnt
+from alfredo.train import ppo
+
+from alfredo.rewards import Reward
+from alfredo.rewards import rTracking_lin_vel
+from alfredo.rewards import rTracking_yaw_vel
+
+
+# Define Reward Structure
+rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=8.0, ps={}),
+ 'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.0, ps={})}
+
+# Initialize a new run
+wandb.init(
+ project="aant",
+ config={
+ "env_name": "AAnt",
+ "seed": 13,
+
+ "training_params": {
+ "backend": "positional",
+ "len_training": 1_500_000,
+ "num_evals": 500,
+ "num_envs": 2048,
+ "batch_size": 2048,
+ "num_minibatches": 8,
+ "updates_per_batch": 8,
+ "episode_len": 1000,
+ "unroll_len": 10,
+ "reward_scaling":1,
+ "action_repeat": 1,
+ "discounting": 0.97,
+ "learning_rate": 3e-4,
+ "entropy_cost": 1e-3,
+ "reward_scaling": 0.1,
+ "normalize_obs": True,
+ },
+
+ "rewards":rewards,
+
+ "aux_model_params":{
+
+ }
+ },
+)
+
+# define callback function that will report training progress
+def progress(num_steps, metrics):
+ print(num_steps)
+ print(metrics)
+
+ epi_len = wandb.config.training_params['episode_len']
+
+ log_dict = {'step': num_steps}
+
+ for mn, m in metrics.items():
+ name_in_log = mn.split('/')[-1]
+ log_dict[name_in_log] = m/epi_len
+
+ wandb.log(log_dict)
+
+
+# get the filepath to the env and agent xmls
+cwd = os.getcwd()
+
+import alfredo.scenes as scenes
+import alfredo.agents as agents
+agents_fp = os.path.dirname(agents.__file__)
+agent_xml_path = f"{agents_fp}/aant/aant.xml"
+
+scenes_fp = os.path.dirname(scenes.__file__)
+
+env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
+ f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
+ f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
+ f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
+
+# make and save initial ppo_network
+key = jax.random.PRNGKey(wandb.config.seed)
+global_key, local_key = jax.random.split(key)
+key_policy, key_value = jax.random.split(global_key)
+
+env = AAnt(backend=wandb.config.training_params['backend'],
+ rewards=rewards,
+ env_xml_path=env_xml_paths[0],
+ agent_xml_path=agent_xml_path)
+
+rng = jax.random.PRNGKey(seed=0)
+state = env.reset(rng)
+
+normalize_fn = running_statistics.normalize
+
+ppo_network = ppo_networks.make_ppo_networks(
+ env.observation_size, env.action_size, normalize_fn
+)
+
+init_params = ppo_losses.PPONetworkParams(
+ policy=ppo_network.policy_network.init(key_policy),
+ value=ppo_network.value_network.init(key_value),
+)
+
+normalizer_params = running_statistics.init_state(
+ specs.Array(env.observation_size, jp.float32)
+)
+
+params_to_save = (normalizer_params, init_params.policy, init_params.value)
+
+model.save_params(f"param-store/AAnt_params_0", params_to_save)
+
+# ============================
+# Training & Saving Params
+# ============================
+i = 0
+
+for p in env_xml_paths:
+
+ d_and_t = datetime.now()
+ print(f"[{d_and_t}] loop start for model: {i}")
+ env = AAnt(backend=wandb.config.training_params['backend'],
+ rewards=rewards,
+ env_xml_path=p,
+ agent_xml_path=agent_xml_path)
+
+ mF = f"{cwd}/param-store/{wandb.config.env_name}_params_{i}"
+ mParams = model.load_params(mF)
+
+ d_and_t = datetime.now()
+ print(f"[{d_and_t}] jitting start for model: {i}")
+ state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=wandb.config.seed))
+ d_and_t = datetime.now()
+ print(f"[{d_and_t}] jitting end for model: {i}")
+
+ # define new training function
+ train_fn = functools.partial(
+ ppo.train,
+ num_timesteps=wandb.config.training_params['len_training'],
+ num_evals=wandb.config.training_params['num_evals'],
+ reward_scaling=wandb.config.training_params['reward_scaling'],
+ episode_length=wandb.config.training_params['episode_len'],
+ normalize_observations=wandb.config.training_params['normalize_obs'],
+ action_repeat=wandb.config.training_params['action_repeat'],
+ unroll_length=wandb.config.training_params['unroll_len'],
+ num_minibatches=wandb.config.training_params['num_minibatches'],
+ num_updates_per_batch=wandb.config.training_params['updates_per_batch'],
+ discounting=wandb.config.training_params['discounting'],
+ learning_rate=wandb.config.training_params['learning_rate'],
+ entropy_cost=wandb.config.training_params['entropy_cost'],
+ num_envs=wandb.config.training_params['num_envs'],
+ batch_size=wandb.config.training_params['batch_size'],
+ seed=wandb.config.seed,
+ in_params=mParams,
+ )
+
+ d_and_t = datetime.now()
+ print(f"[{d_and_t}] training start for model: {i}")
+ _, params, _, ts = train_fn(environment=env, progress_fn=progress)
+ d_and_t = datetime.now()
+ print(f"[{d_and_t}] training end for model: {i}")
+
+ i += 1
+ next_m_name = f"param-store/{wandb.config.env_name}_params_{i}"
+ model.save_params(next_m_name, params)
+
+ d_and_t = datetime.now()
+ print(f"[{d_and_t}] loop end for model: {i}")
diff --git a/experiments/Alfredo-simple-walk/vis_traj.py b/experiments/AAnt-locomotion/vis_traj.py
similarity index 68%
rename from experiments/Alfredo-simple-walk/vis_traj.py
rename to experiments/AAnt-locomotion/vis_traj.py
index 385ae92..6140c61 100644
--- a/experiments/Alfredo-simple-walk/vis_traj.py
+++ b/experiments/AAnt-locomotion/vis_traj.py
@@ -14,7 +14,11 @@
from brax.training.agents.ppo import networks as ppo_networks
from jax import numpy as jp
-from alfredo.agents.A1.alfredo_1 import Alfredo
+from alfredo.agents.aant import AAnt
+
+from alfredo.rewards import Reward
+from alfredo.rewards import rTracking_lin_vel
+from alfredo.rewards import rTracking_yaw_vel
backend = "positional"
@@ -24,10 +28,9 @@
# get the filepath to the env and agent xmls
import alfredo.scenes as scenes
-
import alfredo.agents as agents
agents_fp = os.path.dirname(agents.__file__)
-agent_xml_path = f"{agents_fp}/A1/a1.xml"
+agent_xml_path = f"{agents_fp}/aant/aant.xml"
scenes_fp = os.path.dirname(scenes.__file__)
@@ -40,26 +43,33 @@
params = model.load_params(tpf_path)
+# Define Reward Structure
+# For visualizing, this is just to be able to create the env
+# May want to make this not necessary in the future ..?
+rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=8.0, ps={}),
+ 'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.0, ps={})}
+
# create an env with auto-reset and load previously trained parameters
-env = Alfredo(backend=backend,
- env_xml_path=env_xml_path,
- agent_xml_path=agent_xml_path)
+env = AAnt(backend=backend,
+ rewards=rewards,
+ env_xml_path=env_xml_path,
+ agent_xml_path=agent_xml_path)
auto_reset = True
episode_length = 1000
action_repeat = 1
-if episode_length is not None:
- env = training.EpisodeWrapper(env, episode_length, action_repeat)
+#if episode_length is not None:
+# env = training.EpisodeWrapper(env, episode_length, action_repeat)
-if auto_reset:
- env = training.AutoResetWrapper(env)
+#if auto_reset:
+# env = training.AutoResetWrapper(env)
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
rollout = []
-rng = jax.random.PRNGKey(seed=1)
+rng = jax.random.PRNGKey(seed=13194)
state = jit_env_reset(rng=rng)
normalize = lambda x, y: x
@@ -75,12 +85,26 @@
jit_inference_fn = jax.jit(inference_fn)
+x_vel = 0.0 # m/s
+y_vel = 3.0 # m/s
+yaw_vel = 0.0 # rad/s
+jcmd = jp.array([x_vel, y_vel, yaw_vel])
+
+wcmd = jp.array([10.0, 10.0])
+
# generate policy rollout
for _ in range(episode_length):
rollout.append(state.pipeline_state)
act_rng, rng = jax.random.split(rng)
+
+ state.info['jcmd'] = jcmd
+ state.info['wcmd'] = wcmd
act, _ = jit_inference_fn(state.obs, act_rng)
state = jit_env_step(state, act)
+ print(state.info)
+
+
+print(rollout[-1])
html_string = html.render(env.sys.replace(dt=env.dt), rollout)
diff --git a/experiments/Alfredo-simple-walk/README.md b/experiments/Alfredo-simple-walk/README.md
deleted file mode 100644
index 0fd1160..0000000
--- a/experiments/Alfredo-simple-walk/README.md
+++ /dev/null
@@ -1,31 +0,0 @@
-# Experimenting with Brax
-
-The following is a self-contained experiment with Brax and PPO.
-
-## Training
-
-To run as background process:
-
-```
-python -u seq_training.py > training.log &
-```
-
-To view progress:
-
-```
-tail -f training.log
-```
-
-## Visualizing Trajectories
-
-```
-python vis_traj.py
-```
-
-eg.
-
-```
-python vis_traj.py flatworld/flatworld.xml param-store/A0_param_0
-```
-
-note: the filepath to the xml file is relative to the alfredo/scenes.
diff --git a/experiments/Alfredo-simple-walk/seq_training.py b/experiments/Alfredo-simple-walk/seq_training.py
deleted file mode 100644
index ef1c90a..0000000
--- a/experiments/Alfredo-simple-walk/seq_training.py
+++ /dev/null
@@ -1,164 +0,0 @@
-# ============================
-# Imports
-# ============================
-
-import functools
-import os
-import re
-from datetime import datetime
-
-import brax
-import flax
-import jax
-import matplotlib.pyplot as plt
-import optax
-import wandb
-from brax import envs
-# from brax.envs.wrappers import training
-from brax.io import html, json, model
-from brax.training.acme import running_statistics, specs
-from brax.training.agents.ppo import losses as ppo_losses
-from brax.training.agents.ppo import networks as ppo_networks
-from jax import numpy as jp
-
-from alfredo.agents.A1 import Alfredo
-from alfredo.train import ppo
-
-# Initialize a new run
-wandb.init(
- project="alfredo",
- config={
- "env_name": "A1",
- "backend": "positional",
- "seed": 0,
- "len_training": 1_500_000,
- "batch_size": 1024,
- },
-)
-
-# ==============================
-# Useful Functions & Data Defs
-# ==============================
-
-normalize_fn = running_statistics.normalize
-
-
-def progress(num_steps, metrics):
- print(num_steps)
- wandb.log(
- {
- "step": num_steps,
- "Total Reward": metrics["eval/episode_reward"],
- "Vel Reward": metrics["eval/episode_reward_velocity"],
- "Alive Reward": metrics["eval/episode_reward_alive"],
- "Ctrl Reward": metrics["eval/episode_reward_ctrl"],
- "a_vel_x": metrics["eval/episode_agent_x_velocity"],
- "a_vel_y": metrics["eval/episode_agent_y_velocity"],
- }
- )
-
-
-# ==============================
-# General Variable Defs
-# ==============================
-cwd = os.getcwd()
-
-# get the filepath to the env and agent xmls
-import alfredo.scenes as scenes
-
-import alfredo.agents as agents
-agents_fp = os.path.dirname(agents.__file__)
-agent_xml_path = f"{agents_fp}/A1/a1.xml"
-
-scenes_fp = os.path.dirname(scenes.__file__)
-
-env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
-
-# ============================
-# Loading and Defining Envs
-# ============================
-
-# make and save initial ppo_network
-key = jax.random.PRNGKey(wandb.config.seed)
-global_key, local_key = jax.random.split(key)
-key_policy, key_value = jax.random.split(global_key)
-
-env = Alfredo(backend=wandb.config.backend,
- env_xml_path=env_xml_paths[0],
- agent_xml_path=agent_xml_path)
-
-rng = jax.random.PRNGKey(seed=1)
-state = env.reset(rng)
-
-ppo_network = ppo_networks.make_ppo_networks(
- env.observation_size, env.action_size, normalize_fn
-)
-
-init_params = ppo_losses.PPONetworkParams(
- policy=ppo_network.policy_network.init(key_policy),
- value=ppo_network.value_network.init(key_value),
-)
-
-normalizer_params = running_statistics.init_state(
- specs.Array(env.observation_size, jp.float32)
-)
-
-params_to_save = (normalizer_params, init_params.policy, init_params.value)
-
-model.save_params(f"param-store/A1_params_0", params_to_save)
-
-# ============================
-# Training & Saving Params
-# ============================
-i = 0
-
-for p in env_xml_paths:
-
- d_and_t = datetime.now()
- print(f"[{d_and_t}] loop start for model: {i}")
- env = Alfredo(backend=wandb.config.backend,
- env_xml_path=p,
- agent_xml_path=agent_xml_path)
-
- mF = f"{cwd}/param-store/{wandb.config.env_name}_params_{i}"
- mParams = model.load_params(mF)
-
- d_and_t = datetime.now()
- print(f"[{d_and_t}] jitting start for model: {i}")
- state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
- d_and_t = datetime.now()
- print(f"[{d_and_t}] jitting end for model: {i}")
-
- # define new training function
- train_fn = functools.partial(
- ppo.train,
- num_timesteps=wandb.config.len_training,
- num_evals=200,
- reward_scaling=0.1,
- episode_length=1000,
- normalize_observations=True,
- action_repeat=1,
- unroll_length=10,
- num_minibatches=8,
- num_updates_per_batch=8,
- discounting=0.97,
- learning_rate=3e-4,
- entropy_cost=1e-3,
- num_envs=2048,
- batch_size=wandb.config.batch_size,
- seed=1,
- in_params=mParams,
- )
-
- d_and_t = datetime.now()
- print(f"[{d_and_t}] training start for model: {i}")
- _, params, _, ts = train_fn(environment=env, progress_fn=progress)
- d_and_t = datetime.now()
- print(f"[{d_and_t}] training end for model: {i}")
-
- i += 1
- next_m_name = f"param-store/{wandb.config.env_name}_params_{i}"
- model.save_params(next_m_name, params)
-
- d_and_t = datetime.now()
- print(f"[{d_and_t}] loop end for model: {i}")
diff --git a/experiments/Alfredo-simple-walk/vis_new_model.py b/experiments/Alfredo-simple-walk/vis_new_model.py
deleted file mode 100644
index c729e12..0000000
--- a/experiments/Alfredo-simple-walk/vis_new_model.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import functools
-import os
-import re
-import sys
-from datetime import datetime
-
-import brax
-import jax
-import matplotlib.pyplot as plt
-from brax import envs
-from brax.envs.wrappers import training
-from brax.io import html, json, model
-from brax.training.acme import running_statistics
-from brax.training.agents.ppo import networks as ppo_networks
-from jax import numpy as jp
-
-from alfredo.agents.A1.alfredo_1 import Alfredo
-
-backend = "positional"
-
-# Load desired model xml and trained param set
-# get filepaths from commandline args
-cwd = os.getcwd()
-
-# get the filepath to the env and agent xmls
-import alfredo.scenes as scenes
-
-import alfredo.agents as agents
-agents_fp = os.path.dirname(agents.__file__)
-agent_xml_path = f"{agents_fp}/A1/a1.xml"
-
-scenes_fp = os.path.dirname(scenes.__file__)
-
-env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
-
-# create an env and initial state
-env = Alfredo(backend=backend,
- env_xml_path=env_xml_paths[0],
- agent_xml_path=agent_xml_path)
-
-state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
-
-# render scene
-html_string = html.render(env.sys.replace(dt=env.dt), [state.pipeline_state])
-
-# save output to html filepaths
-d_and_t = datetime.now()
-html_file_path = f"{cwd}/vis-store/A1_dev_.html"
-
-html_file_path = html_file_path.replace(" ", "_")
-
-with open(html_file_path, "w") as file:
- file.write(html_string)
- print(f"saved visualization to {html_file_path}")
diff --git a/experiments/Alfredo-simulate-step/one_physics_step.py b/experiments/Alfredo-simulate-step/one_physics_step.py
deleted file mode 100644
index ac37a3a..0000000
--- a/experiments/Alfredo-simulate-step/one_physics_step.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import functools
-import os
-import re
-import sys
-from datetime import datetime
-
-import brax
-import jax
-import matplotlib.pyplot as plt
-from brax import envs
-from brax.envs.wrappers import training
-from brax.io import html, json, model
-from brax.training.acme import running_statistics
-from brax.training.agents.ppo import networks as ppo_networks
-from jax import numpy as jp
-
-from alfredo.agents.A1.alfredo_1 import Alfredo
-
-backend = "positional"
-
-# Load desired model xml and trained param set
-# get filepaths from commandline args
-cwd = os.getcwd()
-
-# get the filepath to the env and agent xmls
-import alfredo.scenes as scenes
-
-import alfredo.agents as agents
-agents_fp = os.path.dirname(agents.__file__)
-agent_xml_path = f"{agents_fp}/A1/a1.xml"
-
-scenes_fp = os.path.dirname(scenes.__file__)
-
-env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
-
-# create an env and initial state
-env = Alfredo(backend=backend,
- env_xml_path=env_xml_paths[0],
- agent_xml_path=agent_xml_path)
-
-state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
-
-#print(f"Alfredo brax env dir: {dir(env)}")
-#print(f"state: {state}")
-
-com = env._com(state.pipeline_state)
-obs = env._get_obs(state.pipeline_state, jp.zeros(env.action_size))
-#print(f"CoM = {com}")
-#print(f"pipeline_state: {state.pipeline_state}")
-#print(f"observation: {obs}")
-print(f"\n-----------------------------------------------------------------\n")
-nState = env.step(state, jp.zeros(env.action_size))
-com = env._com(state.pipeline_state)
-obs = env._get_obs(state.pipeline_state, jp.zeros(env.action_size))
-#print(f"CoM = {com}")
-#print(f"pipeline_state: {state.pipeline_state}")
-#print(f"observation: {obs}")