Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R2D2 share encoder for varied cameras #103

Open
wants to merge 5 commits into
base: rgb-cam-encoder
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions robomimic/algo/diffusion_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils

# import torch.distributed as dist
# from torch.nn.parallel import DistributedDataParallel as DDP


@register_algo_factory_func("diffusion_policy")
def algo_config_to_class(algo_config):
"""
Expand Down Expand Up @@ -76,8 +80,8 @@ def _create_networks(self):
# the final arch has 2 parts
nets = nn.ModuleDict({
'policy': nn.ModuleDict({
'obs_encoder': obs_encoder,
'noise_pred_net': noise_pred_net
'obs_encoder': torch.nn.parallel.DataParallel(obs_encoder, device_ids=list(range(0,8))),
'noise_pred_net': torch.nn.parallel.DataParallel(noise_pred_net, device_ids=list(range(0,8))),
})
})

Expand Down Expand Up @@ -173,7 +177,6 @@ def train_on_batch(self, batch, epoch, validate=False):
action_dim = self.ac_dim
B = batch['actions'].shape[0]


with TorchUtils.maybe_no_grad(no_grad=validate):
info = super(DiffusionPolicyUNet, self).train_on_batch(batch, epoch, validate=validate)
actions = batch['actions']
Expand Down
1 change: 1 addition & 0 deletions robomimic/config/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def observation_config(self):
self.observation.encoder.low_dim.obs_randomizer_kwargs.do_not_lock_keys()

# =============== RGB default encoder (ResNet backbone + linear layer output) ===============
self.observation.encoder.rgb.fuser = None # How to combine the outputs of multi-camera vision encoders
self.observation.encoder.rgb.core_class = "VisualCore" # Default VisualCore class combines backbone (like ResNet-18) with pooling operation (like spatial softmax)
self.observation.encoder.rgb.core_kwargs = Config() # See models/obs_core.py for important kwargs to set and defaults used
self.observation.encoder.rgb.core_kwargs.do_not_lock_keys()
Expand Down
4 changes: 1 addition & 3 deletions robomimic/models/obs_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(
self.backbone = eval(backbone_class)(**backbone_kwargs)

feat_shape = self.backbone.output_shape(input_shape)
self.feat_shape = feat_shape
# net_list = [self.backbone]
self.nets["backbone"] = self.backbone

Expand All @@ -254,8 +255,6 @@ def __init__(
else:
self.pool = None

post_proc_list.append(torch.nn.Conv1d(in_channels = feat_shape[-1], out_channels=feature_dimension*2, kernel_size=1))
post_proc_list.append(torch.nn.Conv1d(in_channels = feature_dimension*2, out_channels=feature_dimension, kernel_size=1))

# flatten layer
if self.flatten:
Expand Down Expand Up @@ -312,7 +311,6 @@ def forward(self, inputs):

x = self.nets["backbone"].forward(image=image, intrinsics=intrinsics, extrinsics=extrinsics)
# x = self.nets["backbone"].forward(inputs["image"])
x = x.permute(0, 2, 1) # B, L, C --> B, C, L
x = self.nets["post_proc"](x)

return x
Expand Down
70 changes: 62 additions & 8 deletions robomimic/models/obs_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def obs_encoder_factory(
obs_modality2: dict
...
"""
enc = ObservationEncoder(feature_activation=feature_activation)
enc = ObservationEncoder(feature_activation=feature_activation, fuser = encoder_kwargs["rgb"]["fuser"])
for k, obs_shape in obs_shapes.items():
obs_modality = ObsUtils.OBS_KEYS_TO_MODALITIES[k]
enc_kwargs = deepcopy(ObsUtils.DEFAULT_ENCODER_KWARGS[obs_modality]) if encoder_kwargs is None else \
Expand Down Expand Up @@ -103,13 +103,24 @@ def obs_encoder_factory(

input_maps = enc_kwargs.get("input_maps", {})

if any("camera/image/varied_camera" in s for s in enc.obs_shapes.keys()) and ("camera/image/varied_camera" in k):
existing_varied_cam = [a for a in enc.obs_shapes.keys() if "camera/image/varied_camera" in a][0]
share = existing_varied_cam
net_class = None
net_kwargs = None
else:
share = None
net_class = enc_kwargs["core_class"]
net_kwargs = enc_kwargs["core_kwargs"]

enc.register_obs_key(
name=k,
shape=obs_shape,
input_map=input_maps.get(k, None),
net_class=enc_kwargs["core_class"],
net_kwargs=enc_kwargs["core_kwargs"],
net_class=net_class,
net_kwargs=net_kwargs,
randomizers=randomizers,
share_net_from=share,
)

enc.make()
Expand All @@ -123,7 +134,7 @@ class ObservationEncoder(Module):
Call @register_obs_key to register observation keys with the encoder and then
finally call @make to create the encoder networks.
"""
def __init__(self, feature_activation=nn.ReLU):
def __init__(self, feature_activation=nn.ReLU, fuser = None):
"""
Args:
feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass
Expand All @@ -138,6 +149,8 @@ def __init__(self, feature_activation=nn.ReLU):
self.obs_nets = nn.ModuleDict()
self.obs_randomizers = nn.ModuleDict()
self.feature_activation = feature_activation
self.fuser = fuser
self.num_images = 0
self._locked = False

def register_obs_key(
Expand Down Expand Up @@ -171,6 +184,8 @@ def register_obs_key(
"""
assert not self._locked, "ObservationEncoder: @register_obs_key called after @make"
assert name not in self.obs_shapes, "ObservationEncoder: modality {} already exists".format(name)
if "image" in name:
self.num_images += 1

if net is not None:
assert isinstance(net, Module), "ObservationEncoder: @net must be instance of Module class"
Expand Down Expand Up @@ -224,6 +239,39 @@ def _create_layers(self):
if self.feature_activation is not None:
self.activation = self.feature_activation()

if self.fuser == "transformer":
## Define a fuser which takes multiple camera features as [B, sequence of pixels, features]
## and encodes them with a transformer
input_features = self.obs_nets["camera/image/hand_camera_left_image"].feat_shape
# First scales down number of features on each pixel
self.c1 = torch.nn.Conv1d(in_channels = input_features[1], out_channels=512, kernel_size=1)
self.c2 = torch.nn.Conv1d(in_channels = 512, out_channels=512, kernel_size=1)
layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first = True)
self.fusernetwork = nn.TransformerEncoder(layer, num_layers=6)
# Finally flatten and linear layer before concatenated with other low dim features
self.l1 = torch.nn.Linear(input_features[0] * self.num_images * 512, 2048)

def final_collation(self, feats):
if self.fuser is None:
## Default fuser is that everything is just flattened and concatenated together
feats = [TensorUtils.flatten(x, begin_axis=1) for x in feats.values()]
return torch.cat(feats, dim=-1)
elif self.fuser == "transformer":
keys_with_images = [a for a in feats.keys() if "image" in a]
keys_without_images = [a for a in feats.keys() if "image" not in a]
non_image_feats = [TensorUtils.flatten(feats[k], begin_axis=1) for k in keys_without_images]

all_image_feats = []
for k in keys_with_images:
all_image_feats.append(feats[k])
all_image_feats = torch.cat(all_image_feats, dim = 1)
all_image_feats_postconv = self.c2(F.relu(self.c1(all_image_feats.permute(0, 2, 1)))).permute(0, 2, 1)
all_image_feats_posttrans = self.fusernetwork(all_image_feats_postconv)
output = self.l1(TensorUtils.flatten(all_image_feats_posttrans, begin_axis=1))
return torch.cat(non_image_feats + [output], -1)



def forward(self, obs_dict):
"""
Processes modalities according to the ordering in @self.obs_shapes. For each
Expand All @@ -248,7 +296,7 @@ def forward(self, obs_dict):
)

# process modalities by order given by @self.obs_shapes
feats = []
feats = {}
for k in self.obs_shapes:
if self.obs_input_maps[k] is not None:
x = dict()
Expand Down Expand Up @@ -284,18 +332,22 @@ def forward(self, obs_dict):
if rand is not None:
x = rand.forward_out(x)
# flatten to [B, D]
x = TensorUtils.flatten(x, begin_axis=1)
feats.append(x)
# x = TensorUtils.flatten(x, begin_axis=1)
feats[k] = (x)

output = self.final_collation(feats)
# concatenate all features together
return torch.cat(feats, dim=-1)
return output

def output_shape(self, input_shape=None):
"""
Compute the output shape of the encoder.
"""
feat_dim = 0
for k in self.obs_shapes:
# If the fuser is a transformer for image features, don't naively concatenate flattened shapes
if (self.fuser == "transformer") and ("image" in k):
continue
feat_shape = self.obs_shapes[k]
for rand in self.obs_randomizers[k]:
if rand is not None:
Expand All @@ -306,6 +358,8 @@ def output_shape(self, input_shape=None):
if rand is not None:
feat_shape = rand.output_shape_out(feat_shape)
feat_dim += int(np.prod(feat_shape))
if self.fuser == "transformer":
feat_dim += 2048
return [feat_dim]

def __repr__(self):
Expand Down
104 changes: 80 additions & 24 deletions robomimic/scripts/config_gen/diffusion_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def make_generator_helper(args):
values=[1000],
)

generator.add_param(
key="train.batch_size",
name="",
group=-1,
values=[8],
)

# use ddim by default
generator.add_param(
key="algo.ddim.enabled",
Expand Down Expand Up @@ -55,10 +62,10 @@ def make_generator_helper(args):
name="ds",
group=2,
values=[
[{"path": p} for p in scan_datasets("~/Downloads/example_pen_in_cup", postfix="trajectory_im128.h5")],
["DATAPATH"]
],
value_names=[
"pen-in-cup",
"data",
],
)
generator.add_param(
Expand Down Expand Up @@ -98,13 +105,35 @@ def make_generator_helper(args):
# "3cams-stereo",
]
)
generator.add_param(
key="observation.encoder.rgb.obs_randomizer_class",
name="obsrand",
group=130,
values=[
# "CropRandomizer", # crop only
# "ColorRandomizer", # jitter only
["ColorRandomizer", "CropRandomizer"], # jitter, followed by crop
],
hidename=True,
)
generator.add_param(
key="observation.encoder.rgb.obs_randomizer_kwargs",
name="obsrandargs",
group=-1,
values=[
# {"crop_height": 116, "crop_width": 116, "num_crops": 1, "pos_enc": False}, # crop only
# {}, # jitter only
[{}, {"crop_height": 116, "crop_width": 116, "num_crops": 1, "pos_enc": False}], # jitter, followed by crop
],
hidename=True,
)

