Skip to content

Commit

Permalink
Merge pull request #35 from tartavull/more-rewards-joystick
Browse files Browse the repository at this point in the history
More rewards joystick + reward structure organization
Manit Ginoya authored Oct 7, 2024
2 parents f1896f7 + a9fefa1 commit 06c8e23
Showing 26 changed files with 903 additions and 737 deletions.
1 change: 0 additions & 1 deletion alfredo/agents/A1/__init__.py

This file was deleted.

102 changes: 0 additions & 102 deletions alfredo/agents/A1/a1.xml

This file was deleted.

244 changes: 0 additions & 244 deletions alfredo/agents/A1/alfredo_1.py

This file was deleted.

2 changes: 1 addition & 1 deletion alfredo/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import A1
from . import aant
1 change: 1 addition & 0 deletions alfredo/agents/aant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .aant import *
224 changes: 224 additions & 0 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from brax import base
from brax import math
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
from etils import epath
import jax
from jax import numpy as jp

from alfredo.tools import compose_scene

class AAnt(PipelineEnv):
""" """

def __init__(self,
rewards = {},
env_xml_path = "",
agent_xml_path = "",
terminate_when_unhealthy=False,
reset_noise_scale=0.1,
exclude_current_positions_from_observation=False,
backend='generalized',
**kwargs,):

# env_xml_path and agent_xml_path must be provided
if env_xml_path and agent_xml_path:
self._env_xml_path = env_xml_path
self._agent_xml_path = agent_xml_path

xml_scene = compose_scene(self._env_xml_path, self._agent_xml_path)
sys = mjcf.loads(xml_scene)
else:
raise Exception("env_xml_path & agent_xml_path both must be provided")

# reward dictionary must be provided
if rewards:
self._rewards = rewards
else:
raise Exception("reward_Structure must be in kwargs")

# TODO: clean this up in the future &
# make n_frames a function of input dt
n_frames = 5

if backend in ['spring', 'positional']:
sys = sys.replace(dt=0.005)
n_frames = 10

if backend == 'positional':
# TODO: does the same actuator strength work as in spring
sys = sys.replace(
actuator=sys.actuator.replace(
gear=200 * jp.ones_like(sys.actuator.gear)
)
)

kwargs['n_frames'] = kwargs.get('n_frames', n_frames)

# Initialize the superclass "PipelineEnv"
super().__init__(sys=sys, backend=backend, **kwargs)

# Setting other object parameters based on input params
self._terminate_when_unhealthy = terminate_when_unhealthy
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)


def reset(self, rng: jax.Array) -> State:

rng, rng1, rng2, rng3 = jax.random.split(rng, 4)
low, hi = -self._reset_noise_scale, self._reset_noise_scale

# initialize position vector with minor randomization in pose
q = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
)

# initialize velocity vector with minor randomization
qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),))

# generate sample commands
jcmd = self._sample_command(rng3)
wcmd = jp.array([0.0, 0.0])

# initialize pipeline_state (the physics state)
pipeline_state = self.pipeline_init(q, qd)

# reset values and metrics
reward, done, zero = jp.zeros(3)

state_info = {
'jcmd':jcmd,
'wcmd':wcmd,
'rewards': {k: 0.0 for k in self._rewards.keys()},
'step': 0,
}

metrics = {'pos_x_world_abs': zero,
'pos_y_world_abs': zero,
'pos_z_world_abs': zero,}

for rn, r in self._rewards.items():
metrics[rn] = state_info['rewards'][rn]

# get initial observation vector
obs = self._get_obs(pipeline_state, state_info)

return State(pipeline_state, obs, reward, done, metrics, state_info)

def step(self, state: State, action: jax.Array) -> State:
"""Run one timestep of the environment's dynamics."""

# Save the previous physics state and step physics forward
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)

# Add all additional parameters to compute rewards
self._rewards['r_lin_vel'].add_param('jcmd', state.info['jcmd'])
self._rewards['r_yaw_vel'].add_param('jcmd', state.info['jcmd'])

