Skip to content

Commit

Permalink
add demo
Browse files Browse the repository at this point in the history
  • Loading branch information
HeegerGao committed Mar 27, 2024
1 parent 9efa7a9 commit 36e5bb3
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,4 @@ wandb/

pcd/
experiments/
data/
test/
68 changes: 49 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,52 +1,82 @@
<h2 align="center">
<b>RiEMann: Near Real-Time SE(3)-Equivariant Robot Manipulation without Point Cloud Segmentation</b>

<!-- <div align="center">
<a href="" target="_blank">
<img src="https://img.shields.io/badge/Paper-arXiv-green" alt="Paper arXiv"></a>
<div align="center">
<a href="" target="_blank">
<img src="https://img.shields.io/badge/Paper-arXiv-green" alt="Paper ArXiv"></a>
<a href="https://riemann-web.github.io/" target="_blank">
<img src="https://img.shields.io/badge/Page-RiEMann-blue" alt="Project Page"/></a>
</div> -->
</div>
</h2>

This is the official repository of **RiEMann: Near Real-Time SE(3)-Equivariant Robot Manipulation without Point Cloud Segmentation**.
This is the official code repository of **RiEMann: Near Real-Time SE(3)-Equivariant Robot Manipulation without Point Cloud Segmentation**.

<!-- For more information, please visit our [**project page**](). -->

## Overview

## Installation
RiEMann is an SE(3)-equivariant robot manipulation algorithm that can generalize to novel SE(3) object poses with only 5 to 10 demonstrations.

Please follow the steps below to perform the installation:
![image](imgs/web_teaser.gif)

## Installation

### 1. Create virtual environment
```bash
conda create -n equi python==3.8
conda activate equi
```

### 2. Install Special Dependencies
Install [torch-sparse](https://github.com/rusty1s/pytorch_sparse), [torch-scatter](https://github.com/rusty1s/pytorch_scatter), [torch-cluster](https://github.com/rusty1s/pytorch_cluster), and [dgl](https://www.dgl.ai/pages/start.html) according to their official installation guidance.
Not sure if other python versions are OK.

### 3. Install General Dependiencies
### 2. Installation

```pip install -r requirements.txt```
1. Install [PyTorch](https://pytorch.org/). RiEMann is tested on CUDA version 11.7 and PyTorch version 2.0.1. Not sure if other versions are OK.

## Run
2. Install [torch-sparse](https://github.com/rusty1s/pytorch_sparse), [torch-scatter](https://github.com/rusty1s/pytorch_scatter), [torch-cluster](https://github.com/rusty1s/pytorch_cluster), and [dgl](https://www.dgl.ai/pages/start.html) according to their official installation guidance. We recommend to use the following commands to install:
```
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html
pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html
pip install dgl -f https://data.dgl.ai/wheels/cu117/repo.html
pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html
```

### Data
3. ```pip install -r requirements.txt```

Please put your data in the `data/exp_name`.
## Data Preparation

## Training
Please put your data in the `data/{your_exp_name}`. We provide the demonstrations for the mug experiment at `data/mug/pick`, both for training and testing.

`python scripts/training/train_seg.py`
The demonstration file is a .npz file and is in the following data structure:
```
{
"xyz": np.array([traj_num, video_len, point_num, 3]),
"rgb": np.array([traj_num, video_len, point_num, 3]),
"seg_center": np.array([traj_num, video_len, 3]),
"axes": np.array([traj_num, video_len, 9])
}
```
where the *seg_center* is the target position, and *axes* the target rotation, representated in the unit orthonormal basis of 3 axes v1, v2, v3 of the rotation matrix: [v1x, v1y, v1z, v2x, v2y, v2z, v3x, v3y, v3z].

`python scripts/training/train_mani.py`
## Training

## Evaluation
As stated in our paper, there is an SE(3)-invariant network $\phi$ that extracts the saliency map, and an SE(3)-equivariant network $\psi$ that leverages the saliency map to predict the target pose. **We must first train $\phi$, then train $\psi$.**

`python scripts/testing/infer.py`
1. `python scripts/training/train_seg.py`

2. `python scripts/training/train_mani.py`

After these training, you will get a `seg_net.pth` and a `mani_net.pth` under `experiments/{your_exp_name}`.
## Evaluation

Run `python scripts/testing/infer.py`. You can select the testing demonstrations in the input arguments. After this you will get a `pred_pose.npz` that records the predicted target pose, and a open3d window will visualize the result.

## Citing
```
@inproceedings{gao2024riemann,
title={RiEMann: Near Real-Time SE(3)-Equivariant Robot Manipulation without Point Cloud Segmentation},
author={Gao, Chongkai and Xue, Zhengrong and Deng, Shuying and Liang, Tianhai and Yang, Siqi and Zhu, Yuke},
booktitle={arXiv preprint arXiv:},
year={2024}
}
```
4 changes: 2 additions & 2 deletions config/mug/pick.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"seg": {
"model": "SE3SegNet",
"device": "cuda:2",
"device": "cuda:0",
"data_aug": true,
"aug_methods": [
"downsample_table",
Expand All @@ -22,7 +22,7 @@
},
"mani": {
"model": "SE3ManiNet",
"device": "cuda:2",
"device": "cuda:1",
"demo_dir": "data/demos",
"exp_name": "mug",
"data_aug": true,
Expand Down
2 changes: 1 addition & 1 deletion config/mug/place.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"seg": {
"model": "SE3SegNet",
"device": "cuda:1",
"device": "cuda:0",
"data_aug": true,
"aug_methods": [
"downsample_table",
Expand Down
Binary file added data/mug/pick/demo.npz
Binary file not shown.
Binary file added imgs/web_teaser.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 2 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
apex==0.9.10dev
dgl==1.1.2+cu117
e3nn==0.5.1
numpy==1.23.5
omegaconf==2.3.0
open3d==0.17.0
potpourri3d==1.0.0
pynvml==11.5.0
scipy==1.12.0
torch==2.0.1
torch_cluster==1.6.1
torch_scatter==2.1.1
torchvision==0.15.2
tqdm==4.66.1
scipy
tqdm==4.66.2
transforms3d==0.4.1
wandb==0.16.3
4 changes: 2 additions & 2 deletions scripts/testing/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def main(args):
cfg_seg = all_cfg.seg
cfg_mani = all_cfg.mani

wd = os.path.join("experiments", args.exp_name, args.pick_or_place, "deploy", args.setting)
wd = os.path.join("experiments", args.exp_name, args.pick_or_place, args.setting)
pcd_path = os.path.join(os.getcwd(), wd, "pcd.npz")
pcd = np.load(pcd_path)

Expand Down Expand Up @@ -73,7 +73,7 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument('exp_name', type=str, default="mug")
parser.add_argument('pick_or_place', type=str, choices=["pick", "place"], default="pick")
parser.add_argument('setting', type=str, default='val-a-4')
parser.add_argument('setting', type=str, default='new-pose')
args = parser.parse_args()

main(args)

0 comments on commit 36e5bb3

Please sign in to comment.