diff --git a/.github/doc_env.yaml b/.github/doc_env.yaml index 0e3769b4d..bf830c57f 100644 --- a/.github/doc_env.yaml +++ b/.github/doc_env.yaml @@ -1,5 +1,4 @@ channels: - - pytorch - conda-forge name: sam diff --git a/.github/workflows/build_installers.yaml b/.github/workflows/build_installers.yaml index 0ca5c18f4..699c3c489 100644 --- a/.github/workflows/build_installers.yaml +++ b/.github/workflows/build_installers.yaml @@ -24,7 +24,7 @@ jobs: RUN_SCRIPT: | python version_getter.py mkdir ./${{ matrix.os }}_x86_64 - constructor --output-dir ./${{ matrix.os }}_x86_64 --config-filename construct_${{ matrix.os }}.yaml . + constructor --output-dir ./${{ matrix.os }}_x86_64 --config-filename construct_${{ matrix.os }}.yaml . steps: - name: checkout diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 779dbbbb8..412e19b79 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -31,7 +31,7 @@ jobs: - name: Setup micromamba uses: mamba-org/setup-micromamba@v1 with: - environment-file: environment_cpu.yaml + environment-file: ${{ runner.os == 'Windows' && 'environment_cpu_win.yaml' || 'environment.yaml' }} create-args: >- python=${{ matrix.python-version }} diff --git a/doc/annotation_tools.md b/doc/annotation_tools.md index 4692ed1a2..8527e2dfb 100644 --- a/doc/annotation_tools.md +++ b/doc/annotation_tools.md @@ -22,8 +22,8 @@ You can find additional information on the annotation tools [in the FAQ section] HINT: If you would like to start napari to use `micro-sam` from the plugin menu, you must start it by activating the environment where `micro-sam` has been installed using: ```bash -$ mamba activate -$ napari +conda activate +napari ``` diff --git a/doc/contributing.md b/doc/contributing.md index 0be1717e3..a7069dbd1 100644 --- a/doc/contributing.md +++ b/doc/contributing.md @@ -37,10 +37,10 @@ We use [conda](https://docs.conda.io/en/latest/) to [manage our environments](ht Now you can create the environment, install user and developer dependencies, and micro-sam as an editable installation: ```bash -$ mamba env create environment_gpu.yaml -$ mamba activate sam -$ python -m pip install requirements-dev.txt -$ python -m pip install -e . +conda env create environment.yaml +conda activate sam +python -m pip install requirements-dev.txt +python -m pip install -e . ``` ### Make your changes diff --git a/doc/faq.md b/doc/faq.md index 3661fefb5..9100a29ad 100644 --- a/doc/faq.md +++ b/doc/faq.md @@ -7,12 +7,12 @@ If you encounter a problem or question not addressed here feel free to [open an ### 1. How to install `micro_sam`? -The [installation](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#installation) for `micro_sam` is supported in three ways: [from mamba](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#from-mamba) (recommended), [from source](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#from-source) and [from installers](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#from-installer). Check out our [tutorial video](https://youtu.be/gcv0fa84mCc) to get started with `micro_sam`, briefly walking you through the installation process and how to start the tool. +The [installation](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#installation) for `micro_sam` is supported in three ways: [from conda](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#from-conda) (recommended), [from source](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#from-source) and [from installers](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#from-installer). Check out our [tutorial video](https://youtu.be/gcv0fa84mCc) to get started with `micro_sam`, briefly walking you through the installation process and how to start the tool. ### 2. I cannot install `micro_sam` using the installer, I am getting some errors. The installer should work out-of-the-box on Windows and Linux platforms. Please open an issue to report the error you encounter. ->NOTE: The installers enable using `micro_sam` without mamba or conda. However, we recommend the installation from mamba / from source to use all its features seamlessly. Specifically, the installers currently only support the CPU and won't enable you to use the GPU (if you have one). +>NOTE: The installers enable using `micro_sam` without conda. However, we recommend the installation from conda or from source to use all its features seamlessly. Specifically, the installers currently only support the CPU and won't enable you to use the GPU (if you have one). ### 3. What is the minimum system requirement for `micro_sam`? @@ -40,7 +40,7 @@ Having a GPU will significantly speed up the annotation tools and especially the ### 5. I am missing a few packages (eg. `ModuleNotFoundError: No module named 'elf.io`). What should I do? -With the latest release 1.0.0, the installation from mamba and source should take care of this and install all the relevant packages for you. +With the latest release 1.0.0, the installation from conda and source should take care of this and install all the relevant packages for you. So please reinstall `micro_sam`, following [the installation guide](#installation). ### 6. Can I install `micro_sam` using pip? @@ -132,7 +132,7 @@ We want to remove these errors, so we would be very grateful if you can [open an ### 10. The objects are not segmented in my 3d data using the interactive annotation tool. -The first thing to check is: a) make sure you are using the latest version of `micro_sam` (pull the latest commit from master if your installation is from source, or update the installation from conda / mamba using `mamba update micro_sam`), and b) try out the steps from the [3d annotation tutorial video](https://youtu.be/nqpyNQSyu74) to verify if this shows the same behaviour (or the same errors) as you faced. For 3d images, it's important to pass the inputs in the python axis convention, ZYX. +The first thing to check is: a) make sure you are using the latest version of `micro_sam` (pull the latest commit from master if your installation is from source, or update the installation from conda using `conda update micro_sam`), and b) try out the steps from the [3d annotation tutorial video](https://youtu.be/nqpyNQSyu74) to verify if this shows the same behaviour (or the same errors) as you faced. For 3d images, it's important to pass the inputs in the python axis convention, ZYX. c) try using a different model and change the projection mode for 3d segmentation. This is also explained in the video. diff --git a/doc/installation.md b/doc/installation.md index c155f02c1..ccef6097f 100644 --- a/doc/installation.md +++ b/doc/installation.md @@ -1,77 +1,87 @@ # Installation There are three ways to install `micro_sam`: -- [From mamba](#from-mamba) is the recommended way if you want to use all functionality. +- [From conda](#from-conda) is the recommended way if you want to use all functionality. - [From source](#from-source) for setting up a development environment to use the latest version and to change and contribute to our software. -- [From installer](#from-installer) to install it without having to use mamba (supported platforms: Windows and Linux, supports only CPU). +- [From installer](#from-installer) to install it without having to use conda (supported platforms: Windows and Linux, supports only CPU). You can find more information on the installation and how to troubleshoot it in [the FAQ section](#installation-questions). -We do **not** recommend installing `micro-sam` with pip. +We do **not support** installing `micro_sam` with pip. -## From mamba +## From conda -[mamba](https://mamba.readthedocs.io/en/latest/) is a drop-in replacement for conda, but much faster. -The steps below may also work with `conda`, but we recommend using `mamba`, especially if the installation does not work with `conda`. -You can follow the instructions [here](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) to install `mamba`. +`conda` is a python package manager. If you don't have it installed yet you can follow the instructions [here](https://conda-forge.org/download/) to set it up on your system. +Please make sure that you are using an up-to-date version of conda to install `micro_sam`. +You can also use [mamba](https://mamba.readthedocs.io/en/latest/), which is a drop-in replacement for conda, to install it. In this case, just replace the `conda` command below with `mamba`. -**IMPORTANT**: Make sure to avoid installing anything in the base environment. +**IMPORTANT**: Do not install `micro_sam` in the base conda environment. + +**Installation on Linux and Mac OS:** `micro_sam` can be installed in an existing environment via: ```bash -$ mamba install -c pytorch -c conda-forge micro_sam +conda install -c conda-forge micro_sam ``` -or you can create a new environment (here called `micro-sam`) via: - +or you can create a new environment with it (here called `micro-sam`) via: ```bash -$ mamba create -c pytorch -c conda-forge -n micro-sam micro_sam +conda create -c conda-forge -n micro-sam micro_sam ``` - -if you want to use the GPU you need to install PyTorch from the `pytorch` channel instead of `conda-forge`. For example: - +and then activate it via ```bash -$ mamba create -c pytorch -c nvidia -c conda-forge -n micro-sam micro_sam pytorch pytorch-cuda=12.1 +conda activate micro-sam ``` -NOTE: If you create a new enviroment (eg. here called `micro-sam`), you must activate the environment using - +This will also install `pytorch` from the `conda-forge` channel. If you have a recent enough operating system, it will automatically install the best suitable `pytorch` version on your system. +This means it will install the CPU version if you don't have a nVidia GPU, and will install a GPU version if you have. +However, if you have an older operating system, or a CUDA version older than 12, than it may not install the correct version. In this case you will have to specify you're CUDA version, for example for CUDA 11, like this: ```bash -$ mamba activate micro-sam +conda install -c conda-forge micro_sam "libtorch=*=cuda11*" ``` -You may need to change this command to install the correct CUDA version for your system, see [https://pytorch.org/](https://pytorch.org/) for details. +**Installation on Windows:** +`pytorch` is currently not available on conda-forge for windows. Thus, you have to install it from the `pytorch` conda channel. In addition, you have to specify two specific dependencies to avoid incompatibilities. +This can be done with the following commands: +```bash +conda install -c pytorch -c conda-forge micro_sam "nifty=1.2.1=*_4" "protobuf<5" +``` +to install `micro_sam` in an existing environment and +```bash +conda create -c conda-forge -n micro-sam micro_sam "nifty=1.2.1=*_4" "protobuf<5" +``` ## From source To install `micro_sam` from source, we recommend to first set up an environment with the necessary requirements: -- [environment_gpu.yaml](https://github.com/computational-cell-analytics/micro-sam/blob/master/environment_gpu.yaml): sets up an environment with GPU support. -- [environment_cpu.yaml](https://github.com/computational-cell-analytics/micro-sam/blob/master/environment_cpu.yaml): sets up an environment with CPU support. +- [environment.yaml](https://github.com/computational-cell-analytics/micro-sam/blob/master/environment.yaml): to set up an environment on Linux or Mac OS. +- [environment_cpu_win.yaml](https://github.com/computational-cell-analytics/micro-sam/blob/master/environment_cpu_win.yaml): to set up an environment on windows with CPU support. +- [environment_gpu_win.yaml](https://github.com/computational-cell-analytics/micro-sam/blob/master/environment_gpu_win.yaml): to set up an environment on windows with GPU support. To create one of these environments and install `micro_sam` into it follow these steps 1. Clone the repository: ```bash -$ git clone https://github.com/computational-cell-analytics/micro-sam +git clone https://github.com/computational-cell-analytics/micro-sam ``` 2. Enter it: ```bash -$ cd micro-sam +cd micro-sam ``` -3. Create the GPU or CPU environment: +3. Create the respective environment: ```bash -$ mamba env create -f .yaml +conda env create -f .yaml ``` 4. Activate the environment: ```bash -$ mamba activate sam +conda activate sam ``` 5. Install `micro_sam`: @@ -89,7 +99,7 @@ We also provide installers for Linux and Windows: - [Mac](https://owncloud.gwdg.de/index.php/s/7YupGgACw9SHy2P) --> -The installers will not enable you to use a GPU, so if you have one then please consider installing `micro_sam` via [mamba](#from-mamba) instead. They will also not enable using the python library. +The installers will not enable you to use a GPU, so if you have one then please consider installing `micro_sam` via [conda](#from-conda) instead. They will also not enable using the python library. ### Linux Installer: diff --git a/doc/start_page.md b/doc/start_page.md index f4a043122..5139c2622 100644 --- a/doc/start_page.md +++ b/doc/start_page.md @@ -21,9 +21,9 @@ If you run into any problems or have questions please [open an issue](https://gi ## Quickstart -You can install `micro_sam` via mamba: +You can install `micro_sam` via conda: ```bash -$ mamba install -c conda-forge micro_sam +conda install -c conda-forge micro_sam ``` We also provide installers for Windows and Linux. For more details on the available installation options, check out [the installation section](#installation). diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 000000000..26c61f526 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,30 @@ +name: sam +channels: + - conda-forge +dependencies: + - nifty >=1.2.1 + - imagecodecs + - magicgui + - napari >=0.5.0 + - natsort + - pip + - pooch + - pyqt + - python-xxhash + - python-elf >=0.4.8 + # Note: installing the pytorch package from conda-forge will generally + # give you the most optmized version for your system, if you have a modern + # enough OS and CUDA version (CUDA >= 12). For older versions, you can + # specify the CUDA version by pinning libtorch. + # For example, add this line for a CUDA 11 version: + # - libtorch=*=cuda11* + # or, to enforce a CPU installation, change to + # - "pytorch=*=cpu*" + - pytorch >=2.4 + - segment-anything + - torchvision + - torch_em >=0.7.0 + - tqdm + - timm + - pip: + - git+https://github.com/ChaoningZhang/MobileSAM.git diff --git a/environment_cpu.yaml b/environment_cpu_win.yaml similarity index 100% rename from environment_cpu.yaml rename to environment_cpu_win.yaml diff --git a/environment_gpu.yaml b/environment_gpu_win.yaml similarity index 100% rename from environment_gpu.yaml rename to environment_gpu_win.yaml diff --git a/examples/bioimageio/export_model_for_bioengine.py b/examples/bioimageio/export_model_for_bioengine.py new file mode 100644 index 000000000..2683037b4 --- /dev/null +++ b/examples/bioimageio/export_model_for_bioengine.py @@ -0,0 +1,3 @@ +from micro_sam.bioimageio.bioengine_export import export_bioengine_model + +export_bioengine_model("vit_b", "test-export", opset=12) diff --git a/micro_sam/bioimageio/bioengine_export.py b/micro_sam/bioimageio/bioengine_export.py new file mode 100644 index 000000000..e7e63454c --- /dev/null +++ b/micro_sam/bioimageio/bioengine_export.py @@ -0,0 +1,250 @@ +import os +import warnings +from typing import Optional, Union + +import torch +from segment_anything.utils.onnx import SamOnnxModel + +try: + import onnxruntime + onnxruntime_exists = True +except ImportError: + onnxruntime_exists = False + +from ..util import get_sam_model + + +ENCODER_CONFIG = """name: "%s" +backend: "pytorch" +platform: "pytorch_libtorch" + +max_batch_size : 1 +input [ + { + name: "input0__0" + data_type: TYPE_FP32 + dims: [3, -1, -1] + } +] +output [ + { + name: "output0__0" + data_type: TYPE_FP32 + dims: [256, 64, 64] + } +] + +parameters: { + key: "INFERENCE_MODE" + value: { + string_value: "true" + } +}""" + + +DECODER_CONFIG = """name: "%s" +backend: "onnxruntime" +platform: "onnxruntime_onnx" + +parameters: { + key: "INFERENCE_MODE" + value: { + string_value: "true" + } +} + +instance_group { + count: 1 + kind: KIND_CPU +}""" + + +def _to_numpy(tensor): + return tensor.cpu().numpy() + + +def export_image_encoder( + model_type: str, + output_root: Union[str, os.PathLike], + export_name: Optional[str] = None, + checkpoint_path: Optional[str] = None, +) -> None: + """Export SAM image encoder to torchscript. + + The torchscript image encoder can be used for predicting image embeddings + with a backed, e.g. with [the bioengine](https://github.com/bioimage-io/bioengine-model-runner). + + Args: + model_type: The SAM model type. + output_root: The output root directory where the exported model is saved. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the exported model. + """ + if export_name is None: + export_name = model_type + name = f"sam-{export_name}-encoder" + + output_folder = os.path.join(output_root, name) + weight_output_folder = os.path.join(output_folder, "1") + os.makedirs(weight_output_folder, exist_ok=True) + + predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + encoder = predictor.model.image_encoder + + encoder.eval() + input_ = torch.rand(1, 3, 1024, 1024) + traced_model = torch.jit.trace(encoder, input_) + weight_path = os.path.join(weight_output_folder, "model.pt") + traced_model.save(weight_path) + + config_output_path = os.path.join(output_folder, "config.pbtxt") + with open(config_output_path, "w") as f: + f.write(ENCODER_CONFIG % name) + + +def export_onnx_model( + model_type, + output_root, + opset: int, + export_name: Optional[str] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + return_single_mask: bool = True, + gelu_approximate: bool = False, + use_stability_score: bool = False, + return_extra_metrics: bool = False, +) -> None: + """Export SAM prompt encoder and mask decoder to onnx. + + The onnx encoder and decoder can be used for interactive segmentation in the browser. + This code is adapted from + https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py + + Args: + model_type: The SAM model type. + output_root: The output root directory where the exported model is saved. + opset: The ONNX opset version. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the SAM model. + return_single_mask: Whether the mask decoder returns a single or multiple masks. + gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend + does not have an efficient GeLU implementation. + use_stability_score: Whether to use the stability score instead of the predicted score. + return_extra_metrics: Whether to return a larger set of metrics. + """ + if export_name is None: + export_name = model_type + name = f"sam-{export_name}-decoder" + + output_folder = os.path.join(output_root, name) + weight_output_folder = os.path.join(output_folder, "1") + os.makedirs(weight_output_folder, exist_ok=True) + + _, sam = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True) + weight_path = os.path.join(weight_output_folder, "model.onnx") + + onnx_model = SamOnnxModel( + model=sam, + return_single_mask=return_single_mask, + use_stability_score=use_stability_score, + return_extra_metrics=return_extra_metrics, + ) + + if gelu_approximate: + for n, m in onnx_model.named_modules: + if isinstance(m, torch.nn.GELU): + m.approximate = "tanh" + + dynamic_axes = { + "point_coords": {1: "num_points"}, + "point_labels": {1: "num_points"}, + } + + embed_dim = sam.prompt_encoder.embed_dim + embed_size = sam.prompt_encoder.image_embedding_size + + mask_input_size = [4 * x for x in embed_size] + dummy_inputs = { + "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), + "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), + "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), + "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), + "has_mask_input": torch.tensor([1], dtype=torch.float), + "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), + } + + _ = onnx_model(**dummy_inputs) + + output_names = ["masks", "iou_predictions", "low_res_masks"] + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + warnings.filterwarnings("ignore", category=UserWarning) + with open(weight_path, "wb") as f: + print(f"Exporting onnx model to {weight_path}...") + torch.onnx.export( + onnx_model, + tuple(dummy_inputs.values()), + f, + export_params=True, + verbose=False, + opset_version=opset, + do_constant_folding=True, + input_names=list(dummy_inputs.keys()), + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + + if onnxruntime_exists: + ort_inputs = {k: _to_numpy(v) for k, v in dummy_inputs.items()} + # set cpu provider default + providers = ["CPUExecutionProvider"] + ort_session = onnxruntime.InferenceSession(weight_path, providers=providers) + _ = ort_session.run(None, ort_inputs) + print("Model has successfully been run with ONNXRuntime.") + + config_output_path = os.path.join(output_folder, "config.pbtxt") + with open(config_output_path, "w") as f: + f.write(DECODER_CONFIG % name) + + +def export_bioengine_model( + model_type, + output_root, + opset: int, + export_name: Optional[str] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + return_single_mask: bool = True, + gelu_approximate: bool = False, + use_stability_score: bool = False, + return_extra_metrics: bool = False, +) -> None: + """Export SAM model to a format compatible with the BioEngine. + + [The bioengine](https://github.com/bioimage-io/bioengine-model-runner) enables running the + image encoder on an online backend, so that SAM can be used in an online tool, or to predict + the image embeddings via the online backend rather than on CPU. + + Args: + model_type: The SAM model type. + output_root: The output root directory where the exported model is saved. + opset: The ONNX opset version. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the SAM model. + return_single_mask: Whether the mask decoder returns a single or multiple masks. + gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend + does not have an efficient GeLU implementation. + use_stability_score: Whether to use the stability score instead of the predicted score. + return_extra_metrics: Whether to return a larger set of metrics. + """ + export_image_encoder(model_type, output_root, export_name, checkpoint_path) + export_onnx_model( + model_type=model_type, + output_root=output_root, + opset=opset, + export_name=export_name, + checkpoint_path=checkpoint_path, + return_single_mask=return_single_mask, + gelu_approximate=gelu_approximate, + use_stability_score=use_stability_score, + return_extra_metrics=return_extra_metrics, + ) diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index b033055f4..5d86067a4 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -550,9 +550,13 @@ def run_amg( iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, peft_kwargs: Optional[Dict] = None, + cache_embeddings: bool = False ) -> str: - embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved - os.makedirs(embedding_folder, exist_ok=True) + if cache_embeddings: + embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved + os.makedirs(embedding_folder, exist_ok=True) + else: + embedding_folder = None predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) amg = AutomaticMaskGenerator(predictor) @@ -572,9 +576,15 @@ def run_amg( ) instance_segmentation.run_instance_segmentation_grid_search_and_inference( - amg, grid_search_values, - val_image_paths, val_gt_paths, test_image_paths, - embedding_folder, prediction_folder, gs_result_folder, + segmenter=amg, + grid_search_values=grid_search_values, + val_image_paths=val_image_paths, + val_gt_paths=val_gt_paths, + test_image_paths=test_image_paths, + embedding_dir=embedding_folder, + prediction_dir=prediction_folder, + result_dir=gs_result_folder, + experiment_folder=experiment_folder, ) return prediction_folder @@ -592,9 +602,13 @@ def run_instance_segmentation_with_decoder( val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], peft_kwargs: Optional[Dict] = None, + cache_embeddings: bool = False, ) -> str: - embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved - os.makedirs(embedding_folder, exist_ok=True) + if cache_embeddings: + embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved + os.makedirs(embedding_folder, exist_ok=True) + else: + embedding_folder = None predictor, decoder = get_predictor_and_decoder( model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, @@ -616,6 +630,6 @@ def run_instance_segmentation_with_decoder( segmenter, grid_search_values, val_image_paths, val_gt_paths, test_image_paths, embedding_dir=embedding_folder, prediction_dir=prediction_folder, - result_dir=gs_result_folder, + result_dir=gs_result_folder, experiment_folder=experiment_folder, ) return prediction_folder diff --git a/micro_sam/evaluation/instance_segmentation.py b/micro_sam/evaluation/instance_segmentation.py index ad1956b95..7440fd4e8 100644 --- a/micro_sam/evaluation/instance_segmentation.py +++ b/micro_sam/evaluation/instance_segmentation.py @@ -247,7 +247,7 @@ def run_instance_segmentation_grid_search( def run_instance_segmentation_inference( segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], image_paths: List[Union[str, os.PathLike]], - embedding_dir: Union[str, os.PathLike], + embedding_dir: Optional[Union[str, os.PathLike]], prediction_dir: Union[str, os.PathLike], generate_kwargs: Optional[Dict[str, Any]] = None, ) -> None: @@ -278,12 +278,18 @@ def run_instance_segmentation_inference( assert os.path.exists(image_path), image_path image = imageio.imread(image_path) - embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") + if embedding_dir is None: + embedding_path = None + else: + assert predictor is not None + embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") + image_embeddings = util.precompute_image_embeddings( predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings ) segmenter.initialize(image, image_embeddings) + masks = segmenter.generate(**generate_kwargs) if len(masks) == 0: # the instance segmentation can have no masks, hence we just save empty labels @@ -303,9 +309,7 @@ def run_instance_segmentation_inference( def evaluate_instance_segmentation_grid_search( - result_dir: Union[str, os.PathLike], - grid_search_parameters: List[str], - criterion: str = "mSA" + result_dir: Union[str, os.PathLike], grid_search_parameters: List[str], criterion: str = "mSA" ) -> Tuple[Dict[str, Any], float]: """Evaluate gridsearch results. @@ -318,7 +322,6 @@ def evaluate_instance_segmentation_grid_search( The best parameter setting. The evaluation score for the best setting. """ - # Load all the grid search results. gs_files = glob(os.path.join(result_dir, "*.csv")) gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) @@ -362,8 +365,9 @@ def run_instance_segmentation_grid_search_and_inference( val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], - embedding_dir: Union[str, os.PathLike], + embedding_dir: Optional[Union[str, os.PathLike]], prediction_dir: Union[str, os.PathLike], + experiment_folder: Union[str, os.PathLike], result_dir: Union[str, os.PathLike], fixed_generate_kwargs: Optional[Dict[str, Any]] = None, verbose_gs: bool = True, @@ -381,6 +385,7 @@ def run_instance_segmentation_grid_search_and_inference( test_image_paths: The input images for inference. embedding_dir: Folder to cache the image embeddings. prediction_dir: Folder to save the predictions. + experiment_folder: Folder for caching best grid search parameters in 'results'. result_dir: Folder to cache the evaluation results per image. fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. @@ -396,7 +401,7 @@ def run_instance_segmentation_grid_search_and_inference( print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) print() - save_grid_search_best_params(best_kwargs, best_msa, Path(embedding_dir).parent) + save_grid_search_best_params(best_kwargs, best_msa, experiment_folder) generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs generate_kwargs.update(best_kwargs) diff --git a/micro_sam/evaluation/multi_dimensional_segmentation.py b/micro_sam/evaluation/multi_dimensional_segmentation.py index 07b5820f0..09b55f125 100644 --- a/micro_sam/evaluation/multi_dimensional_segmentation.py +++ b/micro_sam/evaluation/multi_dimensional_segmentation.py @@ -4,13 +4,13 @@ from tqdm import tqdm from math import floor from itertools import product -from typing import Union, Tuple, Optional, List, Dict +from typing import Union, Tuple, Optional, List, Dict, Literal import imageio.v3 as imageio import torch -from elf.evaluation import mean_segmentation_accuracy +from elf.evaluation import mean_segmentation_accuracy, dice_score from .. import util from ..inference import batched_inference @@ -30,7 +30,7 @@ def default_grid_search_values_multi_dimensional_segmentation( iou_threshold_values: The values for `iou_threshold` used in the grid-search. By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. projection_method_values: The values for `projection` method used in the grid-search. - By default the values `mask`, `bounding_box` and `points` are used. + By default the values `mask`, `points`, `box`, `points_and_mask` and `single_point` are used. box_extension_values: The values for `box_extension` used in the grid-search. By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used. @@ -71,6 +71,7 @@ def segment_slices_from_ground_truth( verbose: bool = False, return_segmentation: bool = False, min_size: int = 0, + evaluation_metric: Literal["sa", "dice"] = "sa", ) -> Union[float, Tuple[np.ndarray, float]]: """Segment all objects in a volume by prompt-based segmentation in one slice per object. @@ -94,6 +95,7 @@ def segment_slices_from_ground_truth( return_segmentation: Whether to return the segmented volume. min_size: The minimal size for evaluating an object in the ground-truth. The size is measured within the central slice. + evaluation_metric: The choice of supported metric to evaluate predictions. """ assert volume.ndim == 3 @@ -111,8 +113,12 @@ def segment_slices_from_ground_truth( # Create an empty volume to store incoming segmentations final_segmentation = np.zeros_like(ground_truth) + _segmentation_completed = False + if save_path is not None and os.path.exists(save_path): + _segmentation_completed = True # We avoid rerunning the segmentation if it is completed. + skipped_label_ids = [] - for label_id in label_ids: + for label_id in tqdm(label_ids, desc="Segmenting per object in the volume", disable=not verbose): # Binary label volume per instance (also referred to as object) this_seg = (ground_truth == label_id).astype("int") @@ -127,6 +133,9 @@ def segment_slices_from_ground_truth( skipped_label_ids.append(label_id) continue + if _segmentation_completed: + continue + if verbose: print(f"The object with id {label_id} lies in slice range: {slice_range}") @@ -185,7 +194,10 @@ def segment_slices_from_ground_truth( # Save the volumetric segmentation if save_path is not None: - imageio.imwrite(save_path, final_segmentation, compression="zlib") + if _segmentation_completed: + final_segmentation = imageio.imread(save_path) + else: + imageio.imwrite(save_path, final_segmentation, compression="zlib") # Evaluate the volumetric segmentation if skipped_label_ids: @@ -194,9 +206,16 @@ def segment_slices_from_ground_truth( else: curr_gt = ground_truth - msa, sa = mean_segmentation_accuracy(final_segmentation, curr_gt, return_accuracies=True) - results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]} - results = pd.DataFrame.from_dict([results]) + if evaluation_metric == "sa": + msa, sa = mean_segmentation_accuracy( + segmentation=final_segmentation, groundtruth=curr_gt, return_accuracies=True + ) + results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]} + elif evaluation_metric == "dice": + dice = dice_score(segmentation=final_segmentation, groundtruth=curr_gt) + results = {"Dice": dice} + else: + raise ValueError(f"'{evaluation_metric}' is not a supported evaluation metrics. Please choose 'sa' / 'dice'.") if return_segmentation: return results, final_segmentation @@ -204,20 +223,25 @@ def segment_slices_from_ground_truth( return results -def _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values): +def _get_best_parameters_from_grid_search_combinations( + result_dir, best_params_path, grid_search_values, evaluation_metric, +): if os.path.exists(best_params_path): print("The best parameters are already saved at:", best_params_path) return - best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys())) + criterion = "mSA" if evaluation_metric == "sa" else "Dice" + best_kwargs, best_metric = evaluate_instance_segmentation_grid_search( + result_dir=result_dir, grid_search_parameters=list(grid_search_values.keys()), criterion=criterion, + ) # let's save the best parameters - best_kwargs["mSA"] = best_msa + best_kwargs[criterion] = best_metric best_param_df = pd.DataFrame.from_dict([best_kwargs]) best_param_df.to_csv(best_params_path) best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) - print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) + print("Best grid-search result:", best_metric, "with parmeters:\n", best_param_str) def run_multi_dimensional_segmentation_grid_search( @@ -225,12 +249,13 @@ def run_multi_dimensional_segmentation_grid_search( ground_truth: np.ndarray, model_type: str, checkpoint_path: Union[str, os.PathLike], - embedding_path: Union[str, os.PathLike], + embedding_path: Optional[Union[str, os.PathLike]], result_dir: Union[str, os.PathLike], interactive_seg_mode: str = "box", verbose: bool = False, grid_search_values: Optional[Dict[str, List]] = None, - min_size: int = 0 + min_size: int = 0, + evaluation_metric: Literal["sa", "dice"] = "sa", ): """Run grid search for prompt-based multi-dimensional instance segmentation. @@ -240,7 +265,7 @@ def run_multi_dimensional_segmentation_grid_search( ``` grid_search_values = { "iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9], - "projection": ["mask", "bounding_box", "points"], + "projection": ["mask", "box", "points"], "box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5], } ``` @@ -254,12 +279,13 @@ def run_multi_dimensional_segmentation_grid_search( model_type: Choice of segment anything model. checkpoint_path: Path to the model checkpoint. embedding_path: Path to cache the computed embeddings. - result_path: Path to save the grid search results. + result_dir: Path to save the grid search results. interactive_seg_mode: Method for guiding prompt-based instance segmentation. verbose: Whether to get the trace for projected segmentations. grid_search_values: The grid search values for parameters of the `segment_slices_from_ground_truth` function. min_size: The minimal size for evaluating an object in the ground-truth. The size is measured within the central slice. + evaluation_metric: The choice of metric for evaluating predictions. """ if grid_search_values is None: grid_search_values = default_grid_search_values_multi_dimensional_segmentation() @@ -270,7 +296,9 @@ def run_multi_dimensional_segmentation_grid_search( result_path = os.path.join(result_dir, "all_grid_search_results.csv") best_params_path = os.path.join(result_dir, "grid_search_params_multi_dimensional_segmentation.csv") if os.path.exists(result_path): - _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values) + _get_best_parameters_from_grid_search_combinations( + result_dir, best_params_path, grid_search_values, evaluation_metric + ) return best_params_path # Compute all combinations of grid search values. @@ -282,7 +310,7 @@ def run_multi_dimensional_segmentation_grid_search( ] net_list = [] - for gs_kwargs in tqdm(gs_combinations): + for gs_kwargs in tqdm(gs_combinations, desc="Run grid-search for multi-dimensional segmentation"): results = segment_slices_from_ground_truth( volume=volume, ground_truth=ground_truth, @@ -293,6 +321,7 @@ def run_multi_dimensional_segmentation_grid_search( verbose=verbose, return_segmentation=False, min_size=min_size, + evaluation_metric=evaluation_metric, **gs_kwargs ) @@ -303,6 +332,8 @@ def run_multi_dimensional_segmentation_grid_search( res_df = pd.concat(net_list, ignore_index=True) res_df.to_csv(result_path) - _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values) + _get_best_parameters_from_grid_search_combinations( + result_dir, best_params_path, grid_search_values, evaluation_metric + ) print("The best grid-search parameters have been computed and stored at:", best_params_path) return best_params_path diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index febbccf6b..65d95bf39 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -1,6 +1,7 @@ import math from typing import List, Union, Optional +import torch import torch.nn as nn from segment_anything.modeling import Sam @@ -27,11 +28,13 @@ def __init__(self, rank: int, block: nn.Module): super().__init__() self.qkv_proj = block.attn.qkv self.dim = self.qkv_proj.in_features + self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. + self.rank = rank - self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) - self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) - self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False) - self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False) + self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) + self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) + self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) + self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) self.reset_parameters() @@ -45,8 +48,8 @@ def reset_parameters(self): def forward(self, x): qkv = self.qkv_proj(x) # B, N, N, 3 * org_C - new_q = self.w_b_linear_q(self.w_a_linear_q(x)) - new_v = self.w_b_linear_v(self.w_a_linear_v(x)) + new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) + new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) qkv[:, :, :, :self.dim] += new_q qkv[:, :, :, -self.dim:] += new_v return qkv @@ -105,6 +108,54 @@ def forward(self, x): return qkv +class ScaleShiftLayer(nn.Module): + def __init__(self, layer, dim): + super().__init__() + self.layer = layer + self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) + self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) + layer = self + + def forward(self, x): + x = self.layer(x) + assert self.scale.shape == self.shift.shape + if x.shape[-1] == self.scale.shape[0]: + return x * self.scale + self.shift + elif x.shape[1] == self.scale.shape[0]: + return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) + else: + raise ValueError('Input tensors do not match the shape of the scale factors.') + + +class SSFSurgery(nn.Module): + """Operates on all layers in the transformer block for adding learnable scale and shift parameters. + + Args: + rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency. + block: The chosen attention blocks for implementing ssf. + dim: The input dimensions determining the shape of scale and shift parameters. + """ + def __init__(self, rank: int, block: nn.Module): + super().__init__() + self.block = block + + # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. + if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. + block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) + block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) + block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) + block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) + block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) + block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) + + # If we get the embedding block, add one ScaleShiftLayer + elif hasattr(block, "patch_embed"): + block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels) + + def forward(self, x): + return x + + class SelectiveSurgery(nn.Module): """Base class for selectively allowing gradient updates for certain parameters. """ @@ -123,7 +174,7 @@ def allow_gradient_update_for_parameters( Args: prefix: Matches the part of parameter name in front. suffix: Matches the part of parameter name at the end. - infix: Matches parts of parameter name occuring in between. + infix: Matches parts of parameter name occuring in between. """ for k, v in self.block.named_parameters(): if prefix is not None and k.startswith(tuple(prefix)): @@ -141,6 +192,68 @@ def forward(self, x): return x +class AdaptFormer(nn.Module): + """Adds AdaptFormer Module in place of the MLP Layers + + Args: + rank: The rank is not used in this class but kept here for consistency. + block: The chosen encoder block for implementing AdaptFormer. + alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value. + dropout: The dropout rate for the dropout layer between down and up projection layer. + projection_size: The size of the projection layer. + """ + def __init__( + self, + rank: int, + block: nn.Module, + alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp. + dropout: Optional[float] = None, # Does not have an obvious advantage. + projection_size: int = 64, # Stable choice from our preliminary exp. + ): + super().__init__() + + self.mlp_proj = block.mlp + self.n_embd = block.mlp.lin1.in_features + + if alpha == 'learnable_scalar': + self.alpha = nn.Parameter(torch.ones(1)) + else: + self.alpha = alpha + + self.projection_size = projection_size + self.dropout = dropout + + self.down_proj = nn.Linear(self.n_embd, self.projection_size) + self.non_linear_func = nn.ReLU() + self.up_proj = nn.Linear(self.projection_size, self.n_embd) + + block.mlp = self + + if self.dropout is not None: + self.dropout_layer = nn.Dropout(self.dropout) + + nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) + nn.init.zeros_(self.up_proj.weight) + nn.init.zeros_(self.down_proj.bias) + nn.init.zeros_(self.up_proj.bias) + + def forward(self, x): + residual = x + mlp_output = self.mlp_proj(x) + + down = self.down_proj(x) + down = self.non_linear_func(down) + + if self.dropout is not None: + down = self.dropout_layer(down) + + up = self.up_proj(down) + up = up * self.alpha + output = up + residual + mlp_output + + return output + + class AttentionSurgery(SelectiveSurgery): """Child class for allowing gradient updates for parameters in attention layers. """ @@ -189,7 +302,10 @@ def __init__( super().__init__() assert rank > 0 - assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module." + + assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( + "Invalid PEFT module" + ) if attention_layers_to_update: self.peft_layers = attention_layers_to_update @@ -203,17 +319,19 @@ def __init__( for param in model.image_encoder.parameters(): param.requires_grad = False + # Add scale and shift parameters to the patch embedding layers. + if issubclass(self.peft_module, SSFSurgery): + self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) + for t_layer_i, blk in enumerate(model.image_encoder.blocks): # If we only want specific layers with PEFT instead of all if t_layer_i not in self.peft_layers: continue if issubclass(self.peft_module, SelectiveSurgery): - peft_block = self.peft_module(block=blk) + self.peft_blocks.append(self.peft_module(block=blk)) else: - peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs) - - self.peft_blocks.append(peft_block) + self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) self.peft_blocks = nn.ModuleList(self.peft_blocks) diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index 7828a9e2e..00b989c47 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -139,8 +139,9 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None if threshold is not None: iou = util.compute_iou(seg_prev, seg_z) if iou < threshold: - msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." - print(msg) + if verbose: + msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}." + print(msg) break segmentation[z] = seg_z diff --git a/micro_sam/sam_annotator/_annotator.py b/micro_sam/sam_annotator/_annotator.py index fcd58cc44..2dc7efee4 100644 --- a/micro_sam/sam_annotator/_annotator.py +++ b/micro_sam/sam_annotator/_annotator.py @@ -163,11 +163,15 @@ def _update_image(self, segmentation_result=None): # Reset all layers. self._viewer.layers["current_object"].data = np.zeros(self._shape, dtype="uint32") + self._viewer.layers["current_object"].scale = state.image_scale self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32") + self._viewer.layers["auto_segmentation"].scale = state.image_scale if segmentation_result is None or segmentation_result is False: self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32") else: assert segmentation_result.shape == self._shape self._viewer.layers["committed_objects"].data = segmentation_result - + self._viewer.layers["committed_objects"].scale = state.image_scale + self._viewer.layers["point_prompts"].scale = state.image_scale + self._viewer.layers["prompts"].scale = state.image_scale vutil.clear_annotations(self._viewer, clear_segmentations=False) diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index ee42e4ebb..3de8affb5 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -42,6 +42,7 @@ class AnnotatorState(metaclass=Singleton): image_embeddings: Optional[util.ImageEmbeddings] = None predictor: Optional[SamPredictor] = None image_shape: Optional[Tuple[int, int]] = None + image_scale: Optional[Tuple[float, ...]] = None embedding_path: Optional[str] = None data_signature: Optional[str] = None @@ -198,6 +199,7 @@ def reset_state(self): self.image_embeddings = None self.predictor = None self.image_shape = None + self.image_scale = None self.embedding_path = None self.amg = None self.amg_state = None diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 4a58a42a7..63ba5347f 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -325,7 +325,7 @@ def clear_volume(viewer: "napari.viewer.Viewer", all_slices: bool = True) -> Non if all_slices: vutil.clear_annotations(viewer) else: - i = int(viewer.cursor.position[0]) + i = int(viewer.dims.point[0]) vutil.clear_annotations_slice(viewer, i=i) @@ -341,7 +341,7 @@ def clear_track(viewer: "napari.viewer.Viewer", all_frames: bool = True) -> None _reset_tracking_state(viewer) vutil.clear_annotations(viewer) else: - i = int(viewer.cursor.position[0]) + i = int(viewer.dims.point[0]) vutil.clear_annotations_slice(viewer, i=i) @@ -736,7 +736,9 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None: return None shape = viewer.layers["current_object"].data.shape[1:] - position = viewer.cursor.position + + position_world = viewer.dims.point + position = viewer.layers["point_prompts"].world_to_data(position_world) z = int(position[0]) point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], z) @@ -775,7 +777,7 @@ def segment_frame(viewer: "napari.viewer.Viewer") -> None: return None state = AnnotatorState() shape = state.image_shape[1:] - position = viewer.cursor.position + position = viewer.dims.point t = int(position[0]) point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], i=t, track_id=state.current_track_id) @@ -868,7 +870,9 @@ def __init__(self, parent=None): def _initialize_image(self): state = AnnotatorState() image_shape = self.image_selection.get_value().data.shape + image_scale = tuple(self.image_selection.get_value().scale) state.image_shape = image_shape + state.image_scale = image_scale def _create_image_section(self): image_section = QtWidgets.QVBoxLayout() @@ -1083,6 +1087,9 @@ def __call__(self, skip_validate=False): ndim = image.data.ndim state.image_shape = image.data.shape + # Set layer scale + state.image_scale = tuple(image.scale) + # Process tile_shape and halo, set other data. tile_shape, halo = _process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) save_path = None if self.embeddings_save_path == "" else self.embeddings_save_path @@ -1655,7 +1662,7 @@ def __call__(self): if self.volumetric and self.apply_to_volume: worker = self._run_segmentation_3d(kwargs) elif self.volumetric and not self.apply_to_volume: - i = int(self._viewer.cursor.position[0]) + i = int(self._viewer.dims.point[0]) worker = self._run_segmentation_2d(kwargs, i=i) else: worker = self._run_segmentation_2d(kwargs) diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index 1887f371e..7916ea142 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -178,7 +178,7 @@ def point_layer_to_prompts( this_points, this_labels = points, labels else: assert points.shape[1] == 3, f"{points.shape}" - mask = points[:, 0] == i + mask = np.round(points[:, 0]) == i this_points = points[mask][:, 1:] this_labels = labels[mask] assert len(this_points) == len(this_labels) @@ -355,7 +355,7 @@ def segment_slices_with_prompts( image_shape = shape[1:] seg = np.zeros(shape, dtype="uint32") - z_values = point_prompts.data[:, 0] + z_values = np.round(point_prompts.data[:, 0]) z_values_boxes = np.concatenate([box[:1, 0] for box in box_prompts.data]) if box_prompts.data else\ np.zeros(0, dtype="int") diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index e2037aab1..79485b8d0 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -146,12 +146,6 @@ def set_description(self, desc, **kwargs): self._signals.pbar_description.emit(desc) -def _count_parameters(model_parameters): - params = sum(p.numel() for p in model_parameters if p.requires_grad) - params = params / 1e6 - print(f"The number of trainable parameters for the provided model is {round(params, 2)}M") - - @contextmanager def _filter_warnings(ignore_warnings): if ignore_warnings: @@ -163,6 +157,12 @@ def _filter_warnings(ignore_warnings): yield +def _count_parameters(model_parameters): + params = sum(p.numel() for p in model_parameters if p.requires_grad) + params = params / 1e6 + print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)") + + def train_sam( name: str, model_type: str, @@ -249,6 +249,7 @@ def train_sam( peft_kwargs=peft_kwargs, **model_kwargs ) + # This class creates all the training data for a batch (inputs, prompts and labels). convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index 4461aa9b1..059b08c2a 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -8,12 +8,7 @@ class TestPEFTSam(unittest.TestCase): model_type = "vit_b" - def test_lora_sam(self): - from micro_sam.models.peft_sam import PEFT_Sam, LoRASurgery - - _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") - peft_sam = PEFT_Sam(sam, rank=2, peft_module=LoRASurgery) - + def _check_output(self, peft_sam): shape = (3, 1024, 1024) expected_shape = (1, 3, 1024, 1024) with torch.no_grad(): @@ -22,61 +17,54 @@ def test_lora_sam(self): masks = output[0]["masks"] self.assertEqual(masks.shape, expected_shape) + def test_lora_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, LoRASurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=LoRASurgery) + self._check_output(peft_sam) + def test_fact_sam(self): from micro_sam.models.peft_sam import PEFT_Sam, FacTSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=FacTSurgery) - - shape = (3, 1024, 1024) - expected_shape = (1, 3, 1024, 1024) - with torch.no_grad(): - batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] - output = peft_sam(batched_input, multimask_output=True) - masks = output[0]["masks"] - self.assertEqual(masks.shape, expected_shape) + self._check_output(peft_sam) def test_attention_layer_peft_sam(self): from micro_sam.models.peft_sam import PEFT_Sam, AttentionSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=AttentionSurgery) - - shape = (3, 1024, 1024) - expected_shape = (1, 3, 1024, 1024) - with torch.no_grad(): - batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] - output = peft_sam(batched_input, multimask_output=True) - masks = output[0]["masks"] - self.assertEqual(masks.shape, expected_shape) + self._check_output(peft_sam) def test_norm_layer_peft_sam(self): from micro_sam.models.peft_sam import PEFT_Sam, LayerNormSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=LayerNormSurgery) - - shape = (3, 1024, 1024) - expected_shape = (1, 3, 1024, 1024) - with torch.no_grad(): - batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] - output = peft_sam(batched_input, multimask_output=True) - masks = output[0]["masks"] - self.assertEqual(masks.shape, expected_shape) + self._check_output(peft_sam) def test_bias_layer_peft_sam(self): from micro_sam.models.peft_sam import PEFT_Sam, BiasSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=BiasSurgery) + self._check_output(peft_sam) - shape = (3, 1024, 1024) - expected_shape = (1, 3, 1024, 1024) - with torch.no_grad(): - batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] - output = peft_sam(batched_input, multimask_output=True) - masks = output[0]["masks"] - self.assertEqual(masks.shape, expected_shape) + def test_ssf_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, SSFSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=SSFSurgery) + self._check_output(peft_sam) + + def test_adaptformer_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, AdaptFormer + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=AdaptFormer, projection_size=64, alpha=2.0, dropout=0.5) + self._check_output(peft_sam) if __name__ == "__main__":