Skip to content

Commit

Permalink
Add support for histopathology models (#834)
Browse files Browse the repository at this point in the history
* Add support for histopathology models

* Remove expected failure from bioimageio export tests
  • Loading branch information
anwai98 authored Jan 16, 2025
1 parent 3a366df commit 87e3d85
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 12 deletions.
16 changes: 16 additions & 0 deletions doc/bioimageio/histopathology_v1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Segment Anything for Histopathology

This is a [Segment Anything]https://segment-anything.com/) model that was specialized for histopathology with [micro_sam](https://github.com/computational-cell-analytics/micro-sam).
This model uses a %s vision transformer as image encoder.

Segment Anything is a model for interactive and automatic instance segmentation.
We improve it for histopathology by finetuning on a large and diverse microscopy dataset.
It should perform well for nucleus segmentation in histopathology datasets.

See [the dataset overview](https://github.com/computational-cell-analytics/micro-sam/blob/master/doc/datasets/histopathology_v%i.md) for further informations on the training data and the [micro_sam documentation](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html) for details on how to use the model for interactive and automatic segmentation.

## Validation

The easiest way to validate the model is to visually check the segmentation quality for your data.
If you have annotations you can use for validation you can also quantitative validation, see [here for details](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#9-how-can-i-evaluate-a-model-i-have-finetuned).
Please note that the required quality for segmentation always depends on the analysis task you want to solve.
22 changes: 13 additions & 9 deletions micro_sam/bioimageio/model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
"tags": ["segment-anything", "instance-segmentation"],
}

# Reference: https://github.com/bioimage-io/spec-bioimage-io/commit/39d343681d427ec93cf69eef7597d9eb9678deb1#diff-0bbdaa8196fa31f945afabcf04a4295ff098f1f24400ef9e59b0f684d411905eL269 # noqa
# We had this parameter in bioimageio.spec. This has been removed. We just make a copy of the same parameter.
ARBITRARY_SIZE = spec.ParameterizedSize(min=1, step=1)


def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path, tmp_dir):

Expand Down Expand Up @@ -204,7 +208,7 @@ def _check_model(model_description, input_paths, result_paths):
image = xarray.DataArray(np.load(input_paths["image"]), dims=tuple("bcyx"))
embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=tuple("bcyx"))
box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=tuple("bic"))
point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("biic"))
point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("bhwc"))
point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=tuple("bic"))
mask_prompts = xarray.DataArray(np.load(input_paths["mask_prompts"]), dims=tuple("bicyx"))

