Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
HeegerGao committed Mar 28, 2024
1 parent 36e5bb3 commit 36e0698
Show file tree
Hide file tree
Showing 14 changed files with 99 additions and 102 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html

## Data Preparation

Please put your data in the `data/{your_exp_name}`. We provide the demonstrations for the mug experiment at `data/mug/pick`, both for training and testing.
Please put your data in the `data/{your_exp_name}`. We provide the demonstrations for the mug-pick experiment at `data/mug/pick`, both for training and testing.

The demonstration file is a .npz file and is in the following data structure:
```
Expand All @@ -67,9 +67,16 @@ As stated in our paper, there is an SE(3)-invariant network $\phi$ that extracts
2. `python scripts/training/train_mani.py`

After these training, you will get a `seg_net.pth` and a `mani_net.pth` under `experiments/{your_exp_name}`.

Different hyperparameters in the config file leads to different performance, training speed, and memory cost. Have a try!

## Evaluation

Run `python scripts/testing/infer.py`. You can select the testing demonstrations in the input arguments. After this you will get a `pred_pose.npz` that records the predicted target pose, and a open3d window will visualize the result.
Run `python scripts/testing/infer.py`. You can select the testing demonstrations in the input arguments. After this you will get a `pred_pose.npz` that records the predicted target pose.

We provide different scripts for result and feature visualization in `scripts/testing`.

We provide pretrained models for the `mug/pick` experiment in `experiments/mug/pick`.

## Citing
```
Expand Down
2 changes: 1 addition & 1 deletion config/mug/pick.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"seg": {
"model": "SE3SegNet",
"device": "cuda:0",
"device": "cuda:1",
"data_aug": true,
"aug_methods": [
"downsample_table",
Expand Down
45 changes: 0 additions & 45 deletions config/mug/place.json

This file was deleted.

Binary file added data/mug/pick/all.npz
Binary file not shown.
Binary file added data/mug/pick/distracting.npz
Binary file not shown.
Binary file added data/mug/pick/instance.npz
Binary file not shown.
Binary file added data/mug/pick/newpose.npz
Binary file not shown.
Binary file added data/mug/pick/training.npz
Binary file not shown.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ apex==0.9.10dev
e3nn==0.5.1
omegaconf==2.3.0
open3d==0.17.0
opencv-python==4.9.0.80
potpourri3d==1.0.0
pynvml==11.5.0
scipy
Expand Down
19 changes: 9 additions & 10 deletions scripts/testing/infer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys
sys.path.append(".")
import os
import torch
from networks import *
Expand All @@ -6,22 +8,20 @@
import argparse
import numpy as np
from utils.utils import modified_gram_schmidt
from utils.vis import vis_result


def main(args):
all_cfg = OmegaConf.load(f"config/{args.exp_name}/{args.pick_or_place}.json")
cfg_seg = all_cfg.seg
cfg_mani = all_cfg.mani

wd = os.path.join("experiments", args.exp_name, args.pick_or_place, args.setting)
pcd_path = os.path.join(os.getcwd(), wd, "pcd.npz")
pcd_path = os.path.join("data", args.exp_name, args.pick_or_place, f"{args.setting}.npz")
pcd = np.load(pcd_path)

input_xyz = torch.tensor(pcd["xyz"]).float().unsqueeze(0).to(cfg_seg.device)
input_rgb = torch.tensor(pcd["rgb"]).float().unsqueeze(0).to(cfg_seg.device)

model_dir = os.path.join(os.getcwd(), "experiments", args.exp_name, args.pick_or_place)
model_dir = os.path.join("experiments", args.exp_name, args.pick_or_place, "good_models")
policy_seg = globals()[cfg_seg.model](voxel_size=cfg_seg.voxel_size, radius_threshold=cfg_seg.radius_threshold).float().to(cfg_seg.device)
policy_seg.load_state_dict(torch.load(os.path.join(model_dir, "segnet.pth")))
policy_seg.eval()
Expand All @@ -30,6 +30,7 @@ def main(args):
policy_mani.load_state_dict(torch.load(os.path.join(model_dir, "maninet.pth")))
policy_mani.eval()

assert cfg_seg.device == cfg_mani.device, "Device mismatch between segmentation and manipulation networks!"

# preprossing input to keep the same as training
data = {
Expand Down Expand Up @@ -60,20 +61,18 @@ def main(args):
pred_pos = output_pos.detach().cpu().numpy().reshape(3)
pred_rot = out_dir_schmidt.detach().cpu().numpy()

result_path = os.path.join(os.getcwd(), wd, "pred_pose.npz")
result_path = os.path.join(wd, f"{args.setting}_pred_pose.npz")
np.savez(result_path,
pred_pos=pred_pos,
pred_rot=pred_rot,
)
print(f"Result saved to: {result_path}")

vis_result(input_xyz, input_rgb, pred_pos, pred_rot)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('exp_name', type=str, default="mug")
parser.add_argument('pick_or_place', type=str, choices=["pick", "place"], default="pick")
parser.add_argument('setting', type=str, default='new-pose')
parser.add_argument('-exp_name', type=str, default="mug")
parser.add_argument('-pick_or_place', type=str, choices=["pick", "place"], default="pick")
parser.add_argument('-setting', type=str, default='newpose')
args = parser.parse_args()

main(args)
24 changes: 24 additions & 0 deletions scripts/testing/vis_heatmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import open3d as o3d
import numpy as np
import torch

def interpolate_color(color1, color2, weight):
return [c1 * (1 - weight) + c2 * weight for c1, c2 in zip(color1, color2)]

def transfer_weight_to_color_heatmap(xyz, heatmap, color0=[0, 0, 0.4], color1 = [1, 0.1, 0]):
assert np.max(heatmap) <= 1 and np.min(heatmap) >= 0
colors = np.array([interpolate_color(color0, color1, hm) for hm in heatmap])

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(xyz)
pcd.colors = o3d.utility.Vector3dVector(colors)

return pcd

def vis_global_heatmap(pcd_path):
scene_heatmap = o3d.io.read_point_cloud(pcd_path)
scene_heatmap_pcd = transfer_weight_to_color_heatmap(np.asarray(scene_heatmap.points), np.asarray(scene_heatmap.colors).mean(axis=-1))
o3d.visualization.draw_geometries([scene_heatmap_pcd])

if __name__ == "__main__":
pcd_path = f"pcd/seg/pos_heatmap_{pcd_name}_{i}.pcd"
51 changes: 51 additions & 0 deletions scripts/testing/vis_ori_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import open3d as o3d
import numpy as np
import torch

def rotation_matrix_from_vectors(vec1, vec2):
""" Find the rotation matrix that aligns vec1 to vec2 """
a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3)
v = np.cross(a, b)
c = np.dot(a, b)
s = np.linalg.norm(v)
kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2))
return rotation_matrix