# Compute all rewards and accumulate total reward
total_reward = 0.0
for rn, r in self._rewards.items():
r.add_param('sys', self.sys)
r.add_param('pipeline_state', pipeline_state)

reward_value = r.compute()
state.info['rewards'][rn] = reward_value[0]
total_reward += reward_value[0]
# print(f'{rn} reward_val = {reward_value}\n')

# Computing additional metrics as necessary
pos_world = pipeline_state.x.pos[0]
abs_pos_world = jp.abs(pos_world)

# Compute observations
obs = self._get_obs(pipeline_state, state.info)
done = 0.0

# State management
state.info['step'] += 1

state.metrics.update(state.info['rewards'])

state.metrics.update(
pos_x_world_abs = abs_pos_world[0],
pos_y_world_abs = abs_pos_world[1],
pos_z_world_abs = abs_pos_world[2],
)

return state.replace(
pipeline_state=pipeline_state, obs=obs, reward=total_reward, done=done
)

def _get_obs(self, pipeline_state, state_info) -> jax.Array:
"""Observe ant body position and velocities."""
qpos = pipeline_state.q
qvel = pipeline_state.qd

inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0])
local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot)
torso_pos = pipeline_state.x.pos[0]

jcmd = state_info['jcmd']
#wcmd = state_info['wcmd']

if self._exclude_current_positions_from_observation:
qpos = pipeline_state.q[2:]

obs = jp.concatenate([
jp.array(qpos),
jp.array(qvel),
jp.array(local_rpyrate),
jp.array(jcmd),
])

return obs

def _sample_waypoint(self, rng: jax.Array) -> jax.Array:
x_range = [-25, 25]
y_range = [-25, 25]
z_range = [0, 2]

_, key1, key2, key3 = jax.random.split(rng, 4)

x = jax.random.uniform(
key1, (1,), minval=x_range[0], maxval=x_range[1]
)

y = jax.random.uniform(
key2, (1,), minval=y_range[0], maxval=y_range[1]
)

z = jax.random.uniform(
key3, (1,), minval=z_range[0], maxval=z_range[1]
)

wcmd = jp.array([x[0], y[0]])

return wcmd

def _sample_command(self, rng: jax.Array) -> jax.Array:
lin_vel_x_range = [-3.0, 3.0] #[m/s]
lin_vel_y_range = [-3.0, 3.0] #[m/s]
yaw_vel_range = [-1.0, 1.0] #[rad/s]

_, key1, key2, key3 = jax.random.split(rng, 4)

lin_vel_x = jax.random.uniform(
key1, (1,), minval=lin_vel_x_range[0], maxval=lin_vel_x_range[1]
)

lin_vel_y = jax.random.uniform(
key2, (1,), minval=lin_vel_y_range[0], maxval=lin_vel_y_range[1]
)

yaw_vel = jax.random.uniform(
key3, (1,), minval=yaw_vel_range[0], maxval=yaw_vel_range[1]
)

jcmd = jp.array([lin_vel_x[0], lin_vel_y[0], yaw_vel[0]])

return jcmd
94 changes: 94 additions & 0 deletions alfredo/agents/aant/aant.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
<agent>

<custom>
<numeric data="0.0 0.0 0.75 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
<numeric data="1000" name="constraint_limit_stiffness"/>
<numeric data="4000" name="constraint_stiffness"/>
<numeric data="10" name="constraint_ang_damping"/>
<numeric data="20" name="constraint_vel_damping"/>
<numeric data="0.5" name="joint_scale_pos"/>
<numeric data="0.2" name="joint_scale_ang"/>
<numeric data="0.0" name="ang_damping"/>
<numeric data="1" name="spring_mass_scale"/>
<numeric data="1" name="spring_inertia_scale"/>
<numeric data="15" name="solver_maxls"/>
</custom>

<body name="torso" pos="0 0 1.0">
<camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
<geom name="torso_geom" contype="1" pos="0 0 0" size="0.25" type="sphere"/>
<joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/>

<body name="front_left_leg" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux_1_geom" size="0.08" type="capsule"/>

