-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Action Normalization #57
base: master
Are you sure you want to change the base?
Changes from all commits
abbee68
8662524
ae4bd7d
ed5e8c6
9a92b36
a2c2e47
c169db8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -421,7 +421,7 @@ class RolloutPolicy(object): | |
""" | ||
Wraps @Algo object to make it easy to run policies in a rollout loop. | ||
""" | ||
def __init__(self, policy, obs_normalization_stats=None): | ||
def __init__(self, policy, obs_normalization_stats=None, action_normalization_stats=None): | ||
""" | ||
Args: | ||
policy (Algo instance): @Algo object to wrap to prepare for rollouts | ||
|
@@ -430,9 +430,15 @@ def __init__(self, policy, obs_normalization_stats=None): | |
normalization. This should map observation keys to dicts | ||
with a "mean" and "std" of shape (1, ...) where ... is the default | ||
shape for the observation. | ||
|
||
action_normalization_stats (dict): optionally pass a dictionary for action | ||
normalization. This should be a dict with keys | ||
"scale" and "offset" of shape (1, ...) where ... is the default | ||
shape for the action. | ||
""" | ||
self.policy = policy | ||
self.obs_normalization_stats = obs_normalization_stats | ||
self.action_normalization_stats = action_normalization_stats | ||
|
||
def start_episode(self): | ||
""" | ||
|
@@ -474,4 +480,8 @@ def __call__(self, ob, goal=None): | |
if goal is not None: | ||
goal = self._prepare_observation(goal) | ||
ac = self.policy.get_action(obs_dict=ob, goal_dict=goal) | ||
return TensorUtils.to_numpy(ac[0]) | ||
ac = TensorUtils.to_numpy(ac) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any reason for changing |
||
if self.action_normalization_stats is not None: | ||
ac = ObsUtils.unnormalize_actions(ac, self.action_normalization_stats) | ||
ac = ac[0] | ||
return ac |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
""" | ||
Example: | ||
python robomimic/scripts/set_dataset_attr.py --glob 'datasets/**/*_abs.hdf5' --env_args env_kwargs.controller_configs.control_delta=false absolute_actions=true | ||
""" | ||
import argparse | ||
import pathlib | ||
import json | ||
import sys | ||
import tqdm | ||
import h5py | ||
|
||
def update_env_args_dict(env_args_dict: dict, key: tuple, value): | ||
if key is None: | ||
return env_args_dict | ||
elif len(key) == 0: | ||
return env_args_dict | ||
elif len(key) == 1: | ||
env_args_dict[key[0]] = value | ||
return env_args_dict | ||
else: | ||
this_key = key[0] | ||
if this_key not in env_args_dict: | ||
env_args_dict[this_key] = dict() | ||
update_env_args_dict(env_args_dict[this_key], key[1:], value) | ||
return env_args_dict | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--glob", | ||
type=str, | ||
required=True | ||
) | ||
|
||
parser.add_argument( | ||
"--env_args", | ||
type=str, | ||
default=None | ||
) | ||
|
||
parser.add_argument( | ||
'attrs', | ||
nargs='*' | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
# parse attrs to set | ||
# format: key=value | ||
# values are parsed with json | ||
attrs_dict = dict() | ||
for attr_arg in args.attrs: | ||
key, svalue = attr_arg.split("=") | ||
value = json.loads(svalue) | ||
attrs_dict[key] = value | ||
|
||
# parse env_args update | ||
env_args_key = None | ||
env_args_value = None | ||
if args.env_args is not None: | ||
key, svalue = args.env_args.split('=') | ||
env_args_key = key.split('.') | ||
env_args_value = json.loads(svalue) | ||
|
||
# find files | ||
file_paths = list(pathlib.Path.cwd().glob(args.glob)) | ||
|
||
# confirm with the user | ||
print("Found matching files:") | ||
for f in file_paths: | ||
print(f) | ||
print("Are you sure to modify these files with the following attributes:") | ||
print(json.dumps(attrs_dict, indent=2)) | ||
if env_args_key is not None: | ||
print("env_args."+'.'.join(env_args_key)+'='+str(env_args_value)) | ||
result = input("[y/n]?") | ||
if 'y' not in result: | ||
sys.exit(0) | ||
|
||
# execute | ||
for file_path in tqdm.tqdm(file_paths): | ||
with h5py.File(str(file_path), mode='r+') as file: | ||
# update env_args | ||
if env_args_key is not None: | ||
env_args = file['data'].attrs['env_args'] | ||
env_args_dict = json.loads(env_args) | ||
env_args_dict = update_env_args_dict( | ||
env_args_dict=env_args_dict, | ||
key=env_args_key, value=env_args_value) | ||
env_args = json.dumps(env_args_dict) | ||
file['data'].attrs['env_args'] = env_args | ||
|
||
# update other attrs | ||
file['data'].attrs.update(attrs_dict) | ||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ def __init__( | |
hdf5_cache_mode=None, | ||
hdf5_use_swmr=True, | ||
hdf5_normalize_obs=False, | ||
hdf5_normalize_action=None, | ||
filter_by_attribute=None, | ||
load_next_obs=True, | ||
): | ||
|
@@ -74,6 +75,9 @@ def __init__( | |
hdf5_normalize_obs (bool): if True, normalize observations by computing the mean observation | ||
and std of each observation (in each dimension and modality), and normalizing to unit | ||
mean and variance in each dimension. | ||
|
||
hdf5_normalize_action (bool or None): if True, normalize actions' range to [-1,1]. If None, | ||
this value is determined by the hdf5_file['data'].attrs['absolute_actions'] attribute. | ||
|
||
filter_by_attribute (str): if provided, use the provided filter key to look up a subset of | ||
demonstrations to load | ||
|
@@ -86,6 +90,11 @@ def __init__( | |
self.hdf5_use_swmr = hdf5_use_swmr | ||
self.hdf5_normalize_obs = hdf5_normalize_obs | ||
self._hdf5_file = None | ||
|
||
if hdf5_normalize_action is None: | ||
hdf5_normalize_action = self.hdf5_file['data'].attrs.get('absolute_actions', False) | ||
self.hdf5_normalize_action = hdf5_normalize_action | ||
|
||
|
||
assert hdf5_cache_mode in ["all", "low_dim", None] | ||
self.hdf5_cache_mode = hdf5_cache_mode | ||
|
@@ -119,6 +128,10 @@ def __init__( | |
self.obs_normalization_stats = None | ||
if self.hdf5_normalize_obs: | ||
self.obs_normalization_stats = self.normalize_obs() | ||
|
||
self.action_normalization_stats = None | ||
if self.hdf5_normalize_action: | ||
self.action_normalization_stats = self.normalize_actions() | ||
|
||
# maybe store dataset in memory for fast access | ||
if self.hdf5_cache_mode in ["all", "low_dim"]: | ||
|
@@ -366,6 +379,98 @@ def get_obs_normalization_stats(self): | |
assert self.hdf5_normalize_obs, "not using observation normalization!" | ||
return deepcopy(self.obs_normalization_stats) | ||
|
||
def normalize_actions(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Naming of this function may be confused for |
||
""" | ||
Computes a dataset-wide min, max, mean and standard deviation for the actions | ||
(per dimension) and returns it. | ||
""" | ||
def _compute_traj_stats(traj_obs_dict): | ||
""" | ||
Helper function to compute statistics over a single trajectory of observations. | ||
""" | ||
traj_stats = { k : {} for k in traj_obs_dict } | ||
for k in traj_obs_dict: | ||
traj_stats[k]["n"] = traj_obs_dict[k].shape[0] | ||
traj_stats[k]["mean"] = traj_obs_dict[k].mean(axis=0, keepdims=True) # [1, ...] | ||
traj_stats[k]["sqdiff"] = ((traj_obs_dict[k] - traj_stats[k]["mean"]) ** 2).sum(axis=0, keepdims=True) # [1, ...] | ||
traj_stats[k]["min"] = traj_obs_dict[k].min(axis=0, keepdims=True) | ||
traj_stats[k]["max"] = traj_obs_dict[k].max(axis=0, keepdims=True) | ||
return traj_stats | ||
|
||
def _aggregate_traj_stats(traj_stats_a, traj_stats_b): | ||
""" | ||
Helper function to aggregate trajectory statistics. | ||
See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm | ||
for more information. | ||
""" | ||
merged_stats = {} | ||
for k in traj_stats_a: | ||
n_a, avg_a, M2_a, min_a, max_a = traj_stats_a[k]["n"], traj_stats_a[k]["mean"], traj_stats_a[k]["sqdiff"], traj_stats_a[k]["min"], traj_stats_a[k]["max"] | ||
n_b, avg_b, M2_b, min_b, max_b = traj_stats_b[k]["n"], traj_stats_b[k]["mean"], traj_stats_b[k]["sqdiff"], traj_stats_b[k]["min"], traj_stats_b[k]["max"] | ||
n = n_a + n_b | ||
mean = (n_a * avg_a + n_b * avg_b) / n | ||
delta = (avg_b - avg_a) | ||
M2 = M2_a + M2_b + (delta ** 2) * (n_a * n_b) / n | ||
min_ = np.minimum(min_a, min_b) | ||
max_ = np.maximum(max_a, max_b) | ||
merged_stats[k] = dict(n=n, mean=mean, sqdiff=M2, min=min_, max=max_) | ||
return merged_stats | ||
|
||
# Run through all trajectories. For each one, compute minimal observation statistics, and then aggregate | ||
# with the previous statistics. | ||
def get_action_traj(ep): | ||
action_traj = dict() | ||
action_traj['actions'] = self.hdf5_file["data/{}/actions".format(ep)][()].astype('float32') | ||
return action_traj | ||
|
||
ep = self.demos[0] | ||
action_traj = get_action_traj(ep) | ||
merged_stats = _compute_traj_stats(action_traj) | ||
print("SequenceDataset: normalizing actions...") | ||
for ep in LogUtils.custom_tqdm(self.demos[1:]): | ||
action_traj = get_action_traj(ep) | ||
traj_stats = _compute_traj_stats(action_traj) | ||
merged_stats = _aggregate_traj_stats(merged_stats, traj_stats) | ||
|
||
normalization_stats = { k : {} for k in merged_stats } | ||
for k in merged_stats: | ||
normalization_stats[k]["mean"] = merged_stats[k]["mean"].astype('float32') | ||
normalization_stats[k]["std"] = np.sqrt(merged_stats[k]["sqdiff"] / merged_stats[k]["n"]).astype('float32') | ||
normalization_stats[k]["min"] = merged_stats[k]["min"].astype('float32') | ||
normalization_stats[k]["max"] = merged_stats[k]["max"].astype('float32') | ||
|
||
# convert min and max to scale and offset | ||
stats = normalization_stats['actions'] | ||
range_eps = 1e-4 | ||
input_min = stats['min'] | ||
input_max = stats['max'] | ||
output_min = -1.0 | ||
output_max = 1.0 | ||
|
||
input_range = input_max - input_min | ||
ignore_dim = input_range < range_eps | ||
input_range[ignore_dim] = output_max - output_min | ||
scale = (output_max - output_min) / input_range | ||
offset = output_min - scale * input_min | ||
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim] | ||
|
||
action_normalization_stats = { | ||
"scale": scale, | ||
"offset": offset | ||
} | ||
return action_normalization_stats | ||
|
||
def get_action_normalization_stats(self): | ||
""" | ||
Returns dictionary of min, max, mean and std for actions. | ||
|
||
Returns: | ||
action_normalization_stats (dict): a dictionary for action | ||
normalization with a "min", "max", "mean" and "std" of shape (1, ...) where ... is the default | ||
shape for the action. | ||
""" | ||
return deepcopy(self.action_normalization_stats) | ||
|
||
def get_dataset_for_ep(self, ep, key): | ||
""" | ||
Helper utility to get a dataset for a specific demonstration. | ||
|
@@ -443,6 +548,8 @@ def get_item(self, index): | |
) | ||
if self.hdf5_normalize_obs: | ||
meta["obs"] = ObsUtils.normalize_obs(meta["obs"], obs_normalization_stats=self.obs_normalization_stats) | ||
if self.hdf5_normalize_action: | ||
meta["actions"] = ObsUtils.normalize_actions(meta["actions"], action_normalization_stats=self.action_normalization_stats) | ||
|
||
if self.load_next_obs: | ||
meta["next_obs"] = self.get_obs_sequence_from_demo( | ||
|
@@ -513,7 +620,7 @@ def get_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_sta | |
|
||
seq = TensorUtils.pad_sequence(seq, padding=(seq_begin_pad, seq_end_pad), pad_same=True) | ||
pad_mask = np.array([0] * seq_begin_pad + [1] * (seq_end_index - seq_begin_index) + [0] * seq_end_pad) | ||
pad_mask = pad_mask[:, None].astype(np.bool) | ||
pad_mask = pad_mask[:, None].astype(bool) | ||
|
||
return seq, pad_mask | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -499,6 +499,17 @@ def normalize_obs(obs_dict, obs_normalization_stats): | |
|
||
return obs_dict | ||
|
||
def normalize_actions(actions, action_normalization_stats): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. small nitpick: our convention here is to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, both |
||
scale = action_normalization_stats['scale'] | ||
offset = action_normalization_stats['offset'] | ||
actions = actions * scale + offset | ||
return actions | ||
|
||
def unnormalize_actions(actions, action_normalization_stats): | ||
scale = action_normalization_stats['scale'] | ||
offset = action_normalization_stats['offset'] | ||
actions = (actions - offset) / scale | ||
return actions | ||
|
||
def has_modality(modality, obs_keys): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add some comments in the function docstring for
action_normalization_stats
? Similar to how it's already done forobs_normalization_stats