Skip to content

Commit

Permalink
debugged, training now works and all experiment files have been updated
Browse files Browse the repository at this point in the history
mginoya committed Oct 7, 2024
1 parent 4c3d7f8 commit 6862330
Showing 5 changed files with 99 additions and 66 deletions.
6 changes: 3 additions & 3 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ def __init__(self,
agent_xml_path = "",
terminate_when_unhealthy=False,
reset_noise_scale=0.1,
exclude_current_positions_from_observation=True,
exclude_current_positions_from_observation=False,
backend='generalized',
**kwargs,):

@@ -105,7 +105,7 @@ def reset(self, rng: jax.Array) -> State:

# get initial observation vector
obs = self._get_obs(pipeline_state, state_info)

return State(pipeline_state, obs, reward, done, metrics, state_info)

def step(self, state: State, action: jax.Array) -> State:
@@ -126,7 +126,7 @@ def step(self, state: State, action: jax.Array) -> State:
r.add_param('pipeline_state', pipeline_state)

reward_value = r.compute()
state.info['rewards'][rn] = reward_value
state.info['rewards'][rn] = reward_value[0]
total_reward += reward_value[0]
# print(f'{rn} reward_val = {reward_value}\n')

7 changes: 7 additions & 0 deletions alfredo/rewards/reward.py
Original file line number Diff line number Diff line change
@@ -32,3 +32,10 @@ def compute(self):
res = res.at[0].multiply(self.scale) #may not be the best way to do this

return res

def __str__(self):
"""
provides a standard string output
"""

return f'reward: {self.f}, scale: {self.scale}'
6 changes: 5 additions & 1 deletion experiments/AAnt-locomotion/one_physics_step.py
Original file line number Diff line number Diff line change
@@ -39,6 +39,8 @@
rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=15.0, ps={}),
'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.2, ps={})}

print(rewards)

# Get the filepath to the env and agent xmls
import alfredo.scenes as scenes
import alfredo.agents as agents
@@ -65,7 +67,8 @@
agent_xml_path=agent_xml_path)

rng = jax.random.PRNGKey(seed=3)
state = env.reset(rng=rng) #state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
state = env.reset(rng=rng)
#state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

# Initialize inference neural network
normalize = lambda x, y: x
@@ -82,6 +85,7 @@
# Reset the env
key_envs, _ = jax.random.split(rng)
state = env.reset(rng=key_envs)
#state = jax.jit(env.reset)(rng=key_envs)

# Debug printing
print(f"q: {state.pipeline_state.q}")
129 changes: 70 additions & 59 deletions experiments/AAnt-locomotion/training.py
Original file line number Diff line number Diff line change
@@ -6,11 +6,10 @@
import brax
import flax
import jax
import matplotlib.pyplot as plt
import optax
import wandb
from brax import envs
# from brax.envs.wrappers import training

from brax.io import html, json, model
from brax.training.acme import running_statistics, specs
from brax.training.agents.ppo import losses as ppo_losses
@@ -20,60 +19,68 @@
from alfredo.agents.aant import AAnt
from alfredo.train import ppo

from alfredo.rewards import Reward
from alfredo.rewards import rTracking_lin_vel
from alfredo.rewards import rTracking_yaw_vel


# Define Reward Structure
rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=8.0, ps={}),
'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.0, ps={})}

# Initialize a new run
wandb.init(
project="aant",
config={
"env_name": "AAnt",
"backend": "positional",
"seed": 13,
"len_training": 1_500_000,
"num_evals": 500,
"num_envs": 2048,
"batch_size": 2048,
"num_minibatches": 8,
"updates_per_batch": 8,
"episode_len": 1000,
"unroll_len": 10,
"reward_scaling":1,
"action_repeat": 1,
"discounting": 0.97,
"learning_rate": 3e-4,
"entropy_cost": 1e-3,
"reward_scaling": 0.1,
"normalize_obs": True,

"training_params": {
"backend": "positional",
"len_training": 1_500_000,
"num_evals": 500,
"num_envs": 2048,
"batch_size": 2048,
"num_minibatches": 8,
"updates_per_batch": 8,
"episode_len": 1000,
"unroll_len": 10,
"reward_scaling":1,
"action_repeat": 1,
"discounting": 0.97,
"learning_rate": 3e-4,
"entropy_cost": 1e-3,
"reward_scaling": 0.1,
"normalize_obs": True,
},

"rewards":rewards,

"aux_model_params":{

}
},
)

normalize_fn = running_statistics.normalize

# define callback function that will report training progress
def progress(num_steps, metrics):
print(num_steps)
print(metrics)
epi_len = wandb.config.episode_len
wandb.log(
{
"step": num_steps,
"Total Reward": metrics["eval/episode_reward"]/epi_len,
"Waypoint Reward": metrics["eval/episode_reward_waypoint"]/epi_len,
"Lin Vel Reward": metrics["eval/episode_reward_lin_vel"]/epi_len,
"Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"]/epi_len,
"Alive Reward": metrics["eval/episode_reward_alive"]/epi_len,
"Ctrl Reward": metrics["eval/episode_reward_ctrl"]/epi_len,
"Upright Reward": metrics["eval/episode_reward_upright"]/epi_len,
"Torque Reward": metrics["eval/episode_reward_torque"]/epi_len,
"Abs Pos X World": metrics["eval/episode_pos_x_world_abs"]/epi_len,
"Abs Pos Y World": metrics["eval/episode_pos_y_world_abs"]/epi_len,
"Abs Pos Z World": metrics["eval/episode_pos_z_world_abs"]/epi_len,
#"Dist Goal X": metrics["eval/episode_dist_goal_x"]/epi_len,
#"Dist Goal Y": metrics["eval/episode_dist_goal_y"]/epi_len,
#"Dist Goal Z": metrics["eval/episode_dist_goal_z"]/epi_len,
}
)