<body name="aux_1" pos="0.2 0.2 0">
<joint axis="0 0 1" name="hip_1" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom" size="0.08" type="capsule"/>

<body pos="0.2 0.2 0">
<joint axis="-1 1 0" name="ankle_1" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom" size="0.08" type="capsule"/>
<geom name="left_foot_geom" contype="1" pos="0.4 0.4 0" size="0.08" type="sphere" mass="0"/>
</body>
</body>
</body>

<body name="front_right_leg" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>

<body name="aux_2" pos="-0.2 0.2 0">
<joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>

<body pos="-0.2 0.2 0">
<joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
<geom name="right_foot_geom" contype="1" pos="-0.4 0.4 0" size="0.08" type="sphere" mass="0"/>
</body>
</body>
</body>

<body name="back_leg" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>

<body name="aux_3" pos="-0.2 -0.2 0">
<joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>

<body pos="-0.2 -0.2 0">
<joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
<geom name="third_foot_geom" contype="1" pos="-0.4 -0.4 0" size="0.08" type="sphere" mass="0"/>
</body>
</body>
</body>

<body name="right_back_leg" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux_4_geom" size="0.08" type="capsule"/>

<body name="aux_4" pos="0.2 -0.2 0">
<joint axis="0 0 1" name="hip_4" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom" size="0.08" type="capsule"/>

<body pos="0.2 -0.2 0">
<joint axis="1 1 0" name="ankle_4" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="fourth_ankle_geom" size="0.08" type="capsule"/>
<geom name="fourth_foot_geom" contype="1" pos="0.4 -0.4 0" size="0.08" type="sphere" mass="0"/>
</body>
</body>
</body>
</body>

<actuator>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_4" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_4" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_1" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_1" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_2" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_2" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_3" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_3" gear="150"/>
</actuator>

</agent>
5 changes: 4 additions & 1 deletion alfredo/rewards/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .reward import Reward

from .rConstant import *
from .rSpeed import *
from .rHealthy import *
from .rControl import *
from .rEnergy import *
from .rOrientation import *
7 changes: 3 additions & 4 deletions alfredo/rewards/rConstant.py
Original file line number Diff line number Diff line change
@@ -8,8 +8,7 @@
from jax import numpy as jp

def rConstant(sys: base.System,
pipeline_state: base.State,
weight=1.0,
focus_idx_range=(1, -1)) -> jp.ndarray:
pipeline_state: base.State,
focus_idx_range=(0, -1)) -> jax.Array:

return jp.array([weight])
return jp.array([1.0])
56 changes: 53 additions & 3 deletions alfredo/rewards/rControl.py
Original file line number Diff line number Diff line change
@@ -10,9 +10,59 @@
def rControl_act_ss(sys: base.System,
pipeline_state: base.State,
action: jp.ndarray,
weight=1.0,
focus_idx_range=(1, -1)) -> jp.ndarray:
focus_idx_range=(0, -1)) -> jax.Array:

ctrl_cost = weight * jp.sum(jp.square(action))

return ctrl_cost
return jp.array([ctrl_cost])

def rTracking_lin_vel(sys: base.System,
pipeline_state: base.State,
jcmd: jax.Array,
sigma=0.25,
focus_idx_range=(0, 1)) -> jax.Array:

local_vel = math.rotate(pipeline_state.xd.vel[0],
math.quat_inv(pipeline_state.x.rot[0]))

lv_error = jp.sum(jp.square(jcmd[:2] - local_vel[:2])) # just taking a look at x, y velocities
lv_reward = jp.exp(-lv_error/sigma)

return jp.array([lv_reward])

def rTracking_yaw_vel(sys: base.System,
pipeline_state: base.State,
jcmd: jax.Array,
sigma=0.25,
focus_idx_range=(0, 1)) -> jax.Array:

local_ang_vel = math.rotate(pipeline_state.xd.ang[0],
math.quat_inv(pipeline_state.x.rot[0]))

yaw_vel_error = jp.square(jcmd[2] - local_ang_vel[2])
yv_reward = jp.exp(-yaw_vel_error/sigma)