generator.add_param(
key="observation.modalities.obs.low_dim",
name="ldkeys",
group=2498,
values=[
["robot_state/cartesian_position", "robot_state/gripper_position"],
# ["robot_state/cartesian_position", "robot_state/gripper_position"],
[
"robot_state/cartesian_position", "robot_state/gripper_position",
"camera/extrinsics/hand_camera_left", "camera/intrinsics/hand_camera_left", # "camera/extrinsics/hand_camera_left_gripper_offset",
Expand All @@ -116,27 +145,28 @@ def make_generator_helper(args):
]
],
value_names=[
"proprio",
# "proprio",
"proprio-cam",
],
hidename=False,
)
## All of the following are needed for DeFiNe style encoding
generator.add_param(
key="observation.encoder.rgb.input_maps",
name="",
group=2498,
values=[
{
"camera/image/hand_camera_left_image": {
"image": "camera/image/hand_camera_left_image",
},
"camera/image/varied_camera_1_left_image": {
"image": "camera/image/varied_camera_1_left_image",
},
"camera/image/varied_camera_2_left_image": {
"image": "camera/image/varied_camera_2_left_image",
},
},
# {
# "camera/image/hand_camera_left_image": {
# "image": "camera/image/hand_camera_left_image",
# },
# "camera/image/varied_camera_1_left_image": {
# "image": "camera/image/varied_camera_1_left_image",
# },
# "camera/image/varied_camera_2_left_image": {
# "image": "camera/image/varied_camera_2_left_image",
# },
# },
{
"camera/image/hand_camera_left_image": {
"image": "camera/image/hand_camera_left_image",
Expand All @@ -156,7 +186,7 @@ def make_generator_helper(args):
},
],
value_names=[
"define-image-only",
# "define-image-only",
"define",
],
hidename=True,
Expand All @@ -166,13 +196,21 @@ def make_generator_helper(args):
name="",
group=2498,
values=[
False,
# False,
True,
],
hidename=True,
)
generator.add_param(
key="observation.encoder.rgb.core_kwargs.backbone_kwargs.pretrained",
name="",
group=2498,
values=[
# False,
True,
],
hidename=True,
)


generator.add_param(
key="observation.encoder.rgb.core_class",
name="visenc",
Expand All @@ -191,7 +229,25 @@ def make_generator_helper(args):
name="visdim",
group=1234,
values=[
64, #512
None
],
hidename=True,
)
generator.add_param(
key="observation.encoder.rgb.core_kwargs.flatten",
name="flatten",
group=1234,
values=[
False
],
hidename=True,
)
generator.add_param(
key="observation.encoder.rgb.fuser",
name="fuser",
group=1234,
values=[
"transformer"
],
hidename=True,
)
Expand All @@ -202,17 +258,17 @@ def make_generator_helper(args):
# name="backbone",
# group=1234,
# values=[
# "ResNet18Conv",
# # "ResNet50Conv",
# # "ResNet18Conv",
# "ResNet50Conv",
# ],
# )
# generator.add_param(
# key="observation.encoder.rgb.core_kwargs.feature_dimension",
# name="visdim",
# group=1234,
# values=[
# 64,
# # 512,
# # 64,
# 512,
# ],
# )

Expand Down
Loading