-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #35 from tartavull/more-rewards-joystick
More rewards joystick + reward structure organization
Showing
26 changed files
with
903 additions
and
737 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from . import A1 | ||
from . import aant |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .aant import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
from brax import base | ||
from brax import math | ||
from brax.envs.base import PipelineEnv, State | ||
from brax.io import mjcf | ||
from etils import epath | ||
import jax | ||
from jax import numpy as jp | ||
|
||
from alfredo.tools import compose_scene | ||
|
||
class AAnt(PipelineEnv): | ||
""" """ | ||
|
||
def __init__(self, | ||
rewards = {}, | ||
env_xml_path = "", | ||
agent_xml_path = "", | ||
terminate_when_unhealthy=False, | ||
reset_noise_scale=0.1, | ||
exclude_current_positions_from_observation=False, | ||
backend='generalized', | ||
**kwargs,): | ||
|
||
# env_xml_path and agent_xml_path must be provided | ||
if env_xml_path and agent_xml_path: | ||
self._env_xml_path = env_xml_path | ||
self._agent_xml_path = agent_xml_path | ||
|
||
xml_scene = compose_scene(self._env_xml_path, self._agent_xml_path) | ||
sys = mjcf.loads(xml_scene) | ||
else: | ||
raise Exception("env_xml_path & agent_xml_path both must be provided") | ||
|
||
# reward dictionary must be provided | ||
if rewards: | ||
self._rewards = rewards | ||
else: | ||
raise Exception("reward_Structure must be in kwargs") | ||
|
||
# TODO: clean this up in the future & | ||
# make n_frames a function of input dt | ||
n_frames = 5 | ||
|
||
if backend in ['spring', 'positional']: | ||
sys = sys.replace(dt=0.005) | ||
n_frames = 10 | ||
|
||
if backend == 'positional': | ||
# TODO: does the same actuator strength work as in spring | ||
sys = sys.replace( | ||
actuator=sys.actuator.replace( | ||
gear=200 * jp.ones_like(sys.actuator.gear) | ||
) | ||
) | ||
|
||
kwargs['n_frames'] = kwargs.get('n_frames', n_frames) | ||
|
||
# Initialize the superclass "PipelineEnv" | ||
super().__init__(sys=sys, backend=backend, **kwargs) | ||
|
||
# Setting other object parameters based on input params | ||
self._terminate_when_unhealthy = terminate_when_unhealthy | ||
self._reset_noise_scale = reset_noise_scale | ||
self._exclude_current_positions_from_observation = ( | ||
exclude_current_positions_from_observation | ||
) | ||
|
||
|
||
def reset(self, rng: jax.Array) -> State: | ||
|
||
rng, rng1, rng2, rng3 = jax.random.split(rng, 4) | ||
low, hi = -self._reset_noise_scale, self._reset_noise_scale | ||
|
||
# initialize position vector with minor randomization in pose | ||
q = self.sys.init_q + jax.random.uniform( | ||
rng1, (self.sys.q_size(),), minval=low, maxval=hi | ||
) | ||
|
||
# initialize velocity vector with minor randomization | ||
qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),)) | ||
|
||
# generate sample commands | ||
jcmd = self._sample_command(rng3) | ||
wcmd = jp.array([0.0, 0.0]) | ||
|
||
# initialize pipeline_state (the physics state) | ||
pipeline_state = self.pipeline_init(q, qd) | ||
|
||
# reset values and metrics | ||
reward, done, zero = jp.zeros(3) | ||
|
||
state_info = { | ||
'jcmd':jcmd, | ||
'wcmd':wcmd, | ||
'rewards': {k: 0.0 for k in self._rewards.keys()}, | ||
'step': 0, | ||
} | ||
|
||
metrics = {'pos_x_world_abs': zero, | ||
'pos_y_world_abs': zero, | ||
'pos_z_world_abs': zero,} | ||
|
||
for rn, r in self._rewards.items(): | ||
metrics[rn] = state_info['rewards'][rn] | ||
|
||
# get initial observation vector | ||
obs = self._get_obs(pipeline_state, state_info) | ||
|
||
return State(pipeline_state, obs, reward, done, metrics, state_info) | ||
|
||
def step(self, state: State, action: jax.Array) -> State: | ||
"""Run one timestep of the environment's dynamics.""" | ||
|
||
# Save the previous physics state and step physics forward | ||
pipeline_state0 = state.pipeline_state | ||
pipeline_state = self.pipeline_step(pipeline_state0, action) | ||
|
||
# Add all additional parameters to compute rewards | ||
self._rewards['r_lin_vel'].add_param('jcmd', state.info['jcmd']) | ||
self._rewards['r_yaw_vel'].add_param('jcmd', state.info['jcmd']) | ||
|
||
# Compute all rewards and accumulate total reward | ||
total_reward = 0.0 | ||
for rn, r in self._rewards.items(): | ||
r.add_param('sys', self.sys) | ||
r.add_param('pipeline_state', pipeline_state) | ||
|
||
reward_value = r.compute() | ||
state.info['rewards'][rn] = reward_value[0] | ||
total_reward += reward_value[0] | ||
# print(f'{rn} reward_val = {reward_value}\n') | ||
|
||
# Computing additional metrics as necessary | ||
pos_world = pipeline_state.x.pos[0] | ||
abs_pos_world = jp.abs(pos_world) | ||
|
||
# Compute observations | ||
obs = self._get_obs(pipeline_state, state.info) | ||
done = 0.0 | ||
|
||
# State management | ||
state.info['step'] += 1 | ||
|
||
state.metrics.update(state.info['rewards']) | ||
|
||
state.metrics.update( | ||
pos_x_world_abs = abs_pos_world[0], | ||
pos_y_world_abs = abs_pos_world[1], | ||
pos_z_world_abs = abs_pos_world[2], | ||
) | ||
|
||
return state.replace( | ||
pipeline_state=pipeline_state, obs=obs, reward=total_reward, done=done | ||
) | ||
|
||
def _get_obs(self, pipeline_state, state_info) -> jax.Array: | ||
"""Observe ant body position and velocities.""" | ||
qpos = pipeline_state.q | ||
qvel = pipeline_state.qd | ||
|
||
inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0]) | ||
local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot) | ||
torso_pos = pipeline_state.x.pos[0] | ||
|
||
jcmd = state_info['jcmd'] | ||
#wcmd = state_info['wcmd'] | ||
|
||
if self._exclude_current_positions_from_observation: | ||
qpos = pipeline_state.q[2:] | ||
|
||
obs = jp.concatenate([ | ||
jp.array(qpos), | ||
jp.array(qvel), | ||
jp.array(local_rpyrate), | ||
jp.array(jcmd), | ||
]) | ||
|
||
return obs | ||
|
||
def _sample_waypoint(self, rng: jax.Array) -> jax.Array: | ||
x_range = [-25, 25] | ||
y_range = [-25, 25] | ||
z_range = [0, 2] | ||
|
||
_, key1, key2, key3 = jax.random.split(rng, 4) | ||
|
||
x = jax.random.uniform( | ||
key1, (1,), minval=x_range[0], maxval=x_range[1] | ||
) | ||
|
||
y = jax.random.uniform( | ||
key2, (1,), minval=y_range[0], maxval=y_range[1] | ||
) | ||
|
||
z = jax.random.uniform( | ||
key3, (1,), minval=z_range[0], maxval=z_range[1] | ||
) | ||
|
||
wcmd = jp.array([x[0], y[0]]) | ||
|
||
return wcmd | ||
|
||
def _sample_command(self, rng: jax.Array) -> jax.Array: | ||
lin_vel_x_range = [-3.0, 3.0] #[m/s] | ||
lin_vel_y_range = [-3.0, 3.0] #[m/s] | ||
yaw_vel_range = [-1.0, 1.0] #[rad/s] | ||
|
||
_, key1, key2, key3 = jax.random.split(rng, 4) | ||
|
||
lin_vel_x = jax.random.uniform( | ||
key1, (1,), minval=lin_vel_x_range[0], maxval=lin_vel_x_range[1] | ||
) | ||
|
||
lin_vel_y = jax.random.uniform( | ||
key2, (1,), minval=lin_vel_y_range[0], maxval=lin_vel_y_range[1] | ||
) | ||
|
||
yaw_vel = jax.random.uniform( | ||
key3, (1,), minval=yaw_vel_range[0], maxval=yaw_vel_range[1] | ||
) | ||
|
||
jcmd = jp.array([lin_vel_x[0], lin_vel_y[0], yaw_vel[0]]) | ||
|
||
return jcmd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
<agent> | ||
|
||
<custom> | ||
<numeric data="0.0 0.0 0.75 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/> | ||
<numeric data="1000" name="constraint_limit_stiffness"/> | ||
<numeric data="4000" name="constraint_stiffness"/> | ||
<numeric data="10" name="constraint_ang_damping"/> | ||
<numeric data="20" name="constraint_vel_damping"/> | ||
<numeric data="0.5" name="joint_scale_pos"/> | ||
<numeric data="0.2" name="joint_scale_ang"/> | ||
<numeric data="0.0" name="ang_damping"/> | ||
<numeric data="1" name="spring_mass_scale"/> | ||
<numeric data="1" name="spring_inertia_scale"/> | ||
<numeric data="15" name="solver_maxls"/> | ||
</custom> | ||
|
||
<body name="torso" pos="0 0 1.0"> | ||
<camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/> | ||
<geom name="torso_geom" contype="1" pos="0 0 0" size="0.25" type="sphere"/> | ||
<joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/> | ||
|
||
<body name="front_left_leg" pos="0 0 0"> | ||
<geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux_1_geom" size="0.08" type="capsule"/> | ||
|
||
<body name="aux_1" pos="0.2 0.2 0"> | ||
<joint axis="0 0 1" name="hip_1" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/> | ||
<geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom" size="0.08" type="capsule"/> | ||
|
||
<body pos="0.2 0.2 0"> | ||
<joint axis="-1 1 0" name="ankle_1" pos="0.0 0.0 0.0" range="30 70" type="hinge"/> | ||
<geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom" size="0.08" type="capsule"/> | ||
<geom name="left_foot_geom" contype="1" pos="0.4 0.4 0" size="0.08" type="sphere" mass="0"/> | ||
</body> | ||
</body> | ||
</body> | ||
|
||
<body name="front_right_leg" pos="0 0 0"> | ||
<geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/> | ||
|
||
<body name="aux_2" pos="-0.2 0.2 0"> | ||
<joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/> | ||
<geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/> | ||
|
||
<body pos="-0.2 0.2 0"> | ||
<joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/> | ||
<geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/> | ||
<geom name="right_foot_geom" contype="1" pos="-0.4 0.4 0" size="0.08" type="sphere" mass="0"/> | ||
</body> | ||
</body> | ||
</body> | ||
|
||
<body name="back_leg" pos="0 0 0"> | ||
<geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/> | ||
|
||
<body name="aux_3" pos="-0.2 -0.2 0"> | ||
<joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/> | ||
<geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/> | ||
|
||
<body pos="-0.2 -0.2 0"> | ||
<joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/> | ||
<geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/> | ||
<geom name="third_foot_geom" contype="1" pos="-0.4 -0.4 0" size="0.08" type="sphere" mass="0"/> | ||
</body> | ||
</body> | ||
</body> | ||
|
||
<body name="right_back_leg" pos="0 0 0"> | ||
<geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux_4_geom" size="0.08" type="capsule"/> | ||
|
||
<body name="aux_4" pos="0.2 -0.2 0"> | ||
<joint axis="0 0 1" name="hip_4" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/> | ||
<geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom" size="0.08" type="capsule"/> | ||
|
||
<body pos="0.2 -0.2 0"> | ||
<joint axis="1 1 0" name="ankle_4" pos="0.0 0.0 0.0" range="30 70" type="hinge"/> | ||
<geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="fourth_ankle_geom" size="0.08" type="capsule"/> | ||
<geom name="fourth_foot_geom" contype="1" pos="0.4 -0.4 0" size="0.08" type="sphere" mass="0"/> | ||
</body> | ||
</body> | ||
</body> | ||
</body> | ||
|
||
<actuator> | ||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_4" gear="150"/> | ||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_4" gear="150"/> | ||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_1" gear="150"/> | ||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_1" gear="150"/> | ||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_2" gear="150"/> | ||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_2" gear="150"/> | ||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_3" gear="150"/> | ||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_3" gear="150"/> | ||
</actuator> | ||
|
||
</agent> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
from .reward import Reward | ||
|
||
from .rConstant import * | ||
from .rSpeed import * | ||
from .rHealthy import * | ||
from .rControl import * | ||
from .rEnergy import * | ||
from .rOrientation import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import Tuple | ||
|
||
import jax | ||
from brax import actuator, base, math | ||
from brax.envs import PipelineEnv, State | ||
from brax.io import mjcf | ||
from etils import epath | ||
from jax import numpy as jp | ||
|
||
|
||
def rTorques(sys: base.System, | ||
pipeline_state: base.State, | ||
action: jp.ndarray, | ||
focus_idx_range=(0, -1)) -> jax.Array: | ||
|
||
s_idx = focus_idx_range[0] | ||
e_idx = focus_idx_range[1] | ||
|
||
torque = actuator.to_tau(sys, | ||
action, | ||
pipeline_state.q[s_idx:e_idx], | ||
pipeline_state.qd[s_idx:e_idx]) | ||
|
||
|
||
tr = jp.sqrt(jp.sum(jp.square(torque))) + jp.sum(jp.abs(torque)) | ||
|
||
return jp.array([tr]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from typing import Tuple | ||
|
||
import jax | ||
from brax import actuator, base, math | ||
from brax.envs import PipelineEnv, State | ||
from brax.io import mjcf | ||
from etils import epath | ||
from jax import numpy as jp | ||
|
||
def rUpright(sys: base.System, | ||
pipeline_state: base.State, | ||
focus_idx_range = (0,0)) -> jax.Array: | ||
|
||
up = jp.array([0.0, 0.0, 1.0]) | ||
rot_up = math.rotate(up, pipeline_state.x.rot[0]) | ||
|
||
return jp.dot(up, rot_up) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from brax import base | ||
from jax import numpy as jp | ||
|
||
class Reward: | ||
|
||
def __init__(self, f, sc, ps): | ||
""" | ||
:param f: A function handle (ie. function that computes this reward) | ||
:param sc: A float that gets multiplied to base computation provided by f | ||
:param ps: A dictionary of parameters required for the reward computation | ||
""" | ||
|
||
self.f = f | ||
self.scale = sc | ||
self.params = ps | ||
|
||
def add_param(self, p_name, p_value): | ||
""" | ||
Updates self.params dictionary with provided key and value | ||
""" | ||
|
||
self.params[p_name] = p_value | ||
|
||
def compute(self): | ||
""" | ||
computes reward as specified by self.f given | ||
scale and general parameters are set. | ||
Otherwise, this errors out quite spectacularly | ||
""" | ||
|
||
res = self.f(**self.params) | ||
res = res.at[0].multiply(self.scale) #may not be the best way to do this | ||
|
||
return res | ||
|
||
def __str__(self): | ||
""" | ||
provides a standard string output | ||
""" | ||
|
||
return f'reward: {self.f}, scale: {self.scale}' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .tXMLCompose import * | ||
from .tAnalyzeNetwork import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import jax | ||
import jax.numpy as jp | ||
|
||
def analyze_neural_params(params): | ||
""" | ||
This function provides a metric summary of input neural parameter file | ||
Structure of contents are included in a tuple where index: | ||
1. RunningStatisticState | ||
2. FrozenDict(Policy Network Params) | ||
3. FrozenDict(Value Network Params) | ||
""" | ||
|
||
summary = {} | ||
|
||
summary['Running Statistics'] = params[0] | ||
summary['Policy Network'] = analyze_one_neural_network(params[1]) | ||
summary['Value Network'] = analyze_one_neural_network(params[2]) | ||
|
||
return summary | ||
|
||
|
||
def analyze_one_neural_network(sn_params): | ||
""" | ||
Helper function that unpacks a single network parameters | ||
Assuming the contents of these parameters are provided as: | ||
FrozenDict({params: {'hidden_x': {bias: Jax.Array, kernel: Jax.Array}}}) | ||
where x in hidden_x represents order of hidden layer | ||
(but these also include input and output layers?) | ||
""" | ||
|
||
summary = {} | ||
|
||
num_layers = 0 | ||
per_layer_info = {} | ||
total_parameters = 0 | ||
|
||
for k, v in sn_params.items(): | ||
|
||
for layer_name, layer_data in v.items(): | ||
|
||
num_layers += 1 | ||
param_count = 0 | ||
|
||
for param_name, param_data in layer_data.items(): | ||
|
||
if param_name == 'bias': | ||
per_layer_info[layer_name] = {'size_of_layer': len(param_data)} | ||
param_count += len(param_data) | ||
# print(len(param_data)) | ||
|
||
if param_name == 'kernel': | ||
param_count += len(param_data)*len(param_data[0]) | ||
# print(len(param_data)*len(param_data[0])) | ||
|
||
total_parameters += param_count | ||
|
||
summary['num_layers'] = num_layers | ||
summary['per_layer_info'] = per_layer_info | ||
summary['total_parameters'] = total_parameters | ||
|
||
return summary |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import functools | ||
import os | ||
import re | ||
import sys | ||
from datetime import datetime | ||
|
||
import brax | ||
import jax | ||
import matplotlib.pyplot as plt | ||
from brax import envs, math | ||
from brax.envs.wrappers import training | ||
from brax.io import html, json, model | ||
from brax.training.acme import running_statistics | ||
from brax.training.agents.ppo import networks as ppo_networks | ||
from jax import numpy as jp | ||
|
||
from alfredo.agents.aant import AAnt | ||
|
||
from alfredo.tools import analyze_neural_params | ||
|
||
from alfredo.rewards import Reward | ||
from alfredo.rewards import rConstant | ||
from alfredo.rewards import rHealthy_simple_z | ||
from alfredo.rewards import rControl_act_ss | ||
from alfredo.rewards import rTorques | ||
from alfredo.rewards import rTracking_lin_vel | ||
from alfredo.rewards import rTracking_yaw_vel | ||
from alfredo.rewards import rUpright | ||
from alfredo.rewards import rTracking_Waypoint | ||
from alfredo.rewards import rStand_still | ||
|
||
backend = "positional" | ||
|
||
# Load desired model xml and trained param set | ||
# get filepaths from commandline args | ||
cwd = os.getcwd() | ||
|
||
# Define reward structure | ||
rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=15.0, ps={}), | ||
'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.2, ps={})} | ||
|
||
print(rewards) | ||
|
||
# Get the filepath to the env and agent xmls | ||
import alfredo.scenes as scenes | ||
import alfredo.agents as agents | ||
agents_fp = os.path.dirname(agents.__file__) | ||
agent_xml_path = f"{agents_fp}/aant/aant.xml" | ||
|
||
scenes_fp = os.path.dirname(scenes.__file__) | ||
env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"] | ||
tpf_path = f"{cwd}/{sys.argv[-1]}" | ||
|
||
print(f"agent description file: {agent_xml_path}") | ||
print(f"environment description file: {env_xml_paths[0]}") | ||
print(f"neural parameter file: {tpf_path}\n") | ||
|
||
# Load neural parameters | ||
params = model.load_params(tpf_path) | ||
summary = analyze_neural_params(params) | ||
print(f"summary: {summary}\n") | ||
|
||
# create an env and initial state | ||
env = AAnt(backend=backend, | ||
rewards=rewards, | ||
env_xml_path=env_xml_paths[0], | ||
agent_xml_path=agent_xml_path) | ||
|
||
rng = jax.random.PRNGKey(seed=3) | ||
state = env.reset(rng=rng) | ||
#state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0)) | ||
|
||
# Initialize inference neural network | ||
normalize = lambda x, y: x | ||
normalize = running_statistics.normalize | ||
|
||
ppo_network = ppo_networks.make_ppo_networks( | ||
state.obs.shape[-1], env.action_size, preprocess_observations_fn=normalize | ||
) | ||
|
||
make_policy = ppo_networks.make_inference_fn(ppo_network) | ||
policy_params = (params[0], params[1]) | ||
inference_fn = make_policy(policy_params) | ||
|
||
# Reset the env | ||
key_envs, _ = jax.random.split(rng) | ||
state = env.reset(rng=key_envs) | ||
#state = jax.jit(env.reset)(rng=key_envs) | ||
|
||
# Debug printing | ||
print(f"q: {state.pipeline_state.q}") | ||
print(f"\n") | ||
print(f"qd: {state.pipeline_state.qd}") | ||
print(f"\n") | ||
print(f"x: {state.pipeline_state.x}") | ||
print(f"\n") | ||
print(f"xd: {state.pipeline_state.xd}") | ||
print(f"\n") | ||
print(f"contact: {state.pipeline_state.contact}") | ||
print(f"\n") | ||
print(f"reward: {state.reward}") | ||
print(f"\n") | ||
print(state.metrics) | ||
print(f"\n") | ||
print(f"done: {state.done}") | ||
print(f"\n") | ||
|
||
# Rollout trajectory one physics step at a time | ||
episode_length = 1 | ||
|
||
jcmd = jp.array([1.0, 0.0, -0.7]) | ||
|
||
for _ in range(episode_length): | ||
print(f"\n---------------------------------------------------------------\n") | ||
|
||
act_rng, rng = jax.random.split(rng) | ||
print(f"rng: {rng}") | ||
print(f"act_rng: {act_rng}") | ||
print(f"\n") | ||
|
||
state.info['jcmd'] = jcmd | ||
print(f"state info: {state.info}") | ||
print(f"\n") | ||
|
||
act, _ = inference_fn(state.obs, act_rng) | ||
print(f"observation: {state.obs}") | ||
print(f"\n") | ||
print(f"action: {act}") | ||
print(f"\n") | ||
|
||
state = env.step(state, act) | ||
|
||
print(f"q: {state.pipeline_state.q}") | ||
print(f"\n") | ||
print(f"qd: {state.pipeline_state.qd}") | ||
print(f"\n") | ||
print(f"x: {state.pipeline_state.x}") | ||
print(f"\n") | ||
print(f"xd: {state.pipeline_state.xd}") | ||
print(f"\n") | ||
print(f"contact: {state.pipeline_state.contact}") | ||
print(f"\n") | ||
print(f"reward: {state.reward}") | ||
print(f"\n") | ||
print(state.metrics) | ||
print(f"\n") | ||
print(f"done: {state.done}") | ||
print(f"\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import functools | ||
import os | ||
import re | ||
from datetime import datetime | ||
|
||
import brax | ||
import flax | ||
import jax | ||
import optax | ||
import wandb | ||
from brax import envs | ||
|
||
from brax.io import html, json, model | ||
from brax.training.acme import running_statistics, specs | ||
from brax.training.agents.ppo import losses as ppo_losses | ||
from brax.training.agents.ppo import networks as ppo_networks | ||
from jax import numpy as jp | ||
|
||
from alfredo.agents.aant import AAnt | ||
from alfredo.train import ppo | ||
|
||
from alfredo.rewards import Reward | ||
from alfredo.rewards import rTracking_lin_vel | ||
from alfredo.rewards import rTracking_yaw_vel | ||
|
||
|
||
# Define Reward Structure | ||
rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=8.0, ps={}), | ||
'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.0, ps={})} | ||
|
||
# Initialize a new run | ||
wandb.init( | ||
project="aant", | ||
config={ | ||
"env_name": "AAnt", | ||
"seed": 13, | ||
|
||
"training_params": { | ||
"backend": "positional", | ||
"len_training": 1_500_000, | ||
"num_evals": 500, | ||
"num_envs": 2048, | ||
"batch_size": 2048, | ||
"num_minibatches": 8, | ||
"updates_per_batch": 8, | ||
"episode_len": 1000, | ||
"unroll_len": 10, | ||
"reward_scaling":1, | ||
"action_repeat": 1, | ||
"discounting": 0.97, | ||
"learning_rate": 3e-4, | ||
"entropy_cost": 1e-3, | ||
"reward_scaling": 0.1, | ||
"normalize_obs": True, | ||
}, | ||
|
||
"rewards":rewards, | ||
|
||
"aux_model_params":{ | ||
|
||
} | ||
}, | ||
) | ||
|
||
# define callback function that will report training progress | ||
def progress(num_steps, metrics): | ||
print(num_steps) | ||
print(metrics) | ||
|
||
epi_len = wandb.config.training_params['episode_len'] | ||
|
||
log_dict = {'step': num_steps} | ||
|
||
for mn, m in metrics.items(): | ||
name_in_log = mn.split('/')[-1] | ||
log_dict[name_in_log] = m/epi_len | ||
|
||
wandb.log(log_dict) | ||
|
||
|
||
# get the filepath to the env and agent xmls | ||
cwd = os.getcwd() | ||
|
||
import alfredo.scenes as scenes | ||
import alfredo.agents as agents | ||
agents_fp = os.path.dirname(agents.__file__) | ||
agent_xml_path = f"{agents_fp}/aant/aant.xml" | ||
|
||
scenes_fp = os.path.dirname(scenes.__file__) | ||
|
||
env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml", | ||
f"{scenes_fp}/flatworld/flatworld_A1_env.xml", | ||
f"{scenes_fp}/flatworld/flatworld_A1_env.xml", | ||
f"{scenes_fp}/flatworld/flatworld_A1_env.xml"] | ||
|
||
# make and save initial ppo_network | ||
key = jax.random.PRNGKey(wandb.config.seed) | ||
global_key, local_key = jax.random.split(key) | ||
key_policy, key_value = jax.random.split(global_key) | ||
|
||
env = AAnt(backend=wandb.config.training_params['backend'], | ||
rewards=rewards, | ||
env_xml_path=env_xml_paths[0], | ||
agent_xml_path=agent_xml_path) | ||
|
||
rng = jax.random.PRNGKey(seed=0) | ||
state = env.reset(rng) | ||
|
||
normalize_fn = running_statistics.normalize | ||
|
||
ppo_network = ppo_networks.make_ppo_networks( | ||
env.observation_size, env.action_size, normalize_fn | ||
) | ||
|
||
init_params = ppo_losses.PPONetworkParams( | ||
policy=ppo_network.policy_network.init(key_policy), | ||
value=ppo_network.value_network.init(key_value), | ||
) | ||
|
||
normalizer_params = running_statistics.init_state( | ||
specs.Array(env.observation_size, jp.float32) | ||
) | ||
|
||
params_to_save = (normalizer_params, init_params.policy, init_params.value) | ||
|
||
model.save_params(f"param-store/AAnt_params_0", params_to_save) | ||
|
||
# ============================ | ||
# Training & Saving Params | ||
# ============================ | ||
i = 0 | ||
|
||
for p in env_xml_paths: | ||
|
||
d_and_t = datetime.now() | ||
print(f"[{d_and_t}] loop start for model: {i}") | ||
env = AAnt(backend=wandb.config.training_params['backend'], | ||
rewards=rewards, | ||
env_xml_path=p, | ||
agent_xml_path=agent_xml_path) | ||
|
||
mF = f"{cwd}/param-store/{wandb.config.env_name}_params_{i}" | ||
mParams = model.load_params(mF) | ||
|
||
d_and_t = datetime.now() | ||
print(f"[{d_and_t}] jitting start for model: {i}") | ||
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=wandb.config.seed)) | ||
d_and_t = datetime.now() | ||
print(f"[{d_and_t}] jitting end for model: {i}") | ||
|
||
# define new training function | ||
train_fn = functools.partial( | ||
ppo.train, | ||
num_timesteps=wandb.config.training_params['len_training'], | ||
num_evals=wandb.config.training_params['num_evals'], | ||
reward_scaling=wandb.config.training_params['reward_scaling'], | ||
episode_length=wandb.config.training_params['episode_len'], | ||
normalize_observations=wandb.config.training_params['normalize_obs'], | ||
action_repeat=wandb.config.training_params['action_repeat'], | ||
unroll_length=wandb.config.training_params['unroll_len'], | ||
num_minibatches=wandb.config.training_params['num_minibatches'], | ||
num_updates_per_batch=wandb.config.training_params['updates_per_batch'], | ||
discounting=wandb.config.training_params['discounting'], | ||
learning_rate=wandb.config.training_params['learning_rate'], | ||
entropy_cost=wandb.config.training_params['entropy_cost'], | ||
num_envs=wandb.config.training_params['num_envs'], | ||
batch_size=wandb.config.training_params['batch_size'], | ||
seed=wandb.config.seed, | ||
in_params=mParams, | ||
) | ||
|
||
d_and_t = datetime.now() | ||
print(f"[{d_and_t}] training start for model: {i}") | ||
_, params, _, ts = train_fn(environment=env, progress_fn=progress) | ||
d_and_t = datetime.now() | ||
print(f"[{d_and_t}] training end for model: {i}") | ||
|
||
i += 1 | ||
next_m_name = f"param-store/{wandb.config.env_name}_params_{i}" | ||
model.save_params(next_m_name, params) | ||
|
||
d_and_t = datetime.now() | ||
print(f"[{d_and_t}] loop end for model: {i}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.