return jp.array([yv_reward])


def rTracking_Waypoint(sys: base.System,
pipeline_state: base.State,
wcmd: jax.Array,
focus_idx_range=0) -> jax.Array:

torso_pos = pipeline_state.x.pos[focus_idx_range]
pos_goal_diff = torso_pos[0:2] - waypoint[0:2]
pos_sum_abs_diff = -jp.sum(jp.abs(pos_goal_diff))

return jp.array([pos_sum_abs_diff])

def rStand_still(sys: base.System,
pipeline_state: base.State,
jcmd: jax.Array,
joint_angles: jax.Array,
default_pose: jax.Array,
focus_idx_range=0) -> jax.Array:

close_to_still = jp.sum(jp.abs(joint_angles - default_pose)) * math.normalize(jcmd[:2])[1] < 0.1

return jp.array([close_to_still])
27 changes: 27 additions & 0 deletions alfredo/rewards/rEnergy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Tuple

import jax
from brax import actuator, base, math
from brax.envs import PipelineEnv, State
from brax.io import mjcf
from etils import epath
from jax import numpy as jp


def rTorques(sys: base.System,
pipeline_state: base.State,
action: jp.ndarray,
focus_idx_range=(0, -1)) -> jax.Array:

s_idx = focus_idx_range[0]
e_idx = focus_idx_range[1]

torque = actuator.to_tau(sys,
action,
pipeline_state.q[s_idx:e_idx],
pipeline_state.qd[s_idx:e_idx])


tr = jp.sqrt(jp.sum(jp.square(torque))) + jp.sum(jp.abs(torque))

return jp.array([tr])
9 changes: 4 additions & 5 deletions alfredo/rewards/rHealthy.py
Original file line number Diff line number Diff line change
@@ -11,9 +11,8 @@ def rHealthy_simple_z(sys: base.System,
pipeline_state: base.State,
z_range: Tuple,
early_terminate: True,
weight=1.0,
focus_idx_range=(1, -1)) -> jp.ndarray:

focus_idx_range=(0, -1)) -> jax.Array:

