Skip to content

Commit

Permalink
joystick command seems to be working - need to add some nice to haves…
Browse files Browse the repository at this point in the history
… next
  • Loading branch information
mginoya committed Aug 30, 2024
1 parent 05fd524 commit 9d9f187
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 37 deletions.
52 changes: 31 additions & 21 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,14 @@ def reset(self, rng: jax.Array) -> State:
low, hi = -self._reset_noise_scale, self._reset_noise_scale

jcmd = self._sample_command(rng3)
wcmd = self._sample_waypoint(rng3)
#wcmd = self._sample_waypoint(rng3)

print(f"init_q: {self.sys.init_q}")
#wcmd = jp.array([0.0, 10.0])
#print(f"init_q: {self.sys.init_q}")
wcmd = jp.array([0.0, 0.0])

#q = self.sys.init_q
#qd = 0 * jax.random.normal(rng2, (self.sys.qd_size(),))

q = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
)
Expand All @@ -111,8 +114,8 @@ def reset(self, rng: jax.Array) -> State:
'reward_ctrl': zero,
'reward_alive': zero,
'reward_torque': zero,
#'reward_lin_vel': zero,
#'reward_yaw_vel': zero,
'reward_lin_vel': zero,
'reward_yaw_vel': zero,
'reward_upright': zero,
'reward_waypoint': zero,
'pos_x_world_abs': zero,
Expand All @@ -127,7 +130,7 @@ def reset(self, rng: jax.Array) -> State:

def step(self, state: State, action: jax.Array) -> State:
"""Run one timestep of the environment's dynamics."""

pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)

Expand All @@ -136,7 +139,7 @@ def step(self, state: State, action: jax.Array) -> State:
waypoint_cost = rTracking_Waypoint(self.sys,
pipeline_state,
state.info['wcmd'],
weight=100.0,
weight=0.0,
focus_idx_range=0)

lin_vel_reward = rTracking_lin_vel(self.sys,
Expand All @@ -145,24 +148,24 @@ def step(self, state: State, action: jax.Array) -> State:
jp.array([0, 0, 0]), #dummy values for current CoM
self.dt,
state.info['jcmd'],
weight=15.5,
weight=1.5,
focus_idx_range=(0,0))

yaw_vel_reward = rTracking_yaw_vel(self.sys,
pipeline_state,
state.info['jcmd'],
weight=10.8,
weight=0.8,
focus_idx_range=(0,0))

ctrl_cost = rControl_act_ss(self.sys,
pipeline_state,
action,
weight=-self._ctrl_cost_weight)
weight=0.0)

torque_cost = rTorques(self.sys,
pipeline_state,
action,
weight=-0.003)
weight=0.0)

upright_reward = rUpright(self.sys,
pipeline_state,
Expand All @@ -174,14 +177,17 @@ def step(self, state: State, action: jax.Array) -> State:
early_terminate=self._terminate_when_unhealthy,
weight=0.0,
focus_idx_range=(0, 2))
reward = 0.0
#reward = 0.0
reward = healthy_reward[0]
reward += ctrl_cost
reward += torque_cost
reward += upright_reward
reward += waypoint_cost
#reward += lin_vel_reward
#reward += yaw_vel_reward
reward += lin_vel_reward
reward += yaw_vel_reward

#print(f"lin_tracking_vel: {lin_vel_reward}")
#print(f"yaw_tracking_vel: {yaw_vel_reward}\n")

pos_world = pipeline_state.x.pos[0]
abs_pos_world = jp.abs(pos_world)
Expand All @@ -208,8 +214,8 @@ def step(self, state: State, action: jax.Array) -> State:
reward_alive = healthy_reward[0],
reward_torque = torque_cost,
reward_upright = upright_reward,
#reward_lin_vel = lin_vel_reward,
#reward_yaw_vel = yaw_vel_reward,
reward_lin_vel = lin_vel_reward,
reward_yaw_vel = yaw_vel_reward,
reward_waypoint = waypoint_cost,
pos_x_world_abs = abs_pos_world[0],
pos_y_world_abs = abs_pos_world[1],
Expand All @@ -227,14 +233,18 @@ def _get_obs(self, pipeline_state, state_info) -> jax.Array:
"""Observe ant body position and velocities."""
qpos = pipeline_state.q
qvel = pipeline_state.qd

inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0])
local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot)
torso_pos = pipeline_state.x.pos[0]
#jcmd = state_info['jcmd']

jcmd = state_info['jcmd']
wcmd = state_info['wcmd']

if self._exclude_current_positions_from_observation:
qpos = pipeline_state.q[2:]

return jp.concatenate([qpos] + [qvel] + [torso_pos] + [wcmd]) #[jcmd])
return jp.concatenate([qpos] + [qvel] + [local_rpyrate] + [jcmd]) #[jcmd])

def _sample_waypoint(self, rng: jax.Array) -> jax.Array:
x_range = [-25, 25]
Expand All @@ -260,9 +270,9 @@ def _sample_waypoint(self, rng: jax.Array) -> jax.Array:
return wcmd

