From 1f5f1233132b5ff125cb6f93a66898d5033ed858 Mon Sep 17 00:00:00 2001 From: mginoya Date: Sun, 25 Aug 2024 21:16:47 -0400 Subject: [PATCH] single waypoint reward is now working ... turned out to be a local minima thing? --- alfredo/agents/aant/aant.py | 44 +++++++++++++------ alfredo/rewards/rControl.py | 15 +++++-- .../AAnt-locomotion/one_physics_step.py | 6 ++- experiments/AAnt-locomotion/training.py | 30 ++++++++----- experiments/AAnt-locomotion/vis_traj.py | 4 +- 5 files changed, 65 insertions(+), 34 deletions(-) diff --git a/alfredo/agents/aant/aant.py b/alfredo/agents/aant/aant.py index ee139b2..9f20292 100644 --- a/alfredo/agents/aant/aant.py +++ b/alfredo/agents/aant/aant.py @@ -90,7 +90,7 @@ def reset(self, rng: jax.Array) -> State: jcmd = self._sample_command(rng3) #wcmd = self._sample_waypoint(rng3) - wcmd = jp.array([10.0, 10.0, 0.5]) + wcmd = jp.array([0.0, 10.0]) q = self.sys.init_q + jax.random.uniform( rng1, (self.sys.q_size(),), minval=low, maxval=hi @@ -118,6 +118,9 @@ def reset(self, rng: jax.Array) -> State: 'pos_x_world_abs': zero, 'pos_y_world_abs': zero, 'pos_z_world_abs': zero, + 'dist_goal_x': zero, + 'dist_goal_y': zero, + #'dist_goal_z': zero, } return State(pipeline_state, obs, reward, done, metrics, state_info) @@ -128,14 +131,16 @@ def step(self, state: State, action: jax.Array) -> State: pipeline_state0 = state.pipeline_state pipeline_state = self.pipeline_step(pipeline_state0, action) + #print(f"wcmd: {state.info['wcmd']}") + #print(f"x.pos[0]: {pipeline_state.x.pos[0]}") waypoint_cost = rTracking_Waypoint(self.sys, - state.pipeline_state, + pipeline_state, state.info['wcmd'], - weight=1.0, + weight=100.0, focus_idx_range=0) lin_vel_reward = rTracking_lin_vel(self.sys, - state.pipeline_state, + pipeline_state, jp.array([0, 0, 0]), #dummy values for previous CoM jp.array([0, 0, 0]), #dummy values for current CoM self.dt, @@ -144,30 +149,30 @@ def step(self, state: State, action: jax.Array) -> State: focus_idx_range=(0,0)) yaw_vel_reward = rTracking_yaw_vel(self.sys, - state.pipeline_state, + pipeline_state, state.info['jcmd'], weight=10.8, focus_idx_range=(0,0)) ctrl_cost = rControl_act_ss(self.sys, - state.pipeline_state, + pipeline_state, action, weight=0.0) torque_cost = rTorques(self.sys, - state.pipeline_state, + pipeline_state, action, weight=0.0) upright_reward = rUpright(self.sys, - state.pipeline_state, + pipeline_state, weight=0.0) healthy_reward = rHealthy_simple_z(self.sys, - state.pipeline_state, + pipeline_state, self._healthy_z_range, early_terminate=self._terminate_when_unhealthy, - weight=1.0, + weight=0.0, focus_idx_range=(0, 2)) reward = 0.0 reward = healthy_reward[0] @@ -181,15 +186,23 @@ def step(self, state: State, action: jax.Array) -> State: pos_world = pipeline_state.x.pos[0] abs_pos_world = jp.abs(pos_world) - print(f'true position in world: {pos_world}') - print(f'absolute position in world: {abs_pos_world}\n') + #print(f"wcmd: {state.info['wcmd']}") + #print(f"x.pos[0]: {pipeline_state.x.pos[0]}") + wcmd = state.info['wcmd'] + dist_goal = pos_world[0:2] - wcmd + #print(dist_goal) + + #print(f'true position in world: {pos_world}') + #print(f'absolute position in world: {abs_pos_world}') + #print(f"dist_goal: {dist_goal}\n") obs = self._get_obs(pipeline_state, state.info) # print(f"\n") # print(f"healthy_reward? {healthy_reward}") # print(f"\n") - done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0 - + #done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0 + done = 0.0 + state.metrics.update( reward_ctrl = ctrl_cost, reward_alive = healthy_reward[0], @@ -201,6 +214,9 @@ def step(self, state: State, action: jax.Array) -> State: pos_x_world_abs = abs_pos_world[0], pos_y_world_abs = abs_pos_world[1], pos_z_world_abs = abs_pos_world[2], + dist_goal_x = dist_goal[0], + dist_goal_y = dist_goal[1], + #dist_goal_z = dist_goal[2], ) return state.replace( diff --git a/alfredo/rewards/rControl.py b/alfredo/rewards/rControl.py index 1de14d4..e2565e5 100644 --- a/alfredo/rewards/rControl.py +++ b/alfredo/rewards/rControl.py @@ -70,8 +70,15 @@ def rTracking_Waypoint(sys: base.System, # x_i = pipeline_state.x.vmap().do( # base.Transform.create(pos=sys.link.inertia.transform.pos) # ) + + #print(f"wcmd: {waypoint}") + #print(f"x.pos[0]: {pipeline_state.x.pos[0]}") + torso_pos = pipeline_state.x.pos[focus_idx_range] + pos_goal_diff = torso_pos[0:2] - waypoint[0:2] + #print(f"pos_goal_diff: {pos_goal_diff}") + pos_sum_abs_diff = -jp.sum(jp.abs(pos_goal_diff)) + #inv_euclid_dist = -math.safe_norm(pos_goal_diff) + #print(f"pos_sum_abs_diff: {pos_sum_abs_diff}") - pos_goal_diff = pipeline_state.x.pos[focus_idx_range] - waypoint - inv_euclid_dist = -math.safe_norm(pos_goal_diff) - - return weight*inv_euclid_dist + #return weight*inv_euclid_dist + return weight*pos_sum_abs_diff diff --git a/experiments/AAnt-locomotion/one_physics_step.py b/experiments/AAnt-locomotion/one_physics_step.py index 16758de..1028943 100644 --- a/experiments/AAnt-locomotion/one_physics_step.py +++ b/experiments/AAnt-locomotion/one_physics_step.py @@ -46,7 +46,7 @@ env_xml_path=env_xml_paths[0], agent_xml_path=agent_xml_path) -rng = jax.random.PRNGKey(seed=0) +rng = jax.random.PRNGKey(seed=3) state = env.reset(rng=rng) #state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0)) normalize = lambda x, y: x @@ -60,7 +60,9 @@ policy_params = (params[0], params[1]) inference_fn = make_policy(policy_params) -wcmd = jp.array([10.0, 10.0, 0.5]) +wcmd = jp.array([0.0, 1000.0]) +key_envs, _ = jax.random.split(rng) +state = env.reset(rng=key_envs) print(f"q: {state.pipeline_state.q}") print(f"\n") diff --git a/experiments/AAnt-locomotion/training.py b/experiments/AAnt-locomotion/training.py index df1476e..e11f8cc 100644 --- a/experiments/AAnt-locomotion/training.py +++ b/experiments/AAnt-locomotion/training.py @@ -28,7 +28,8 @@ "backend": "positional", "seed": 0, "len_training": 1_500_000, - "batch_size": 1024, + "batch_size": 2048, + "episode_len": 1000, }, ) @@ -36,20 +37,25 @@ def progress(num_steps, metrics): print(num_steps) + print(metrics) + epi_len = wandb.config.episode_len wandb.log( { "step": num_steps, - "Total Reward": metrics["eval/episode_reward"], - "Waypoint Reward": metrics["eval/episode_reward_waypoint"], + "Total Reward": metrics["eval/episode_reward"]/epi_len, + "Waypoint Reward": metrics["eval/episode_reward_waypoint"]/epi_len, #"Lin Vel Reward": metrics["eval/episode_reward_lin_vel"], #"Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"], - "Alive Reward": metrics["eval/episode_reward_alive"], - "Ctrl Reward": metrics["eval/episode_reward_ctrl"], - "Upright Reward": metrics["eval/episode_reward_upright"], - "Torque Reward": metrics["eval/episode_reward_torque"], - "Abs Pos X World": metrics["eval/episode_pos_x_world_abs"], - "Abs Pos Y World": metrics["eval/episode_pos_y_world_abs"], - "Abs Pos Z World": metrics["eval/episode_pos_z_world_abs"], + "Alive Reward": metrics["eval/episode_reward_alive"]/epi_len, + "Ctrl Reward": metrics["eval/episode_reward_ctrl"]/epi_len, + "Upright Reward": metrics["eval/episode_reward_upright"]/epi_len, + "Torque Reward": metrics["eval/episode_reward_torque"]/epi_len, + "Abs Pos X World": metrics["eval/episode_pos_x_world_abs"]/epi_len, + "Abs Pos Y World": metrics["eval/episode_pos_y_world_abs"]/epi_len, + "Abs Pos Z World": metrics["eval/episode_pos_z_world_abs"]/epi_len, + "Dist Goal X": metrics["eval/episode_dist_goal_x"]/epi_len, + "Dist Goal Y": metrics["eval/episode_dist_goal_y"]/epi_len, + #"Dist Goal Z": metrics["eval/episode_dist_goal_z"]/epi_len, } ) @@ -120,9 +126,9 @@ def progress(num_steps, metrics): train_fn = functools.partial( ppo.train, num_timesteps=wandb.config.len_training, - num_evals=100, + num_evals=400, reward_scaling=0.1, - episode_length=1000, + episode_length=wandb.config.episode_len, normalize_observations=True, action_repeat=1, unroll_length=10, diff --git a/experiments/AAnt-locomotion/vis_traj.py b/experiments/AAnt-locomotion/vis_traj.py index e227c22..76b645f 100644 --- a/experiments/AAnt-locomotion/vis_traj.py +++ b/experiments/AAnt-locomotion/vis_traj.py @@ -58,7 +58,7 @@ jit_env_step = jax.jit(env.step) rollout = [] -rng = jax.random.PRNGKey(seed=0) +rng = jax.random.PRNGKey(seed=13294) state = jit_env_reset(rng=rng) normalize = lambda x, y: x @@ -79,7 +79,7 @@ #yaw_vel = 0.0 # rad/s #jcmd = jp.array([x_vel, y_vel, yaw_vel]) -wcmd = jp.array([10.0, 10.0, 0.5]) +wcmd = jp.array([0.0, 10.0]) # generate policy rollout for _ in range(episode_length):