epi_len = wandb.config.training_params['episode_len']

log_dict = {'step': num_steps}

for mn, m in metrics.items():
name_in_log = mn.split('/')[-1]
log_dict[name_in_log] = m/epi_len

wandb.log(log_dict)

cwd = os.getcwd()

# get the filepath to the env and agent xmls
cwd = os.getcwd()

import alfredo.scenes as scenes
import alfredo.agents as agents
agents_fp = os.path.dirname(agents.__file__)
@@ -91,13 +98,16 @@ def progress(num_steps, metrics):
global_key, local_key = jax.random.split(key)
key_policy, key_value = jax.random.split(global_key)

env = AAnt(backend=wandb.config.backend,
env = AAnt(backend=wandb.config.training_params['backend'],
rewards=rewards,
env_xml_path=env_xml_paths[0],
agent_xml_path=agent_xml_path)

rng = jax.random.PRNGKey(seed=1)
rng = jax.random.PRNGKey(seed=0)
state = env.reset(rng)

normalize_fn = running_statistics.normalize

ppo_network = ppo_networks.make_ppo_networks(
env.observation_size, env.action_size, normalize_fn
)
@@ -124,7 +134,8 @@ def progress(num_steps, metrics):

d_and_t = datetime.now()
print(f"[{d_and_t}] loop start for model: {i}")
env = AAnt(backend=wandb.config.backend,
env = AAnt(backend=wandb.config.training_params['backend'],
rewards=rewards,
env_xml_path=p,
agent_xml_path=agent_xml_path)

@@ -133,27 +144,27 @@ def progress(num_steps, metrics):

d_and_t = datetime.now()
print(f"[{d_and_t}] jitting start for model: {i}")
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=1))
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=wandb.config.seed))
d_and_t = datetime.now()
print(f"[{d_and_t}] jitting end for model: {i}")

# define new training function
train_fn = functools.partial(
ppo.train,
num_timesteps=wandb.config.len_training,
num_evals=wandb.config.num_evals,
reward_scaling=wandb.config.reward_scaling,
episode_length=wandb.config.episode_len,
normalize_observations=wandb.config.normalize_obs,
action_repeat=wandb.config.action_repeat,
unroll_length=wandb.config.unroll_len,
num_minibatches=wandb.config.num_minibatches,
num_updates_per_batch=wandb.config.updates_per_batch,
discounting=wandb.config.discounting,
learning_rate=wandb.config.learning_rate,
entropy_cost=wandb.config.entropy_cost,
num_envs=wandb.config.num_envs,
batch_size=wandb.config.batch_size,
num_timesteps=wandb.config.training_params['len_training'],
num_evals=wandb.config.training_params['num_evals'],
reward_scaling=wandb.config.training_params['reward_scaling'],
episode_length=wandb.config.training_params['episode_len'],
normalize_observations=wandb.config.training_params['normalize_obs'],
action_repeat=wandb.config.training_params['action_repeat'],
unroll_length=wandb.config.training_params['unroll_len'],
num_minibatches=wandb.config.training_params['num_minibatches'],
num_updates_per_batch=wandb.config.training_params['updates_per_batch'],
discounting=wandb.config.training_params['discounting'],
learning_rate=wandb.config.training_params['learning_rate'],
entropy_cost=wandb.config.training_params['entropy_cost'],
num_envs=wandb.config.training_params['num_envs'],
batch_size=wandb.config.training_params['batch_size'],
seed=wandb.config.seed,
in_params=mParams,
)
17 changes: 14 additions & 3 deletions experiments/AAnt-locomotion/vis_traj.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,10 @@

from alfredo.agents.aant import AAnt

from alfredo.rewards import Reward
from alfredo.rewards import rTracking_lin_vel
from alfredo.rewards import rTracking_yaw_vel

backend = "positional"

# Load desired model xml and trained param set
@@ -39,8 +43,15 @@

params = model.load_params(tpf_path)

# Define Reward Structure
# For visualizing, this is just to be able to create the env
# May want to make this not necessary in the future ..?
rewards = {'r_lin_vel': Reward(rTracking_lin_vel, sc=8.0, ps={}),
'r_yaw_vel': Reward(rTracking_yaw_vel, sc=1.0, ps={})}

# create an env with auto-reset and load previously trained parameters
env = AAnt(backend=backend,
env = AAnt(backend=backend,
rewards=rewards,
env_xml_path=env_xml_path,
agent_xml_path=agent_xml_path)

@@ -75,8 +86,8 @@
jit_inference_fn = jax.jit(inference_fn)

x_vel = 0.0 # m/s
y_vel = 0.0 # m/s
yaw_vel = -0.7 # rad/s
y_vel = 3.0 # m/s
yaw_vel = 0.0 # rad/s
jcmd = jp.array([x_vel, y_vel, yaw_vel])

wcmd = jp.array([10.0, 10.0])

0 comments on commit 6862330

Please sign in to comment.