min_z, max_z = z_range
focus_s = focus_idx_range[0]
focus_e = focus_idx_range[-1]
@@ -24,8 +23,8 @@ def rHealthy_simple_z(sys: base.System,
is_healthy = jp.where(focus_x_pos > max_z, x=0.0, y=is_healthy)

if early_terminate:
hr = weight
hr = 1.0
else:
hr = weight * is_healthy
hr = 1.0 * is_healthy

return jp.array([hr, is_healthy])
17 changes: 17 additions & 0 deletions alfredo/rewards/rOrientation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Tuple

import jax
from brax import actuator, base, math
from brax.envs import PipelineEnv, State
from brax.io import mjcf
from etils import epath
from jax import numpy as jp

def rUpright(sys: base.System,
pipeline_state: base.State,
focus_idx_range = (0,0)) -> jax.Array:

up = jp.array([0.0, 0.0, 1.0])
rot_up = math.rotate(up, pipeline_state.x.rot[0])

return jp.dot(up, rot_up)
44 changes: 0 additions & 44 deletions alfredo/rewards/rSpeed.py

This file was deleted.

41 changes: 41 additions & 0 deletions alfredo/rewards/reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from brax import base
from jax import numpy as jp

class Reward:

def __init__(self, f, sc, ps):
"""
:param f: A function handle (ie. function that computes this reward)
:param sc: A float that gets multiplied to base computation provided by f
:param ps: A dictionary of parameters required for the reward computation
"""

self.f = f
self.scale = sc
self.params = ps

def add_param(self, p_name, p_value):
"""
Updates self.params dictionary with provided key and value
"""

self.params[p_name] = p_value

def compute(self):
"""
computes reward as specified by self.f given
scale and general parameters are set.
Otherwise, this errors out quite spectacularly
"""

res = self.f(**self.params)
res = res.at[0].multiply(self.scale) #may not be the best way to do this

return res

def __str__(self):
"""
provides a standard string output
"""

return f'reward: {self.f}, scale: {self.scale}'
16 changes: 1 addition & 15 deletions alfredo/scenes/flatworld/flatworld_A1_env.xml
Original file line number Diff line number Diff line change
@@ -8,21 +8,7 @@
<motor ctrllimited="true" ctrlrange="-.4 .4" />
</default>

<option iterations="8" timestep="0.003" />

<custom>
<numeric data="2500" name="constraint_limit_stiffness" />
<numeric data="27000" name="constraint_stiffness" />
<numeric data="30" name="constraint_ang_damping" />
<numeric data="80" name="constraint_vel_damping" />
<numeric data="-0.05" name="ang_damping" />
<numeric data="0.5" name="joint_scale_pos" />
<numeric data="0.1" name="joint_scale_ang" />
<numeric data="0" name="spring_mass_scale" />
<numeric data="1" name="spring_inertia_scale" />
<numeric data="20" name="matrix_inv_iterations" />
<numeric data="15" name="solver_maxls" />
</custom>
<option iterations="30" timestep="0.002" />

<size nkey="5" nuser_geom="1" />

1 change: 1 addition & 0 deletions alfredo/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .tXMLCompose import *
from .tAnalyzeNetwork import *
63 changes: 63 additions & 0 deletions alfredo/tools/tAnalyzeNetwork.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import jax
import jax.numpy as jp

def analyze_neural_params(params):
"""
This function provides a metric summary of input neural parameter file
Structure of contents are included in a tuple where index:
1. RunningStatisticState
2. FrozenDict(Policy Network Params)
3. FrozenDict(Value Network Params)
"""

summary = {}

summary['Running Statistics'] = params[0]
summary['Policy Network'] = analyze_one_neural_network(params[1])
summary['Value Network'] = analyze_one_neural_network(params[2])

return summary


def analyze_one_neural_network(sn_params):
"""
Helper function that unpacks a single network parameters
Assuming the contents of these parameters are provided as:
FrozenDict({params: {'hidden_x': {bias: Jax.Array, kernel: Jax.Array}}})
where x in hidden_x represents order of hidden layer
(but these also include input and output layers?)
"""

summary = {}

num_layers = 0
per_layer_info = {}
total_parameters = 0

for k, v in sn_params.items():

for layer_name, layer_data in v.items():

num_layers += 1
param_count = 0

for param_name, param_data in layer_data.items():

if param_name == 'bias':
per_layer_info[layer_name] = {'size_of_layer': len(param_data)}
param_count += len(param_data)
# print(len(param_data))

if param_name == 'kernel':
param_count += len(param_data)*len(param_data[0])
# print(len(param_data)*len(param_data[0]))

total_parameters += param_count

summary['num_layers'] = num_layers
summary['per_layer_info'] = per_layer_info
summary['total_parameters'] = total_parameters

return summary
3 changes: 3 additions & 0 deletions alfredo/tools/tXMLCompose.py
Original file line number Diff line number Diff line change
@@ -15,10 +15,13 @@ def compose_scene(xml_env, xml_agent):
worldbody = env_root.find('worldbody')

ag_root = agent_tree.getroot()
ag_custom = ag_root.find('custom')
ag_body = ag_root.find('body')
ag_actuator = ag_root.find('actuator')


worldbody.append(ag_body)
env_root.append(ag_custom)
env_root.append(ag_actuator)

beautify(env_root)
148 changes: 148 additions & 0 deletions experiments/AAnt-locomotion/one_physics_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import functools
import os
import re
import sys
from datetime import datetime

import brax
import jax
import matplotlib.pyplot as plt
from brax import envs, math
from brax.envs.wrappers import training
from brax.io import html, json, model
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as ppo_networks
from jax import numpy as jp

from alfredo.agents.aant import AAnt

from alfredo.tools import analyze_neural_params

from alfredo.rewards import Reward
from alfredo.rewards import rConstant
from alfredo.rewards import rHealthy_simple_z
from alfredo.rewards import rControl_act_ss
from alfredo.rewards import rTorques
from alfredo.rewards import rTracking_lin_vel
from alfredo.rewards import rTracking_yaw_vel
from alfredo.rewards import rUpright
from alfredo.rewards import rTracking_Waypoint
from alfredo.rewards import rStand_still

backend = "positional"

# Load desired model xml and trained param set
# get filepaths from commandline args
cwd = os.getcwd()

# Define reward structure
rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=15.0, ps={}),
'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.2, ps={})}