def draw_arrow(unit_vector, translation, color):
# Default arrow direction is along Z-axis
default_arrow_direction = np.array([0, 0, 1])
R = rotation_matrix_from_vectors(default_arrow_direction, unit_vector)
arrow = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.00025,
cone_radius=0.0009,
cylinder_height=np.linalg.norm(unit_vector) * 0.007,
cone_height=np.linalg.norm(unit_vector) * 0.004)
arrow.paint_uniform_color(color)
arrow.rotate(R, center=(0, 0, 0))
arrow.translate(translation)

return arrow

def draw_ori_feature(ball_pcd, ball_oris):
arrows = []
for i in range(ball_oris.shape[0]):
# Your desired vector
red_vector = ball_oris[i][0:3].numpy()
green_vector = ball_oris[i][3:6].numpy()
blue_vector = ball_oris[i][6:9].numpy()
unit_red_vector = red_vector / np.linalg.norm(red_vector)
unit_green_vector = green_vector / np.linalg.norm(green_vector)
unit_blue_vector = blue_vector / np.linalg.norm(blue_vector)

arrows.append(draw_arrow(unit_red_vector, np.asarray(ball_pcd.points)[i], (1, 0, 0)))
arrows.append(draw_arrow(unit_green_vector, np.asarray(ball_pcd.points)[i], (0, 1, 0)))
arrows.append(draw_arrow(unit_blue_vector, np.asarray(ball_pcd.points)[i], (0, 0, 1)))

