diff --git a/alfredo/agents/aant/aant.py b/alfredo/agents/aant/aant.py index be71c9e..af966c4 100644 --- a/alfredo/agents/aant/aant.py +++ b/alfredo/agents/aant/aant.py @@ -16,6 +16,7 @@ from alfredo.rewards import rTracking_yaw_vel from alfredo.rewards import rUpright from alfredo.rewards import rTracking_Waypoint +from alfredo.rewards import rStand_still class AAnt(PipelineEnv): """ """ @@ -121,8 +122,8 @@ def reset(self, rng: jax.Array) -> State: 'pos_x_world_abs': zero, 'pos_y_world_abs': zero, 'pos_z_world_abs': zero, - 'dist_goal_x': zero, - 'dist_goal_y': zero, + #'dist_goal_x': zero, + #'dist_goal_y': zero, #'dist_goal_z': zero, } @@ -148,13 +149,13 @@ def step(self, state: State, action: jax.Array) -> State: jp.array([0, 0, 0]), #dummy values for current CoM self.dt, state.info['jcmd'], - weight=1.5, + weight=15.0, focus_idx_range=(0,0)) yaw_vel_reward = rTracking_yaw_vel(self.sys, pipeline_state, state.info['jcmd'], - weight=0.8, + weight=1.0, focus_idx_range=(0,0)) ctrl_cost = rControl_act_ss(self.sys, @@ -194,8 +195,8 @@ def step(self, state: State, action: jax.Array) -> State: #print(f"wcmd: {state.info['wcmd']}") #print(f"x.pos[0]: {pipeline_state.x.pos[0]}") - wcmd = state.info['wcmd'] - dist_goal = pos_world[0:2] - wcmd + #wcmd = state.info['wcmd'] + #dist_goal = pos_world[0:2] - wcmd #print(dist_goal) #print(f'true position in world: {pos_world}') @@ -220,8 +221,8 @@ def step(self, state: State, action: jax.Array) -> State: pos_x_world_abs = abs_pos_world[0], pos_y_world_abs = abs_pos_world[1], pos_z_world_abs = abs_pos_world[2], - dist_goal_x = dist_goal[0], - dist_goal_y = dist_goal[1], + #dist_goal_x = dist_goal[0], + #dist_goal_y = dist_goal[1], #dist_goal_z = dist_goal[2], ) @@ -270,9 +271,9 @@ def _sample_waypoint(self, rng: jax.Array) -> jax.Array: return wcmd def _sample_command(self, rng: jax.Array) -> jax.Array: - lin_vel_x_range = [-1, 3] #[m/s] - lin_vel_y_range = [-1, 3] #[m/s] - yaw_vel_range = [-0.7, 0.7] #[rad/s] + lin_vel_x_range = [0.0, 0.0] #[m/s] + lin_vel_y_range = [0.0, 0.0] #[m/s] + yaw_vel_range = [-1.0, 1.0] #[rad/s] _, key1, key2, key3 = jax.random.split(rng, 4) diff --git a/experiments/AAnt-locomotion/training.py b/experiments/AAnt-locomotion/training.py index cce4940..ef8c1c7 100644 --- a/experiments/AAnt-locomotion/training.py +++ b/experiments/AAnt-locomotion/training.py @@ -56,8 +56,8 @@ def progress(num_steps, metrics): "step": num_steps, "Total Reward": metrics["eval/episode_reward"]/epi_len, "Waypoint Reward": metrics["eval/episode_reward_waypoint"]/epi_len, - "Lin Vel Reward": metrics["eval/episode_reward_lin_vel"], - "Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"], + "Lin Vel Reward": metrics["eval/episode_reward_lin_vel"]/epi_len, + "Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"]/epi_len, "Alive Reward": metrics["eval/episode_reward_alive"]/epi_len, "Ctrl Reward": metrics["eval/episode_reward_ctrl"]/epi_len, "Upright Reward": metrics["eval/episode_reward_upright"]/epi_len, @@ -65,8 +65,8 @@ def progress(num_steps, metrics): "Abs Pos X World": metrics["eval/episode_pos_x_world_abs"]/epi_len, "Abs Pos Y World": metrics["eval/episode_pos_y_world_abs"]/epi_len, "Abs Pos Z World": metrics["eval/episode_pos_z_world_abs"]/epi_len, - "Dist Goal X": metrics["eval/episode_dist_goal_x"]/epi_len, - "Dist Goal Y": metrics["eval/episode_dist_goal_y"]/epi_len, + #"Dist Goal X": metrics["eval/episode_dist_goal_x"]/epi_len, + #"Dist Goal Y": metrics["eval/episode_dist_goal_y"]/epi_len, #"Dist Goal Z": metrics["eval/episode_dist_goal_z"]/epi_len, } ) @@ -81,10 +81,10 @@ def progress(num_steps, metrics): scenes_fp = os.path.dirname(scenes.__file__) -env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"] - #f"{scenes_fp}/flatworld/flatworld_A1_env.xml", - #f"{scenes_fp}/flatworld/flatworld_A1_env.xml", - #f"{scenes_fp}/flatworld/flatworld_A1_env.xml"] +env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml", + f"{scenes_fp}/flatworld/flatworld_A1_env.xml", + f"{scenes_fp}/flatworld/flatworld_A1_env.xml", + f"{scenes_fp}/flatworld/flatworld_A1_env.xml"] # make and save initial ppo_network key = jax.random.PRNGKey(wandb.config.seed) diff --git a/experiments/AAnt-locomotion/vis_traj.py b/experiments/AAnt-locomotion/vis_traj.py index 22826c2..06a4790 100644 --- a/experiments/AAnt-locomotion/vis_traj.py +++ b/experiments/AAnt-locomotion/vis_traj.py @@ -76,7 +76,7 @@ x_vel = 0.0 # m/s y_vel = 0.0 # m/s -yaw_vel = 0.7 # rad/s +yaw_vel = -0.7 # rad/s jcmd = jp.array([x_vel, y_vel, yaw_vel]) wcmd = jp.array([10.0, 10.0])