From 9d9f1877e8ad6a86afde316a3026607cb49788d4 Mon Sep 17 00:00:00 2001 From: mginoya Date: Fri, 30 Aug 2024 13:46:18 -0400 Subject: [PATCH] joystick command seems to be working - need to add some nice to haves next --- alfredo/agents/aant/aant.py | 52 +++++++++++++++---------- alfredo/rewards/rControl.py | 11 +++++- experiments/AAnt-locomotion/training.py | 18 ++++----- experiments/AAnt-locomotion/vis_traj.py | 10 ++--- 4 files changed, 54 insertions(+), 37 deletions(-) diff --git a/alfredo/agents/aant/aant.py b/alfredo/agents/aant/aant.py index ac1f808..be71c9e 100644 --- a/alfredo/agents/aant/aant.py +++ b/alfredo/agents/aant/aant.py @@ -87,11 +87,14 @@ def reset(self, rng: jax.Array) -> State: low, hi = -self._reset_noise_scale, self._reset_noise_scale jcmd = self._sample_command(rng3) - wcmd = self._sample_waypoint(rng3) + #wcmd = self._sample_waypoint(rng3) - print(f"init_q: {self.sys.init_q}") - #wcmd = jp.array([0.0, 10.0]) + #print(f"init_q: {self.sys.init_q}") + wcmd = jp.array([0.0, 0.0]) + #q = self.sys.init_q + #qd = 0 * jax.random.normal(rng2, (self.sys.qd_size(),)) + q = self.sys.init_q + jax.random.uniform( rng1, (self.sys.q_size(),), minval=low, maxval=hi ) @@ -111,8 +114,8 @@ def reset(self, rng: jax.Array) -> State: 'reward_ctrl': zero, 'reward_alive': zero, 'reward_torque': zero, - #'reward_lin_vel': zero, - #'reward_yaw_vel': zero, + 'reward_lin_vel': zero, + 'reward_yaw_vel': zero, 'reward_upright': zero, 'reward_waypoint': zero, 'pos_x_world_abs': zero, @@ -127,7 +130,7 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: """Run one timestep of the environment's dynamics.""" - + pipeline_state0 = state.pipeline_state pipeline_state = self.pipeline_step(pipeline_state0, action) @@ -136,7 +139,7 @@ def step(self, state: State, action: jax.Array) -> State: waypoint_cost = rTracking_Waypoint(self.sys, pipeline_state, state.info['wcmd'], - weight=100.0, + weight=0.0, focus_idx_range=0) lin_vel_reward = rTracking_lin_vel(self.sys, @@ -145,24 +148,24 @@ def step(self, state: State, action: jax.Array) -> State: jp.array([0, 0, 0]), #dummy values for current CoM self.dt, state.info['jcmd'], - weight=15.5, + weight=1.5, focus_idx_range=(0,0)) yaw_vel_reward = rTracking_yaw_vel(self.sys, pipeline_state, state.info['jcmd'], - weight=10.8, + weight=0.8, focus_idx_range=(0,0)) ctrl_cost = rControl_act_ss(self.sys, pipeline_state, action, - weight=-self._ctrl_cost_weight) + weight=0.0) torque_cost = rTorques(self.sys, pipeline_state, action, - weight=-0.003) + weight=0.0) upright_reward = rUpright(self.sys, pipeline_state, @@ -174,14 +177,17 @@ def step(self, state: State, action: jax.Array) -> State: early_terminate=self._terminate_when_unhealthy, weight=0.0, focus_idx_range=(0, 2)) - reward = 0.0 + #reward = 0.0 reward = healthy_reward[0] reward += ctrl_cost reward += torque_cost reward += upright_reward reward += waypoint_cost - #reward += lin_vel_reward - #reward += yaw_vel_reward + reward += lin_vel_reward + reward += yaw_vel_reward + + #print(f"lin_tracking_vel: {lin_vel_reward}") + #print(f"yaw_tracking_vel: {yaw_vel_reward}\n") pos_world = pipeline_state.x.pos[0] abs_pos_world = jp.abs(pos_world) @@ -208,8 +214,8 @@ def step(self, state: State, action: jax.Array) -> State: reward_alive = healthy_reward[0], reward_torque = torque_cost, reward_upright = upright_reward, - #reward_lin_vel = lin_vel_reward, - #reward_yaw_vel = yaw_vel_reward, + reward_lin_vel = lin_vel_reward, + reward_yaw_vel = yaw_vel_reward, reward_waypoint = waypoint_cost, pos_x_world_abs = abs_pos_world[0], pos_y_world_abs = abs_pos_world[1], @@ -227,14 +233,18 @@ def _get_obs(self, pipeline_state, state_info) -> jax.Array: """Observe ant body position and velocities.""" qpos = pipeline_state.q qvel = pipeline_state.qd + + inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0]) + local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot) torso_pos = pipeline_state.x.pos[0] - #jcmd = state_info['jcmd'] + + jcmd = state_info['jcmd'] wcmd = state_info['wcmd'] if self._exclude_current_positions_from_observation: qpos = pipeline_state.q[2:] - return jp.concatenate([qpos] + [qvel] + [torso_pos] + [wcmd]) #[jcmd]) + return jp.concatenate([qpos] + [qvel] + [local_rpyrate] + [jcmd]) #[jcmd]) def _sample_waypoint(self, rng: jax.Array) -> jax.Array: x_range = [-25, 25] @@ -260,9 +270,9 @@ def _sample_waypoint(self, rng: jax.Array) -> jax.Array: return wcmd def _sample_command(self, rng: jax.Array) -> jax.Array: - lin_vel_x_range = [-0.6, 1.5] #[m/s] - lin_vel_y_range = [-0.6, 1.5] #[m/s] - yaw_vel_range = [-0.0, 0.0] #[rad/s] + lin_vel_x_range = [-1, 3] #[m/s] + lin_vel_y_range = [-1, 3] #[m/s] + yaw_vel_range = [-0.7, 0.7] #[rad/s] _, key1, key2, key3 = jax.random.split(rng, 4) diff --git a/alfredo/rewards/rControl.py b/alfredo/rewards/rControl.py index ac3f09c..28d79b3 100644 --- a/alfredo/rewards/rControl.py +++ b/alfredo/rewards/rControl.py @@ -32,6 +32,9 @@ def rTracking_lin_vel(sys: base.System, local_vel = math.rotate(pipeline_state.xd.vel[0], math.quat_inv(pipeline_state.x.rot[0])) + #print(f"rot_quat: {pipeline_state.x.rot[0]}") + #print(f"global_vel: {pipeline_state.xd.vel[0]}") + #print(f"local_vel: {local_vel}\n") #print(f"com_prev:{CoM_prev}, com_now:{CoM_now}") #local_vel = (CoM_prev - CoM_now)/dt #print(f"jcmd[:2]: {jcmd[:2]}") @@ -53,9 +56,13 @@ def rTracking_yaw_vel(sys: base.System, #local_yaw_vel = math.rotate(pipeline_state.xd.ang[focus_idx_range[0]:focus_idx_range[1]], # math.quat_inv(pipeline_state.x.rotate[focus_idx_range[0], focus_idx_range[1]])) - local_yaw_vel = math.rotate(pipeline_state.xd.vel[0], + local_ang_vel = math.rotate(pipeline_state.xd.ang[0], math.quat_inv(pipeline_state.x.rot[0])) - yaw_vel_error = jp.square(jcmd[2] - local_yaw_vel[2]) + + #print(f"global_ang_vel: {pipeline_state.xd.vel[0]}") + #print(f"local_ang_vel: {local_ang_vel}\n") + + yaw_vel_error = jp.square(jcmd[2] - local_ang_vel[2]) yv_reward = jp.exp(-yaw_vel_error/sigma) return weight*yv_reward diff --git a/experiments/AAnt-locomotion/training.py b/experiments/AAnt-locomotion/training.py index fe40455..cce4940 100644 --- a/experiments/AAnt-locomotion/training.py +++ b/experiments/AAnt-locomotion/training.py @@ -26,7 +26,7 @@ config={ "env_name": "AAnt", "backend": "positional", - "seed": 1, + "seed": 13, "len_training": 1_500_000, "num_evals": 500, "num_envs": 2048, @@ -35,7 +35,7 @@ "updates_per_batch": 8, "episode_len": 1000, "unroll_len": 10, - "reward_scaling":0.1, + "reward_scaling":1, "action_repeat": 1, "discounting": 0.97, "learning_rate": 3e-4, @@ -56,8 +56,8 @@ def progress(num_steps, metrics): "step": num_steps, "Total Reward": metrics["eval/episode_reward"]/epi_len, "Waypoint Reward": metrics["eval/episode_reward_waypoint"]/epi_len, - #"Lin Vel Reward": metrics["eval/episode_reward_lin_vel"], - #"Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"], + "Lin Vel Reward": metrics["eval/episode_reward_lin_vel"], + "Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"], "Alive Reward": metrics["eval/episode_reward_alive"]/epi_len, "Ctrl Reward": metrics["eval/episode_reward_ctrl"]/epi_len, "Upright Reward": metrics["eval/episode_reward_upright"]/epi_len, @@ -81,10 +81,10 @@ def progress(num_steps, metrics): scenes_fp = os.path.dirname(scenes.__file__) -env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml", - f"{scenes_fp}/flatworld/flatworld_A1_env.xml", - f"{scenes_fp}/flatworld/flatworld_A1_env.xml", - f"{scenes_fp}/flatworld/flatworld_A1_env.xml"] +env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"] + #f"{scenes_fp}/flatworld/flatworld_A1_env.xml", + #f"{scenes_fp}/flatworld/flatworld_A1_env.xml", + #f"{scenes_fp}/flatworld/flatworld_A1_env.xml"] # make and save initial ppo_network key = jax.random.PRNGKey(wandb.config.seed) @@ -118,7 +118,7 @@ def progress(num_steps, metrics): # ============================ # Training & Saving Params # ============================ -i = 8 +i = 0 for p in env_xml_paths: diff --git a/experiments/AAnt-locomotion/vis_traj.py b/experiments/AAnt-locomotion/vis_traj.py index b1df220..22826c2 100644 --- a/experiments/AAnt-locomotion/vis_traj.py +++ b/experiments/AAnt-locomotion/vis_traj.py @@ -74,10 +74,10 @@ jit_inference_fn = jax.jit(inference_fn) -#x_vel = 0.0 # m/s -#y_vel = -1.5 # m/s -#yaw_vel = 0.0 # rad/s -#jcmd = jp.array([x_vel, y_vel, yaw_vel]) +x_vel = 0.0 # m/s +y_vel = 0.0 # m/s +yaw_vel = 0.7 # rad/s +jcmd = jp.array([x_vel, y_vel, yaw_vel]) wcmd = jp.array([10.0, 10.0]) @@ -86,7 +86,7 @@ rollout.append(state.pipeline_state) act_rng, rng = jax.random.split(rng) - #state.info['jcmd'] = jcmd + state.info['jcmd'] = jcmd state.info['wcmd'] = wcmd act, _ = jit_inference_fn(state.obs, act_rng) state = jit_env_step(state, act)