From 9762b488133f80772b1c23686ed78c1ee2516cc9 Mon Sep 17 00:00:00 2001 From: mginoya Date: Tue, 6 Aug 2024 16:46:40 -0400 Subject: [PATCH] some progress - still a bit confusing --- alfredo/agents/A1/a1.xml | 6 ++-- alfredo/agents/A1/alfredo_1.py | 34 +++++++------------ alfredo/rewards/rControl.py | 8 ++--- .../Alfredo-simple-walk/seq_training.py | 2 +- experiments/Alfredo-simple-walk/vis_traj.py | 2 +- 5 files changed, 21 insertions(+), 31 deletions(-) diff --git a/alfredo/agents/A1/a1.xml b/alfredo/agents/A1/a1.xml index 5cae861..b293046 100644 --- a/alfredo/agents/A1/a1.xml +++ b/alfredo/agents/A1/a1.xml @@ -18,8 +18,8 @@ - - + + @@ -50,7 +50,7 @@ - + diff --git a/alfredo/agents/A1/alfredo_1.py b/alfredo/agents/A1/alfredo_1.py index 521a4f8..c94875a 100644 --- a/alfredo/agents/A1/alfredo_1.py +++ b/alfredo/agents/A1/alfredo_1.py @@ -128,7 +128,7 @@ def reset(self, rng: jp.ndarray) -> State: 'CoM': com, } - obs = self._get_obs(pipeline_state, jp.zeros(self.sys.act_size()), state_info) + obs = self._get_obs(pipeline_state, state_info) reward, done, zero = jp.zeros(3) metrics = { @@ -149,7 +149,7 @@ def step(self, state: State, action: jp.ndarray) -> State: waypoint_cost = rTracking_Waypoint(self.sys, state.pipeline_state, state.info['wcmd'], - weight=1.0, + weight=2.0, focus_idx_range=0) ctrl_cost = rControl_act_ss(self.sys, @@ -176,7 +176,7 @@ def step(self, state: State, action: jp.ndarray) -> State: reward += waypoint_cost - obs = self._get_obs(pipeline_state, action, state.info) + obs = self._get_obs(pipeline_state, state.info) done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0 state.metrics.update( @@ -190,28 +190,18 @@ def step(self, state: State, action: jp.ndarray) -> State: pipeline_state=pipeline_state, obs=obs, reward=reward, done=done ) - def _get_obs(self, pipeline_state: base.State, action: jp.ndarray, state_info) -> jp.ndarray: - """Observes Alfredo's body position, velocities, and angles.""" - - a_positions = pipeline_state.q - a_velocities = pipeline_state.qd + def _get_obs(self, pipeline_state: base.State, state_info) -> jax.Array: + """Observes Alfredo's body position, and velocities""" + qpos = pipeline_state.q + qvel = pipeline_state.qd wcmd = state_info['wcmd'] + + if self._exclude_current_positions_from_observation: + qpos = pipeline_state.q[2:] - qfrc_actuator = actuator.to_tau( - self.sys, action, pipeline_state.q, pipeline_state.qd - ) - - # external_contact_forces are excluded - return jp.concatenate( - [ - a_positions, - a_velocities, - qfrc_actuator, - wcmd - ] - ) - + return jp.concatenate([qpos] + [qvel] + [wcmd]) + def _com(self, pipeline_state: base.State) -> jp.ndarray: """Computes Center of Mass of Alfredo""" diff --git a/alfredo/rewards/rControl.py b/alfredo/rewards/rControl.py index 228b34b..1de14d4 100644 --- a/alfredo/rewards/rControl.py +++ b/alfredo/rewards/rControl.py @@ -67,11 +67,11 @@ def rTracking_Waypoint(sys: base.System, weight=1.0, focus_idx_range=0) -> jp.ndarray: - x_i = pipeline_state.x.vmap().do( - base.Transform.create(pos=sys.link.inertia.transform.pos) - ) + # x_i = pipeline_state.x.vmap().do( + # base.Transform.create(pos=sys.link.inertia.transform.pos) + # ) - pos_goal_diff = x_i.pos[focus_idx_range] - waypoint + pos_goal_diff = pipeline_state.x.pos[focus_idx_range] - waypoint inv_euclid_dist = -math.safe_norm(pos_goal_diff) return weight*inv_euclid_dist diff --git a/experiments/Alfredo-simple-walk/seq_training.py b/experiments/Alfredo-simple-walk/seq_training.py index 7c653fd..b44e040 100644 --- a/experiments/Alfredo-simple-walk/seq_training.py +++ b/experiments/Alfredo-simple-walk/seq_training.py @@ -131,7 +131,7 @@ def progress(num_steps, metrics): train_fn = functools.partial( ppo.train, num_timesteps=wandb.config.len_training, - num_evals=20, + num_evals=200, reward_scaling=0.1, episode_length=1000, normalize_observations=True, diff --git a/experiments/Alfredo-simple-walk/vis_traj.py b/experiments/Alfredo-simple-walk/vis_traj.py index 5c4813a..6e94410 100644 --- a/experiments/Alfredo-simple-walk/vis_traj.py +++ b/experiments/Alfredo-simple-walk/vis_traj.py @@ -81,7 +81,7 @@ # yaw_vel = 0.0 # rad/s # jcmd = jp.array([x_vel, y_vel, yaw_vel]) -wcmd = jp.array([0.0, 10.0, 0.0]) +wcmd = jp.array([10.0, 0.0, 1.0]) # generate policy rollout for _ in range(episode_length):