# Visualize the point cloud and the vectors
o3d.visualization.draw_geometries([ball_pcd, *arrows])


if "__name__" == "__main__":
ball_pcd = o3d.io.read_point_cloud("eval.pcd")
ori_feature = torch.load('ori_feature_{pcd_name}.pt', map_location=torch.device('cpu')).detach()
draw_ori_feature(ball_pcd, ori_feature)
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
import os
import colorsys

def main():
result_path = os.path.join("pred_pose.npz")
def main(pcd_path, result_path):
result = np.load(result_path)
pred_pos = result["pred_pos"]
pred_rot = result["pred_rot"]
pred_trans = np.identity(4)
pred_trans[:3, :3] = pred_rot
pred_trans[:3, 3] = pred_pos

pcd_path = os.path.join("pcd_ee_frame.npz")
pcd = np.load(pcd_path)
xyz = pcd["xyz"]
rgb = pcd["rgb"]
Expand All @@ -21,8 +19,6 @@ def main():
pcd.points = o3d.utility.Vector3dVector(xyz)
pcd.colors = o3d.utility.Vector3dVector(rgb)

coor_ori = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])

coor_pred = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.2, origin=[0, 0, 0])
coor_pred = coor_pred.transform(pred_trans)

Expand All @@ -33,7 +29,6 @@ def main():
adjusted_colors = np.array([colorsys.hsv_to_rgb(*color) for color in hsv_colors])
pcd.colors = o3d.utility.Vector3dVector(adjusted_colors)


vis = o3d.visualization.Visualizer()
vis.create_window()
vis.add_geometry(pcd)
Expand All @@ -45,4 +40,6 @@ def main():
vis.destroy_window()

if __name__ == "__main__":
main()
pcd_path = os.path.join("data", "mug", "pick", "new-pose.npz")
result_path = os.path.join("experiments", "mug", "pick", "new-pose", "new-pose_pred_pose.npz")
main(pcd_path, result_path)
37 changes: 0 additions & 37 deletions utils/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,3 @@ def save_pcd_as_pcd(xyz, rgb, save_file="./pcd/test.pcd", draw_heatmap=False):
pcd.colors=o3d.utility.Vector3dVector(rgb)

o3d.io.write_point_cloud(save_file, pcd)

def vis_result(xyz, rgb, pred_pos, pred_rot, point_size=10.0, color_enhance=True):
if isinstance(xyz, torch.Tensor):
xyz = xyz.detach().cpu().numpy()
rgb = rgb.detach().cpu().numpy()
pred_pos = pred_pos.detach().cpu().numpy()
pred_rot = pred_rot.detach().cpu().numpy()

pred_trans = np.identity(4)
pred_trans[:3, :3] = pred_rot
pred_trans[:3, 3] = pred_pos

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(xyz)
pcd.colors = o3d.utility.Vector3dVector(rgb)

coor_pred = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.2, origin=[0, 0, 0])
coor_pred = coor_pred.transform(pred_trans)

if color_enhance:
# enhance color for better visualization
colors = np.asarray(pcd.colors)
hsv_colors = np.array([colorsys.rgb_to_hsv(*color) for color in colors])
hsv_colors[:, 1] *= 1.4
hsv_colors[:, 1] = np.clip(hsv_colors[:, 1], 0, 1)
adjusted_colors = np.array([colorsys.hsv_to_rgb(*color) for color in hsv_colors])
pcd.colors = o3d.utility.Vector3dVector(adjusted_colors)

vis = o3d.visualization.Visualizer()
vis.create_window()
vis.add_geometry(pcd)
vis.add_geometry(coor_pred)
opt = vis.get_render_option()
opt.point_size = point_size # change point size

vis.run()
vis.destroy_window()

0 comments on commit 36e0698

Please sign in to comment.