Expand Down Expand Up @@ -292,8 +296,8 @@ def export_sam_model(
# NOTE: to support 1 and 3 channels we can add another preprocessing.
# Best solution: Have a pre-processing for this! (1C -> RGB)
spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]),
spec.SpaceInputAxis(id=spec.AxisId("y"), size=spec.ARBITRARY_SIZE),
spec.SpaceInputAxis(id=spec.AxisId("x"), size=spec.ARBITRARY_SIZE),
spec.SpaceInputAxis(id=spec.AxisId("y"), size=ARBITRARY_SIZE),
spec.SpaceInputAxis(id=spec.AxisId("x"), size=ARBITRARY_SIZE),
],
test_tensor=spec.FileDescr(source=input_paths["image"]),
data=spec.IntervalOrRatioDataDescr(type="uint8")
Expand All @@ -307,7 +311,7 @@ def export_sam_model(
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]),
],
Expand All @@ -323,11 +327,11 @@ def export_sam_model(
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
spec.IndexInputAxis(
id=spec.AxisId("point"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]),
],
Expand All @@ -343,11 +347,11 @@ def export_sam_model(
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
spec.IndexInputAxis(
id=spec.AxisId("point"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
],
test_tensor=spec.FileDescr(source=input_paths["point_labels"]),
Expand All @@ -362,7 +366,7 @@ def export_sam_model(
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
spec.ChannelAxis(channel_names=["channel"]),
spec.SpaceInputAxis(id=spec.AxisId("y"), size=256),
Expand Down
14 changes: 14 additions & 0 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def models():
"vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb",
"vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc",
"vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
# Histopathology models:
"vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
"vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
"vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
}
# Additional decoders for instance segmentation.
decoder_registry = {
Expand All @@ -123,6 +127,10 @@ def models():
"vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb",
"vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f",
"vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
# Histopathology models:
"vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
"vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
"vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
}
registry = {**encoder_registry, **decoder_registry}

Expand All @@ -137,6 +145,9 @@ def models():
"vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", # noqa
"vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt",
"vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa
"vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
"vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
"vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
}

decoder_urls = {
Expand All @@ -146,6 +157,9 @@ def models():
"vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", # noqa
"vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", # noqa
"vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa
"vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
"vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
"vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
}
urls = {**encoder_urls, **decoder_urls}

Expand Down
132 changes: 132 additions & 0 deletions scripts/model_export/export_histopathology_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os
import xxhash
import argparse
import warnings
from glob import glob

import h5py

import bioimageio.spec.model.v0_5 as spec

from micro_sam.bioimageio import export_sam_model

from models import get_id_and_emoji


MODEL_TO_NAME = {
"vit_b_histopathology": "SAM Histopathology Generalist (ViT-B)",
"vit_l_histopathology": "SAM Histopathology Generalist (ViT-L)",
"vit_h_histopathology": "SAM Histopathology Generalist (ViT-H)",
}

BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
OUTPUT_FOLDER = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/exported_models/"
PUMA_ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/data/puma"


def create_doc(model_type, version):
template_file = os.path.join(
os.path.split(__file__)[0], "../../doc/bioimageio", f"histopathology_v{version}.md"
)
assert os.path.exists(template_file), template_file
with open(template_file, "r") as f:
template = f.read()

doc = template % (model_type, version)
return doc


def get_data():
input_paths = glob(os.path.join(PUMA_ROOT, "test", "preprocessed", "training_set_*.h5"))
# Choose the first input path
input_path = input_paths[0]

with h5py.File(input_path, "r") as f:
image = f["raw"][:]
label_image = f["labels/nuclei"][:]

# Convert to channels first.
image = image.transpose(1, 2, 0)

return image, label_image


def compute_checksum(path):
xxh_checksum = xxhash.xxh128()
with open(path, "rb") as f:
while True:
data = f.read(BUF_SIZE)
if not data:
break
xxh_checksum.update(data)
return xxh_checksum.hexdigest()


def export_model(model_path, model_type, version, email):
output_folder = os.path.join(OUTPUT_FOLDER, "histopathology")
os.makedirs(output_folder, exist_ok=True)

model_name = f"{model_type}_histopathology"

output_path = os.path.join(output_folder, model_name)
if os.path.exists(output_path):
print("The model", model_name, "has already been exported.")
return

image, label_image = get_data()
covers = ["./covers/cover_lm.png"] # HACK: We use existing covers.
doc = create_doc(model_type, version)

model_id, emoji = get_id_and_emoji(model_name)
uploader = spec.Uploader(email=email)

export_name = MODEL_TO_NAME[model_name]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
export_sam_model(
image, label_image,
name=export_name,
model_type=model_type,
checkpoint_path=model_path,
output_path=output_path,
documentation=doc,
covers=covers,
id=model_id,
id_emoji=emoji,
uploader=uploader,
)

# NOTE: I needed to unzip the files myself. Not sure how this worked before. Maybe something changed in spec?
from torch_em.data.datasets.util import unzip
unzip(zip_path=output_path, dst=(output_path + ".unzip"))

print("Exported model", model_id)
encoder_path = os.path.join(output_path + ".unzip", f"{model_type}.pt")
encoder_checksum = compute_checksum(encoder_path)
print("Encoder:")
print(model_name, f"xxh128:{encoder_checksum}")

decoder_path = os.path.join(output_path + ".unzip", f"{model_type}_decoder.pt")
decoder_checksum = compute_checksum(decoder_path)
print("Decoder:")
print(f"{model_name}_decoder", f"xxh128:{decoder_checksum}")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--email", required=True)
parser.add_argument("-v", "--version", default=1, type=int)
parser.add_argument("-c", "--checkpoint", required=True, type=str)
parser.add_argument("-m", "--model_type", required=True, type=str)
args = parser.parse_args()

export_model(
model_path=args.checkpoint,
model_type=args.model_type,
version=1,
email=args.email,
)


if __name__ == "__main__":
main()
5 changes: 2 additions & 3 deletions test/test_bioimageio/test_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from shutil import rmtree

import bioimageio.spec

import micro_sam.util as util
from micro_sam.sample_data import synthetic_data

spec_minor = int(bioimageio.spec.__version__.split(".")[1])


@unittest.skipIf(spec_minor < 5, "Needs bioimagio.spec >= 0.5")
@unittest.expectedFailure
class TestModelExport(unittest.TestCase):
tmp_folder = "tmp"
model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b"
Expand All @@ -20,9 +20,8 @@ def setUp(self):
os.makedirs(self.tmp_folder, exist_ok=True)

def tearDown(self):
rmtree(self.tmp_folder)
rmtree(self.tmp_folder, ignore_errors=True)

@unittest.expectedFailure
def test_model_export(self):
from micro_sam.bioimageio import export_sam_model
image, labels = synthetic_data(shape=(1024, 1022))
Expand Down

0 comments on commit 87e3d85

Please sign in to comment.