From c76e896aa5e1857b0655ba8ae67cca65b45b76f1 Mon Sep 17 00:00:00 2001 From: mginoya Date: Fri, 4 Oct 2024 21:40:43 -0400 Subject: [PATCH] new way to do rewards is now working - more cleanup to follow --- alfredo/agents/A1/alfredo_1.py | 1 - alfredo/agents/aant/aant.py | 217 +++++++----------- alfredo/rewards/__init__.py | 3 +- alfredo/rewards/rConstant.py | 7 +- alfredo/rewards/rControl.py | 57 +---- alfredo/rewards/rEnergy.py | 5 +- alfredo/rewards/rHealthy.py | 7 +- alfredo/rewards/rOrientation.py | 8 +- alfredo/rewards/rSpeed.py | 44 ---- alfredo/rewards/reward.py | 5 +- .../AAnt-locomotion/one_physics_step.py | 42 +++- 11 files changed, 146 insertions(+), 250 deletions(-) delete mode 100644 alfredo/rewards/rSpeed.py diff --git a/alfredo/agents/A1/alfredo_1.py b/alfredo/agents/A1/alfredo_1.py index c94875a..1c9126f 100644 --- a/alfredo/agents/A1/alfredo_1.py +++ b/alfredo/agents/A1/alfredo_1.py @@ -12,7 +12,6 @@ from alfredo.tools import compose_scene from alfredo.rewards import rConstant from alfredo.rewards import rHealthy_simple_z -from alfredo.rewards import rSpeed_X from alfredo.rewards import rControl_act_ss from alfredo.rewards import rTorques from alfredo.rewards import rTracking_lin_vel diff --git a/alfredo/agents/aant/aant.py b/alfredo/agents/aant/aant.py index af966c4..7ec4d06 100644 --- a/alfredo/agents/aant/aant.py +++ b/alfredo/agents/aant/aant.py @@ -7,9 +7,9 @@ from jax import numpy as jp from alfredo.tools import compose_scene +from alfredo.rewards import Reward from alfredo.rewards import rConstant from alfredo.rewards import rHealthy_simple_z -from alfredo.rewards import rSpeed_X from alfredo.rewards import rControl_act_ss from alfredo.rewards import rTorques from alfredo.rewards import rTracking_lin_vel @@ -22,6 +22,9 @@ class AAnt(PipelineEnv): """ """ def __init__(self, + rewards = {}, + env_xml_path = "", + agent_xml_path = "", ctrl_cost_weight=0.5, use_contact_forces=False, contact_cost_weight=5e-4, @@ -34,20 +37,24 @@ def __init__(self, backend='generalized', **kwargs,): - # forcing this model to need an input scene_xml_path or - # the combination of env_xml_path and agent_xml_path - # if none of these options are present, an error will be thrown - path="" - - if "env_xml_path" and "agent_xml_path" in kwargs: - env_xp = kwargs["env_xml_path"] - agent_xp = kwargs["agent_xml_path"] - xml_scene = compose_scene(env_xp, agent_xp) - del kwargs["env_xml_path"] - del kwargs["agent_xml_path"] - + # env_xml_path and agent_xml_path must be provided + if env_xml_path and agent_xml_path: + self._env_xml_path = env_xml_path + self._agent_xml_path = agent_xml_path + + xml_scene = compose_scene(self._env_xml_path, self._agent_xml_path) sys = mjcf.loads(xml_scene) + else: + raise Exception("env_xml_path & agent_xml_path both must be provided") + + # reward dictionary must be provided + if rewards: + self._rewards = rewards + else: + raise Exception("reward_Structure must be in kwargs") + # TODO: clean this up in the future & + # make n_frames a function of input dt n_frames = 5 if backend in ['spring', 'positional']: @@ -64,8 +71,10 @@ def __init__(self, kwargs['n_frames'] = kwargs.get('n_frames', n_frames) + # Initialize the superclass "PipelineEnv" super().__init__(sys=sys, backend=backend, **kwargs) + # Setting other object parameters based on input params self._ctrl_cost_weight = ctrl_cost_weight self._use_contact_forces = use_contact_forces self._contact_cost_weight = contact_cost_weight @@ -83,151 +92,90 @@ def __init__(self, def reset(self, rng: jax.Array) -> State: + rng, rng1, rng2, rng3 = jax.random.split(rng, 4) - low, hi = -self._reset_noise_scale, self._reset_noise_scale - jcmd = self._sample_command(rng3) - #wcmd = self._sample_waypoint(rng3) - - #print(f"init_q: {self.sys.init_q}") - wcmd = jp.array([0.0, 0.0]) - - #q = self.sys.init_q - #qd = 0 * jax.random.normal(rng2, (self.sys.qd_size(),)) - + # initialize position vector with minor randomization in pose q = self.sys.init_q + jax.random.uniform( rng1, (self.sys.q_size(),), minval=low, maxval=hi ) - + + # initialize velocity vector with minor randomization qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),)) + # generate sample commands + jcmd = self._sample_command(rng3) + wcmd = jp.array([0.0, 0.0]) + + # initialize pipeline_state (the physics state) + pipeline_state = self.pipeline_init(q, qd) + + # reset values and metrics + reward, done, zero = jp.zeros(3) + state_info = { 'jcmd':jcmd, 'wcmd':wcmd, + 'rewards': {k: 0.0 for k in self._rewards.keys()}, + 'step': 0, } - - pipeline_state = self.pipeline_init(q, qd) - obs = self._get_obs(pipeline_state, state_info) - reward, done, zero = jp.zeros(3) - metrics = { - 'reward_ctrl': zero, - 'reward_alive': zero, - 'reward_torque': zero, - 'reward_lin_vel': zero, - 'reward_yaw_vel': zero, - 'reward_upright': zero, - 'reward_waypoint': zero, - 'pos_x_world_abs': zero, - 'pos_y_world_abs': zero, - 'pos_z_world_abs': zero, - #'dist_goal_x': zero, - #'dist_goal_y': zero, - #'dist_goal_z': zero, - } + metrics = {'pos_x_world_abs': zero, + 'pos_y_world_abs': zero, + 'pos_z_world_abs': zero,} + for rn, r in self._rewards.items(): + metrics[rn] = state_info['rewards'][rn] + + # get initial observation vector + obs = self._get_obs(pipeline_state, state_info) + return State(pipeline_state, obs, reward, done, metrics, state_info) def step(self, state: State, action: jax.Array) -> State: """Run one timestep of the environment's dynamics.""" - + + # Save the previous physics state and step physics forward pipeline_state0 = state.pipeline_state pipeline_state = self.pipeline_step(pipeline_state0, action) - #print(f"wcmd: {state.info['wcmd']}") - #print(f"x.pos[0]: {pipeline_state.x.pos[0]}") - waypoint_cost = rTracking_Waypoint(self.sys, - pipeline_state, - state.info['wcmd'], - weight=0.0, - focus_idx_range=0) - - lin_vel_reward = rTracking_lin_vel(self.sys, - pipeline_state, - jp.array([0, 0, 0]), #dummy values for previous CoM - jp.array([0, 0, 0]), #dummy values for current CoM - self.dt, - state.info['jcmd'], - weight=15.0, - focus_idx_range=(0,0)) - - yaw_vel_reward = rTracking_yaw_vel(self.sys, - pipeline_state, - state.info['jcmd'], - weight=1.0, - focus_idx_range=(0,0)) - - ctrl_cost = rControl_act_ss(self.sys, - pipeline_state, - action, - weight=0.0) - - torque_cost = rTorques(self.sys, - pipeline_state, - action, - weight=0.0) - - upright_reward = rUpright(self.sys, - pipeline_state, - weight=0.0) - - healthy_reward = rHealthy_simple_z(self.sys, - pipeline_state, - self._healthy_z_range, - early_terminate=self._terminate_when_unhealthy, - weight=0.0, - focus_idx_range=(0, 2)) - #reward = 0.0 - reward = healthy_reward[0] - reward += ctrl_cost - reward += torque_cost - reward += upright_reward - reward += waypoint_cost - reward += lin_vel_reward - reward += yaw_vel_reward - - #print(f"lin_tracking_vel: {lin_vel_reward}") - #print(f"yaw_tracking_vel: {yaw_vel_reward}\n") + # Add all additional parameters to compute rewards + self._rewards['r_lin_vel'].add_param('jcmd', state.info['jcmd']) + self._rewards['r_yaw_vel'].add_param('jcmd', state.info['jcmd']) + + # Compute all rewards and accumulate total reward + total_reward = 0.0 + for rn, r in self._rewards.items(): + r.add_param('sys', self.sys) + r.add_param('pipeline_state', pipeline_state) + reward_value = r.compute() + state.info['rewards'][rn] = reward_value + total_reward += reward_value[0] + # print(f'{rn} reward_val = {reward_value}\n') + + # Computing additional metrics as necessary pos_world = pipeline_state.x.pos[0] abs_pos_world = jp.abs(pos_world) - #print(f"wcmd: {state.info['wcmd']}") - #print(f"x.pos[0]: {pipeline_state.x.pos[0]}") - #wcmd = state.info['wcmd'] - #dist_goal = pos_world[0:2] - wcmd - #print(dist_goal) - - #print(f'true position in world: {pos_world}') - #print(f'absolute position in world: {abs_pos_world}') - #print(f"dist_goal: {dist_goal}\n") - + # Compute observations obs = self._get_obs(pipeline_state, state.info) - # print(f"\n") - # print(f"healthy_reward? {healthy_reward}") - # print(f"\n") - #done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0 done = 0.0 + + # State management + state.info['step'] += 1 + + state.metrics.update(state.info['rewards']) state.metrics.update( - reward_ctrl = ctrl_cost, - reward_alive = healthy_reward[0], - reward_torque = torque_cost, - reward_upright = upright_reward, - reward_lin_vel = lin_vel_reward, - reward_yaw_vel = yaw_vel_reward, - reward_waypoint = waypoint_cost, pos_x_world_abs = abs_pos_world[0], pos_y_world_abs = abs_pos_world[1], pos_z_world_abs = abs_pos_world[2], - #dist_goal_x = dist_goal[0], - #dist_goal_y = dist_goal[1], - #dist_goal_z = dist_goal[2], ) - + return state.replace( - pipeline_state=pipeline_state, obs=obs, reward=reward, done=done + pipeline_state=pipeline_state, obs=obs, reward=total_reward, done=done ) def _get_obs(self, pipeline_state, state_info) -> jax.Array: @@ -240,16 +188,23 @@ def _get_obs(self, pipeline_state, state_info) -> jax.Array: torso_pos = pipeline_state.x.pos[0] jcmd = state_info['jcmd'] - wcmd = state_info['wcmd'] + #wcmd = state_info['wcmd'] if self._exclude_current_positions_from_observation: qpos = pipeline_state.q[2:] - return jp.concatenate([qpos] + [qvel] + [local_rpyrate] + [jcmd]) #[jcmd]) - + obs = jp.concatenate([ + jp.array(qpos), + jp.array(qvel), + jp.array(local_rpyrate), + jp.array(jcmd), + ]) + + return obs + def _sample_waypoint(self, rng: jax.Array) -> jax.Array: - x_range = [-25, 25] - y_range = [-25, 25] + x_range = [-25, 25] + y_range = [-25, 25] z_range = [0, 2] _, key1, key2, key3 = jax.random.split(rng, 4) @@ -271,8 +226,8 @@ def _sample_waypoint(self, rng: jax.Array) -> jax.Array: return wcmd def _sample_command(self, rng: jax.Array) -> jax.Array: - lin_vel_x_range = [0.0, 0.0] #[m/s] - lin_vel_y_range = [0.0, 0.0] #[m/s] + lin_vel_x_range = [-3.0, 3.0] #[m/s] + lin_vel_y_range = [-3.0, 3.0] #[m/s] yaw_vel_range = [-1.0, 1.0] #[rad/s] _, key1, key2, key3 = jax.random.split(rng, 4) diff --git a/alfredo/rewards/__init__.py b/alfredo/rewards/__init__.py index 761a96e..d8e5d93 100644 --- a/alfredo/rewards/__init__.py +++ b/alfredo/rewards/__init__.py @@ -1,5 +1,6 @@ +from .reward import Reward + from .rConstant import * -from .rSpeed import * from .rHealthy import * from .rControl import * from .rEnergy import * diff --git a/alfredo/rewards/rConstant.py b/alfredo/rewards/rConstant.py index ce3b3e3..74f6c06 100644 --- a/alfredo/rewards/rConstant.py +++ b/alfredo/rewards/rConstant.py @@ -8,8 +8,7 @@ from jax import numpy as jp def rConstant(sys: base.System, - pipeline_state: base.State, - weight=1.0, - focus_idx_range=(0, -1)) -> jp.ndarray: + pipeline_state: base.State, + focus_idx_range=(0, -1)) -> jax.Array: - return jp.array([weight]) + return jp.array([1.0]) diff --git a/alfredo/rewards/rControl.py b/alfredo/rewards/rControl.py index 28d79b3..6e332a1 100644 --- a/alfredo/rewards/rControl.py +++ b/alfredo/rewards/rControl.py @@ -10,94 +10,59 @@ def rControl_act_ss(sys: base.System, pipeline_state: base.State, action: jp.ndarray, - weight=1.0, - focus_idx_range=(0, -1)) -> jp.ndarray: + focus_idx_range=(0, -1)) -> jax.Array: ctrl_cost = weight * jp.sum(jp.square(action)) - return ctrl_cost + return jp.array([ctrl_cost]) def rTracking_lin_vel(sys: base.System, pipeline_state: base.State, - CoM_prev: jp.ndarray, - CoM_now: jp.ndarray, - dt, jcmd: jax.Array, - weight=1.0, sigma=0.25, - focus_idx_range=(0, 1)) -> jp.ndarray: + focus_idx_range=(0, 1)) -> jax.Array: - #local_vel = math.rotate(pipeline_state.xd.vel[focus_idx_range[0]:focus_idx_range[1]], - # math.quat_inv(pipeline_state.x.rot[focus_idx_range[0]:focus_idx_range[1]])) local_vel = math.rotate(pipeline_state.xd.vel[0], math.quat_inv(pipeline_state.x.rot[0])) - #print(f"rot_quat: {pipeline_state.x.rot[0]}") - #print(f"global_vel: {pipeline_state.xd.vel[0]}") - #print(f"local_vel: {local_vel}\n") - #print(f"com_prev:{CoM_prev}, com_now:{CoM_now}") - #local_vel = (CoM_prev - CoM_now)/dt - #print(f"jcmd[:2]: {jcmd[:2]}") - #print(f"loca_vel[:2]: {local_vel[:2]}") lv_error = jp.sum(jp.square(jcmd[:2] - local_vel[:2])) # just taking a look at x, y velocities lv_reward = jp.exp(-lv_error/sigma) - #print(f"lv_error: {lv_error}") - #print(f"lv_reward: {lv_reward}") - - return weight*lv_reward + return jp.array([lv_reward]) def rTracking_yaw_vel(sys: base.System, pipeline_state: base.State, jcmd: jax.Array, sigma=0.25, - weight=1.0, - focus_idx_range=(0, 1)) -> jp.ndarray: + focus_idx_range=(0, 1)) -> jax.Array: - #local_yaw_vel = math.rotate(pipeline_state.xd.ang[focus_idx_range[0]:focus_idx_range[1]], - # math.quat_inv(pipeline_state.x.rotate[focus_idx_range[0], focus_idx_range[1]])) local_ang_vel = math.rotate(pipeline_state.xd.ang[0], math.quat_inv(pipeline_state.x.rot[0])) - #print(f"global_ang_vel: {pipeline_state.xd.vel[0]}") - #print(f"local_ang_vel: {local_ang_vel}\n") - yaw_vel_error = jp.square(jcmd[2] - local_ang_vel[2]) yv_reward = jp.exp(-yaw_vel_error/sigma) - return weight*yv_reward + return jp.array([yv_reward]) def rTracking_Waypoint(sys: base.System, pipeline_state: base.State, - waypoint: jax.Array, - weight=1.0, - focus_idx_range=0) -> jp.ndarray: - - # x_i = pipeline_state.x.vmap().do( - # base.Transform.create(pos=sys.link.inertia.transform.pos) - # ) + wcmd: jax.Array, + focus_idx_range=0) -> jax.Array: - #print(f"wcmd: {waypoint}") - #print(f"x.pos[0]: {pipeline_state.x.pos[0]}") torso_pos = pipeline_state.x.pos[focus_idx_range] pos_goal_diff = torso_pos[0:2] - waypoint[0:2] - #print(f"pos_goal_diff: {pos_goal_diff}") pos_sum_abs_diff = -jp.sum(jp.abs(pos_goal_diff)) - #inv_euclid_dist = -math.safe_norm(pos_goal_diff) - #print(f"pos_sum_abs_diff: {pos_sum_abs_diff}") - #return weight*inv_euclid_dist - return weight*pos_sum_abs_diff + return jp.array([pos_sum_abs_diff]) def rStand_still(sys: base.System, pipeline_state: base.State, jcmd: jax.Array, joint_angles: jax.Array, default_pose: jax.Array, - weight: 1.0, - focus_idx_range=0) -> jp.ndarray: + focus_idx_range=0) -> jax.Array: close_to_still = jp.sum(jp.abs(joint_angles - default_pose)) * math.normalize(jcmd[:2])[1] < 0.1 - return weight * close_to_still + return jp.array([close_to_still]) diff --git a/alfredo/rewards/rEnergy.py b/alfredo/rewards/rEnergy.py index 1adeb6f..760d2d8 100644 --- a/alfredo/rewards/rEnergy.py +++ b/alfredo/rewards/rEnergy.py @@ -11,8 +11,7 @@ def rTorques(sys: base.System, pipeline_state: base.State, action: jp.ndarray, - weight=1.0, - focus_idx_range=(0, -1)) -> jp.ndarray: + focus_idx_range=(0, -1)) -> jax.Array: s_idx = focus_idx_range[0] e_idx = focus_idx_range[1] @@ -25,4 +24,4 @@ def rTorques(sys: base.System, tr = jp.sqrt(jp.sum(jp.square(torque))) + jp.sum(jp.abs(torque)) - return weight*tr + return jp.array([tr]) diff --git a/alfredo/rewards/rHealthy.py b/alfredo/rewards/rHealthy.py index 49ed08f..2a696ca 100644 --- a/alfredo/rewards/rHealthy.py +++ b/alfredo/rewards/rHealthy.py @@ -11,8 +11,7 @@ def rHealthy_simple_z(sys: base.System, pipeline_state: base.State, z_range: Tuple, early_terminate: True, - weight=1.0, - focus_idx_range=(0, -1)) -> jp.ndarray: + focus_idx_range=(0, -1)) -> jax.Array: min_z, max_z = z_range focus_s = focus_idx_range[0] @@ -24,8 +23,8 @@ def rHealthy_simple_z(sys: base.System, is_healthy = jp.where(focus_x_pos > max_z, x=0.0, y=is_healthy) if early_terminate: - hr = weight + hr = 1.0 else: - hr = weight * is_healthy + hr = 1.0 * is_healthy return jp.array([hr, is_healthy]) diff --git a/alfredo/rewards/rOrientation.py b/alfredo/rewards/rOrientation.py index 8a2f93e..340c171 100644 --- a/alfredo/rewards/rOrientation.py +++ b/alfredo/rewards/rOrientation.py @@ -9,13 +9,9 @@ def rUpright(sys: base.System, pipeline_state: base.State, - weight = 1.0, - focus_idx_range = (0,0)) -> jp.ndarray: + focus_idx_range = (0,0)) -> jax.Array: up = jp.array([0.0, 0.0, 1.0]) - #print(f"torso orientation? {pipeline_state.x.rot[0]}") rot_up = math.rotate(up, pipeline_state.x.rot[0]) - #print(f"rot_up vector: {rot_up}") - #print(f"dot product: {jp.dot(up, rot_up)}") - return weight*jp.dot(up, rot_up) + return jp.dot(up, rot_up) diff --git a/alfredo/rewards/rSpeed.py b/alfredo/rewards/rSpeed.py deleted file mode 100644 index 0738b32..0000000 --- a/alfredo/rewards/rSpeed.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Tuple - -import jax -from brax import actuator, base, math -from brax.envs import PipelineEnv, State -from brax.io import mjcf -from etils import epath -from jax import numpy as jp - -def rSpeed_X(sys: base.System, - pipeline_state: base.State, - CoM_prev: jp.ndarray, - CoM_now: jp.ndarray, - dt, - weight=1.0, - focus_idx_range=(0, -1)) -> jp.ndarray: - - - velocity = (CoM_now - CoM_prev) / dt - - focus_s = focus_idx_range[0] - focus_e = focus_idx_range[-1] - - sxr = weight * velocity[0] - - return jp.array([sxr, velocity[0], velocity[1]]) - -def rSpeed_Y(sys: base.System, - pipeline_state: base.State, - CoM_prev: jp.ndarray, - CoM_now: jp.ndarray, - dt, - weight=1.0, - focus_idx_range=(0, -1)) -> jp.ndarray: - - - velocity = (CoM_now - CoM_prev) / dt - - focus_s = focus_idx_range[0] - focus_e = focus_idx_range[-1] - - syr = weight * velocity[1] - - return jp.array([syr, velocity[0], velocity[1]]) diff --git a/alfredo/rewards/reward.py b/alfredo/rewards/reward.py index 6876a16..b2ca1bb 100644 --- a/alfredo/rewards/reward.py +++ b/alfredo/rewards/reward.py @@ -27,5 +27,8 @@ def compute(self): scale and general parameters are set. Otherwise, this errors out quite spectacularly """ + + res = self.f(**self.params) + res = res.at[0].multiply(self.scale) #may not be the best way to do this - return self.scale*self.f(**self.params) + return res diff --git a/experiments/AAnt-locomotion/one_physics_step.py b/experiments/AAnt-locomotion/one_physics_step.py index 536b8b7..37dcaf0 100644 --- a/experiments/AAnt-locomotion/one_physics_step.py +++ b/experiments/AAnt-locomotion/one_physics_step.py @@ -16,39 +16,58 @@ from alfredo.agents.aant import AAnt +from alfredo.tools import analyze_neural_params + +from alfredo.rewards import Reward +from alfredo.rewards import rConstant +from alfredo.rewards import rHealthy_simple_z +from alfredo.rewards import rControl_act_ss +from alfredo.rewards import rTorques +from alfredo.rewards import rTracking_lin_vel +from alfredo.rewards import rTracking_yaw_vel +from alfredo.rewards import rUpright +from alfredo.rewards import rTracking_Waypoint +from alfredo.rewards import rStand_still + backend = "positional" # Load desired model xml and trained param set # get filepaths from commandline args cwd = os.getcwd() -# get the filepath to the env and agent xmls -import alfredo.scenes as scenes +# Define reward structure +rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=15.0, ps={}), + 'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.2, ps={})} +# Get the filepath to the env and agent xmls +import alfredo.scenes as scenes import alfredo.agents as agents agents_fp = os.path.dirname(agents.__file__) agent_xml_path = f"{agents_fp}/aant/aant.xml" scenes_fp = os.path.dirname(scenes.__file__) - env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"] - tpf_path = f"{cwd}/{sys.argv[-1]}" print(f"agent description file: {agent_xml_path}") print(f"environment description file: {env_xml_paths[0]}") -print(f"neural parameter file: {tpf_path}") +print(f"neural parameter file: {tpf_path}\n") +# Load neural parameters params = model.load_params(tpf_path) +summary = analyze_neural_params(params) +print(f"summary: {summary}\n") # create an env and initial state -env = AAnt(backend=backend, +env = AAnt(backend=backend, + rewards=rewards, env_xml_path=env_xml_paths[0], agent_xml_path=agent_xml_path) rng = jax.random.PRNGKey(seed=3) state = env.reset(rng=rng) #state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0)) +# Initialize inference neural network normalize = lambda x, y: x normalize = running_statistics.normalize @@ -60,10 +79,11 @@ policy_params = (params[0], params[1]) inference_fn = make_policy(policy_params) -wcmd = jp.array([0.0, 10.0]) +# Reset the env key_envs, _ = jax.random.split(rng) state = env.reset(rng=key_envs) +# Debug printing print(f"q: {state.pipeline_state.q}") print(f"\n") print(f"qd: {state.pipeline_state.qd}") @@ -81,7 +101,11 @@ print(f"done: {state.done}") print(f"\n") -episode_length = 10 +# Rollout trajectory one physics step at a time +episode_length = 1 + +jcmd = jp.array([1.0, 0.0, -0.7]) + for _ in range(episode_length): print(f"\n---------------------------------------------------------------\n") @@ -90,7 +114,7 @@ print(f"act_rng: {act_rng}") print(f"\n") - state.info['wcmd'] = wcmd + state.info['jcmd'] = jcmd print(f"state info: {state.info}") print(f"\n")