def _sample_command(self, rng: jax.Array) -> jax.Array:
lin_vel_x_range = [-0.6, 1.5] #[m/s]
lin_vel_y_range = [-0.6, 1.5] #[m/s]
yaw_vel_range = [-0.0, 0.0] #[rad/s]
lin_vel_x_range = [-1, 3] #[m/s]
lin_vel_y_range = [-1, 3] #[m/s]
yaw_vel_range = [-0.7, 0.7] #[rad/s]

_, key1, key2, key3 = jax.random.split(rng, 4)

Expand Down
11 changes: 9 additions & 2 deletions alfredo/rewards/rControl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def rTracking_lin_vel(sys: base.System,
local_vel = math.rotate(pipeline_state.xd.vel[0],
math.quat_inv(pipeline_state.x.rot[0]))

#print(f"rot_quat: {pipeline_state.x.rot[0]}")
#print(f"global_vel: {pipeline_state.xd.vel[0]}")
#print(f"local_vel: {local_vel}\n")
#print(f"com_prev:{CoM_prev}, com_now:{CoM_now}")
#local_vel = (CoM_prev - CoM_now)/dt
#print(f"jcmd[:2]: {jcmd[:2]}")
Expand All @@ -53,9 +56,13 @@ def rTracking_yaw_vel(sys: base.System,

#local_yaw_vel = math.rotate(pipeline_state.xd.ang[focus_idx_range[0]:focus_idx_range[1]],
# math.quat_inv(pipeline_state.x.rotate[focus_idx_range[0], focus_idx_range[1]]))
local_yaw_vel = math.rotate(pipeline_state.xd.vel[0],
local_ang_vel = math.rotate(pipeline_state.xd.ang[0],
math.quat_inv(pipeline_state.x.rot[0]))
yaw_vel_error = jp.square(jcmd[2] - local_yaw_vel[2])

#print(f"global_ang_vel: {pipeline_state.xd.vel[0]}")
#print(f"local_ang_vel: {local_ang_vel}\n")

yaw_vel_error = jp.square(jcmd[2] - local_ang_vel[2])
yv_reward = jp.exp(-yaw_vel_error/sigma)

return weight*yv_reward
Expand Down
18 changes: 9 additions & 9 deletions experiments/AAnt-locomotion/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
config={
"env_name": "AAnt",
"backend": "positional",
"seed": 1,
"seed": 13,
"len_training": 1_500_000,
"num_evals": 500,
"num_envs": 2048,
Expand All @@ -35,7 +35,7 @@
"updates_per_batch": 8,
"episode_len": 1000,
"unroll_len": 10,
"reward_scaling":0.1,
"reward_scaling":1,
"action_repeat": 1,
"discounting": 0.97,
"learning_rate": 3e-4,
Expand All @@ -56,8 +56,8 @@ def progress(num_steps, metrics):
"step": num_steps,
"Total Reward": metrics["eval/episode_reward"]/epi_len,
"Waypoint Reward": metrics["eval/episode_reward_waypoint"]/epi_len,
#"Lin Vel Reward": metrics["eval/episode_reward_lin_vel"],
#"Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"],
"Lin Vel Reward": metrics["eval/episode_reward_lin_vel"],
"Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"],
"Alive Reward": metrics["eval/episode_reward_alive"]/epi_len,
"Ctrl Reward": metrics["eval/episode_reward_ctrl"]/epi_len,
"Upright Reward": metrics["eval/episode_reward_upright"]/epi_len,
Expand All @@ -81,10 +81,10 @@ def progress(num_steps, metrics):

scenes_fp = os.path.dirname(scenes.__file__)

env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
#f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
#f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
#f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]

# make and save initial ppo_network
key = jax.random.PRNGKey(wandb.config.seed)
Expand Down Expand Up @@ -118,7 +118,7 @@ def progress(num_steps, metrics):
# ============================
# Training & Saving Params
# ============================
i = 8
i = 0

for p in env_xml_paths:

Expand Down
10 changes: 5 additions & 5 deletions experiments/AAnt-locomotion/vis_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@

jit_inference_fn = jax.jit(inference_fn)

#x_vel = 0.0 # m/s
#y_vel = -1.5 # m/s
#yaw_vel = 0.0 # rad/s
#jcmd = jp.array([x_vel, y_vel, yaw_vel])
x_vel = 0.0 # m/s
y_vel = 0.0 # m/s
yaw_vel = 0.7 # rad/s
jcmd = jp.array([x_vel, y_vel, yaw_vel])

wcmd = jp.array([10.0, 10.0])

Expand All @@ -86,7 +86,7 @@
rollout.append(state.pipeline_state)
act_rng, rng = jax.random.split(rng)

#state.info['jcmd'] = jcmd
state.info['jcmd'] = jcmd
state.info['wcmd'] = wcmd
act, _ = jit_inference_fn(state.obs, act_rng)
state = jit_env_step(state, act)
Expand Down

0 comments on commit 9d9f187

Please sign in to comment.