├── ckpts
│ ├── sam_vit_h_4b8939.pth #origin sam weights
│ ├── vit_h_mask_2024-01-15-05-42-30 #finetuned mask decoder and training log
│ ├── vit_h_semantic_mask_2024-01-15-22-51-30 #trained semantic mask decoder and training log
├── data
│ ├── processed #preprocessed dataset
│ ├── sam_embedding #image embedding generated by sam encoder
│ ├── Testing #testing split of original dataset
│ ├── Training #training and validation split of original dataset
│ ├── centers.txt
│ └── widths.txt
├── id_to_color.txt #mapping from organ class to color
└── sam_on_btcv
├── segment_anything #add build_sem_sam.py and SemanticMaskDecoder compared with origin sam
├── btcv_dataset.py #Dataset class
├── criterion.py #Dice loss
├── grid_sam.py #some applications with grid points prompts
├── myAutomaticMaskGenerator.py #automask that supports semantic mask decoder
├── my_pridictor.py #predictor that supports semantic mask decoder
├── preprocess_dataset.py #data preprocessing
├── finetune.py #finetune mask decoder
├── train_semantic.py #train semantic mask decoder
└── visualize.py #applications for visualization
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install opencv-python,tqdm,SimpleITK,open3d,matplotlib
1. Download pretrained weight for sam into ckpts/
2. Download dataset into data/Testing and data/Training
3. Run preprocess_dataset.py for data preprocessing
1. Run finetune.py to finetune mask decoder without semantic on BTCV
2. Run train_semantic.py to train semantic mask decoder (need to modify he parameter 'from_pretrain' on the bottom of the script)
After the training above, you can implement semantic automask on each image.
python grid_sam.py
This will generate semantic segmentation on each slice, and the results are saved in the format of a single channel image with pixel value ranging from 0 to 13.
In visualize.py
I support a series of applications for visualization.
vis_semantic_masks:
visualize semantic segmentation result.
plot_history:
draw figure of training log, including loss, dice and acc.
save_to_ply:
save the semantic segmentation result of each 2D slice into 3D point cloud.