print(rewards)

# Get the filepath to the env and agent xmls
import alfredo.scenes as scenes
import alfredo.agents as agents
agents_fp = os.path.dirname(agents.__file__)
agent_xml_path = f"{agents_fp}/aant/aant.xml"

scenes_fp = os.path.dirname(scenes.__file__)
env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
tpf_path = f"{cwd}/{sys.argv[-1]}"

print(f"agent description file: {agent_xml_path}")
print(f"environment description file: {env_xml_paths[0]}")
print(f"neural parameter file: {tpf_path}\n")

# Load neural parameters
params = model.load_params(tpf_path)
summary = analyze_neural_params(params)
print(f"summary: {summary}\n")

# create an env and initial state
env = AAnt(backend=backend,
rewards=rewards,
env_xml_path=env_xml_paths[0],
agent_xml_path=agent_xml_path)

rng = jax.random.PRNGKey(seed=3)
state = env.reset(rng=rng)
#state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

# Initialize inference neural network
normalize = lambda x, y: x
normalize = running_statistics.normalize

ppo_network = ppo_networks.make_ppo_networks(
state.obs.shape[-1], env.action_size, preprocess_observations_fn=normalize
)

make_policy = ppo_networks.make_inference_fn(ppo_network)
policy_params = (params[0], params[1])
inference_fn = make_policy(policy_params)

# Reset the env
key_envs, _ = jax.random.split(rng)
state = env.reset(rng=key_envs)
#state = jax.jit(env.reset)(rng=key_envs)

# Debug printing
print(f"q: {state.pipeline_state.q}")
print(f"\n")
print(f"qd: {state.pipeline_state.qd}")
print(f"\n")
print(f"x: {state.pipeline_state.x}")
print(f"\n")
print(f"xd: {state.pipeline_state.xd}")
print(f"\n")
print(f"contact: {state.pipeline_state.contact}")
print(f"\n")
print(f"reward: {state.reward}")
print(f"\n")
print(state.metrics)
print(f"\n")
print(f"done: {state.done}")
print(f"\n")

# Rollout trajectory one physics step at a time
episode_length = 1

jcmd = jp.array([1.0, 0.0, -0.7])

for _ in range(episode_length):
print(f"\n---------------------------------------------------------------\n")

act_rng, rng = jax.random.split(rng)
print(f"rng: {rng}")
print(f"act_rng: {act_rng}")
print(f"\n")

state.info['jcmd'] = jcmd
print(f"state info: {state.info}")
print(f"\n")

act, _ = inference_fn(state.obs, act_rng)
print(f"observation: {state.obs}")
print(f"\n")
print(f"action: {act}")
print(f"\n")

state = env.step(state, act)

print(f"q: {state.pipeline_state.q}")
print(f"\n")
print(f"qd: {state.pipeline_state.qd}")
print(f"\n")
print(f"x: {state.pipeline_state.x}")
print(f"\n")
print(f"xd: {state.pipeline_state.xd}")
print(f"\n")
print(f"contact: {state.pipeline_state.contact}")
print(f"\n")
print(f"reward: {state.reward}")
print(f"\n")
print(state.metrics)
print(f"\n")
print(f"done: {state.done}")
print(f"\n")
183 changes: 183 additions & 0 deletions experiments/AAnt-locomotion/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import functools
import os
import re
from datetime import datetime

import brax
import flax
import jax
import optax
import wandb
from brax import envs

from brax.io import html, json, model
from brax.training.acme import running_statistics, specs
from brax.training.agents.ppo import losses as ppo_losses
from brax.training.agents.ppo import networks as ppo_networks
from jax import numpy as jp

