Skip to content

Commit

Permalink
saving progress to perform a cleanup and maybe try to update brax ver…
Browse files Browse the repository at this point in the history
…sion?
  • Loading branch information
mginoya committed Sep 5, 2024
1 parent 9d9f187 commit 712e18e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
23 changes: 12 additions & 11 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from alfredo.rewards import rTracking_yaw_vel
from alfredo.rewards import rUpright
from alfredo.rewards import rTracking_Waypoint
from alfredo.rewards import rStand_still

class AAnt(PipelineEnv):
""" """
Expand Down Expand Up @@ -121,8 +122,8 @@ def reset(self, rng: jax.Array) -> State:
'pos_x_world_abs': zero,
'pos_y_world_abs': zero,
'pos_z_world_abs': zero,
'dist_goal_x': zero,
'dist_goal_y': zero,
#'dist_goal_x': zero,
#'dist_goal_y': zero,
#'dist_goal_z': zero,
}

Expand All @@ -148,13 +149,13 @@ def step(self, state: State, action: jax.Array) -> State:
jp.array([0, 0, 0]), #dummy values for current CoM
self.dt,
state.info['jcmd'],
weight=1.5,
weight=15.0,
focus_idx_range=(0,0))

yaw_vel_reward = rTracking_yaw_vel(self.sys,
pipeline_state,
state.info['jcmd'],
weight=0.8,
weight=1.0,
focus_idx_range=(0,0))

ctrl_cost = rControl_act_ss(self.sys,
Expand Down Expand Up @@ -194,8 +195,8 @@ def step(self, state: State, action: jax.Array) -> State:

#print(f"wcmd: {state.info['wcmd']}")
#print(f"x.pos[0]: {pipeline_state.x.pos[0]}")
wcmd = state.info['wcmd']
dist_goal = pos_world[0:2] - wcmd
#wcmd = state.info['wcmd']
#dist_goal = pos_world[0:2] - wcmd
#print(dist_goal)

#print(f'true position in world: {pos_world}')
Expand All @@ -220,8 +221,8 @@ def step(self, state: State, action: jax.Array) -> State:
pos_x_world_abs = abs_pos_world[0],
pos_y_world_abs = abs_pos_world[1],
pos_z_world_abs = abs_pos_world[2],
dist_goal_x = dist_goal[0],
dist_goal_y = dist_goal[1],
#dist_goal_x = dist_goal[0],
#dist_goal_y = dist_goal[1],
#dist_goal_z = dist_goal[2],
)

Expand Down Expand Up @@ -270,9 +271,9 @@ def _sample_waypoint(self, rng: jax.Array) -> jax.Array:
return wcmd

def _sample_command(self, rng: jax.Array) -> jax.Array:
lin_vel_x_range = [-1, 3] #[m/s]
lin_vel_y_range = [-1, 3] #[m/s]
yaw_vel_range = [-0.7, 0.7] #[rad/s]
lin_vel_x_range = [0.0, 0.0] #[m/s]
lin_vel_y_range = [0.0, 0.0] #[m/s]
yaw_vel_range = [-1.0, 1.0] #[rad/s]

_, key1, key2, key3 = jax.random.split(rng, 4)

Expand Down
16 changes: 8 additions & 8 deletions experiments/AAnt-locomotion/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ def progress(num_steps, metrics):
"step": num_steps,
"Total Reward": metrics["eval/episode_reward"]/epi_len,
"Waypoint Reward": metrics["eval/episode_reward_waypoint"]/epi_len,
"Lin Vel Reward": metrics["eval/episode_reward_lin_vel"],
"Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"],
"Lin Vel Reward": metrics["eval/episode_reward_lin_vel"]/epi_len,
"Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"]/epi_len,
"Alive Reward": metrics["eval/episode_reward_alive"]/epi_len,
"Ctrl Reward": metrics["eval/episode_reward_ctrl"]/epi_len,
"Upright Reward": metrics["eval/episode_reward_upright"]/epi_len,
"Torque Reward": metrics["eval/episode_reward_torque"]/epi_len,
"Abs Pos X World": metrics["eval/episode_pos_x_world_abs"]/epi_len,
"Abs Pos Y World": metrics["eval/episode_pos_y_world_abs"]/epi_len,
"Abs Pos Z World": metrics["eval/episode_pos_z_world_abs"]/epi_len,
"Dist Goal X": metrics["eval/episode_dist_goal_x"]/epi_len,
"Dist Goal Y": metrics["eval/episode_dist_goal_y"]/epi_len,
#"Dist Goal X": metrics["eval/episode_dist_goal_x"]/epi_len,
#"Dist Goal Y": metrics["eval/episode_dist_goal_y"]/epi_len,
#"Dist Goal Z": metrics["eval/episode_dist_goal_z"]/epi_len,
}
)
Expand All @@ -81,10 +81,10 @@ def progress(num_steps, metrics):

scenes_fp = os.path.dirname(scenes.__file__)

env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
#f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
#f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
#f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]

# make and save initial ppo_network
key = jax.random.PRNGKey(wandb.config.seed)
Expand Down
2 changes: 1 addition & 1 deletion experiments/AAnt-locomotion/vis_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@

x_vel = 0.0 # m/s
y_vel = 0.0 # m/s
yaw_vel = 0.7 # rad/s
yaw_vel = -0.7 # rad/s
jcmd = jp.array([x_vel, y_vel, yaw_vel])

wcmd = jp.array([10.0, 10.0])
Expand Down

0 comments on commit 712e18e

Please sign in to comment.