├── ckpts
│ ├── sam_vit_h_4b8939.pth #origin sam weights
│ ├── vit_h_mask_2024-01-15-05-42-30 #finetuned mask decoder and training log
│ ├── vit_h_semantic_mask_2024-01-15-22-51-30 #trained semantic mask decoder and training log
├── data
│ ├── processed #preprocessed dataset
│ ├── sam_embedding #image embedding generated by sam encoder
│ ├── Testing #testing split of original dataset
│ ├── Training #training and validation split of original dataset
│ ├── centers.txt
│ └── widths.txt
├── id_to_color.txt #mapping from organ class to color
└── sam_on_btcv
├── segment_anything #add build_sem_sam.py and SemanticMaskDecoder compared with origin sam
├── btcv_dataset.py #Dataset class
├── criterion.py #Dice loss
├── grid_sam.py #some applications with grid points prompts
├── myAutomaticMaskGenerator.py #automask that supports semantic mask decoder
├── my_pridictor.py #predictor that supports semantic mask decoder
├── preprocess_dataset.py #data preprocessing
├── finetune.py #finetune mask decoder
├── train_semantic.py #train semantic mask decoder
└── visualize.py #applications for visualization
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install opencv-python,tqdm,SimpleITK,open3d,matplotlib
1. Download pretrained weight for sam into ckpts/
2. Download dataset into data/Testing and data/Training
3. Run preprocess_dataset.py for data preprocessing
1. Run finetune.py to finetune mask decoder without semantic on BTCV
2. Run train_semantic.py to train semantic mask decoder (need to modify he parameter 'from_pretrain' on the bottom of the script)
After the training above, you can implement semantic automask on each image.
python grid_sam.py
This will generate semantic segmentation on each slice, and the results are saved in the format of a single channel image with pixel value ranging from 0 to 13.
In visualize.py
I support a series of applications for visualization.
visualize semantic segmentation result.
draw figure of training log, including loss, dice and acc.
save the semantic segmentation result of each 2D slice into 3D point cloud.