from alfredo.agents.aant import AAnt
from alfredo.train import ppo

from alfredo.rewards import Reward
from alfredo.rewards import rTracking_lin_vel
from alfredo.rewards import rTracking_yaw_vel


# Define Reward Structure
rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=8.0, ps={}),
'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.0, ps={})}

# Initialize a new run
wandb.init(
project="aant",
config={
"env_name": "AAnt",
"seed": 13,

"training_params": {
"backend": "positional",
"len_training": 1_500_000,
"num_evals": 500,
"num_envs": 2048,
"batch_size": 2048,
"num_minibatches": 8,
"updates_per_batch": 8,
"episode_len": 1000,
"unroll_len": 10,
"reward_scaling":1,
"action_repeat": 1,
"discounting": 0.97,
"learning_rate": 3e-4,
"entropy_cost": 1e-3,
"reward_scaling": 0.1,
"normalize_obs": True,
},

"rewards":rewards,

"aux_model_params":{

}
},
)

# define callback function that will report training progress
def progress(num_steps, metrics):
print(num_steps)
print(metrics)

epi_len = wandb.config.training_params['episode_len']

log_dict = {'step': num_steps}

for mn, m in metrics.items():
name_in_log = mn.split('/')[-1]
log_dict[name_in_log] = m/epi_len

wandb.log(log_dict)


# get the filepath to the env and agent xmls
cwd = os.getcwd()

import alfredo.scenes as scenes
import alfredo.agents as agents
agents_fp = os.path.dirname(agents.__file__)
agent_xml_path = f"{agents_fp}/aant/aant.xml"

scenes_fp = os.path.dirname(scenes.__file__)

env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]

# make and save initial ppo_network
key = jax.random.PRNGKey(wandb.config.seed)
global_key, local_key = jax.random.split(key)
key_policy, key_value = jax.random.split(global_key)

env = AAnt(backend=wandb.config.training_params['backend'],
rewards=rewards,
env_xml_path=env_xml_paths[0],
agent_xml_path=agent_xml_path)

rng = jax.random.PRNGKey(seed=0)
state = env.reset(rng)

normalize_fn = running_statistics.normalize

ppo_network = ppo_networks.make_ppo_networks(
env.observation_size, env.action_size, normalize_fn
)

init_params = ppo_losses.PPONetworkParams(
policy=ppo_network.policy_network.init(key_policy),
value=ppo_network.value_network.init(key_value),
)

normalizer_params = running_statistics.init_state(
specs.Array(env.observation_size, jp.float32)
)

params_to_save = (normalizer_params, init_params.policy, init_params.value)

model.save_params(f"param-store/AAnt_params_0", params_to_save)

# ============================
# Training & Saving Params
# ============================
i = 0

for p in env_xml_paths:

d_and_t = datetime.now()
print(f"[{d_and_t}] loop start for model: {i}")
env = AAnt(backend=wandb.config.training_params['backend'],
rewards=rewards,
env_xml_path=p,
agent_xml_path=agent_xml_path)

mF = f"{cwd}/param-store/{wandb.config.env_name}_params_{i}"
mParams = model.load_params(mF)

d_and_t = datetime.now()
print(f"[{d_and_t}] jitting start for model: {i}")
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=wandb.config.seed))
d_and_t = datetime.now()
print(f"[{d_and_t}] jitting end for model: {i}")

# define new training function
train_fn = functools.partial(
ppo.train,
num_timesteps=wandb.config.training_params['len_training'],
num_evals=wandb.config.training_params['num_evals'],
reward_scaling=wandb.config.training_params['reward_scaling'],
episode_length=wandb.config.training_params['episode_len'],
normalize_observations=wandb.config.training_params['normalize_obs'],
action_repeat=wandb.config.training_params['action_repeat'],
unroll_length=wandb.config.training_params['unroll_len'],
num_minibatches=wandb.config.training_params['num_minibatches'],
num_updates_per_batch=wandb.config.training_params['updates_per_batch'],
discounting=wandb.config.training_params['discounting'],
learning_rate=wandb.config.training_params['learning_rate'],
entropy_cost=wandb.config.training_params['entropy_cost'],
num_envs=wandb.config.training_params['num_envs'],
batch_size=wandb.config.training_params['batch_size'],
seed=wandb.config.seed,
in_params=mParams,
)

d_and_t = datetime.now()
print(f"[{d_and_t}] training start for model: {i}")
_, params, _, ts = train_fn(environment=env, progress_fn=progress)
d_and_t = datetime.now()
print(f"[{d_and_t}] training end for model: {i}")

i += 1
next_m_name = f"param-store/{wandb.config.env_name}_params_{i}"
model.save_params(next_m_name, params)

d_and_t = datetime.now()
print(f"[{d_and_t}] loop end for model: {i}")
Original file line number Diff line number Diff line change
@@ -14,7 +14,11 @@
from brax.training.agents.ppo import networks as ppo_networks
from jax import numpy as jp

from alfredo.agents.A1.alfredo_1 import Alfredo
from alfredo.agents.aant import AAnt

from alfredo.rewards import Reward
from alfredo.rewards import rTracking_lin_vel
from alfredo.rewards import rTracking_yaw_vel

backend = "positional"

@@ -24,10 +28,9 @@

# get the filepath to the env and agent xmls
import alfredo.scenes as scenes

import alfredo.agents as agents
agents_fp = os.path.dirname(agents.__file__)
agent_xml_path = f"{agents_fp}/A1/a1.xml"
agent_xml_path = f"{agents_fp}/aant/aant.xml"

scenes_fp = os.path.dirname(scenes.__file__)

@@ -40,26 +43,33 @@

params = model.load_params(tpf_path)

# Define Reward Structure
# For visualizing, this is just to be able to create the env
# May want to make this not necessary in the future ..?
rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=8.0, ps={}),
'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.0, ps={})}

# create an env with auto-reset and load previously trained parameters
env = Alfredo(backend=backend,
env_xml_path=env_xml_path,
agent_xml_path=agent_xml_path)
env = AAnt(backend=backend,
rewards=rewards,
env_xml_path=env_xml_path,
agent_xml_path=agent_xml_path)

auto_reset = True
episode_length = 1000
action_repeat = 1

if episode_length is not None:
env = training.EpisodeWrapper(env, episode_length, action_repeat)
#if episode_length is not None:
# env = training.EpisodeWrapper(env, episode_length, action_repeat)

if auto_reset:
env = training.AutoResetWrapper(env)
#if auto_reset:
# env = training.AutoResetWrapper(env)

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)

rollout = []
rng = jax.random.PRNGKey(seed=1)
rng = jax.random.PRNGKey(seed=13194)
state = jit_env_reset(rng=rng)

normalize = lambda x, y: x
@@ -75,12 +85,26 @@

jit_inference_fn = jax.jit(inference_fn)

x_vel = 0.0 # m/s
y_vel = 3.0 # m/s
yaw_vel = 0.0 # rad/s
jcmd = jp.array([x_vel, y_vel, yaw_vel])

wcmd = jp.array([10.0, 10.0])

# generate policy rollout
for _ in range(episode_length):
rollout.append(state.pipeline_state)
act_rng, rng = jax.random.split(rng)

state.info['jcmd'] = jcmd
state.info['wcmd'] = wcmd
act, _ = jit_inference_fn(state.obs, act_rng)
state = jit_env_step(state, act)
print(state.info)


print(rollout[-1])

html_string = html.render(env.sys.replace(dt=env.dt), rollout)

31 changes: 0 additions & 31 deletions experiments/Alfredo-simple-walk/README.md

This file was deleted.

164 changes: 0 additions & 164 deletions experiments/Alfredo-simple-walk/seq_training.py

This file was deleted.

54 changes: 0 additions & 54 deletions experiments/Alfredo-simple-walk/vis_new_model.py

This file was deleted.

57 changes: 0 additions & 57 deletions experiments/Alfredo-simulate-step/one_physics_step.py

This file was deleted.

0 comments on commit 06c8e23

Please sign in to comment.