Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes seqfish reader v2 #268

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"readfcs",
"tifffile>=2023.8.12",
"ome-types",
"xmltodict",
]

[project.optional-dependencies]
Expand Down
4 changes: 1 addition & 3 deletions src/spatialdata_io/_constants/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class SeqfishKeys(ModeEnum):
TIFF_FILE = ".tiff"
GEOJSON_FILE = ".geojson"
# file identifiers
ROI = "Roi"
TRANSCRIPT_COORDINATES = "TranscriptList"
DAPI = "DAPI"
COUNTS_FILE = "CellxGene"
Expand All @@ -78,6 +77,7 @@ class SeqfishKeys(ModeEnum):
# transcripts
TRANSCRIPTS_X = "x"
TRANSCRIPTS_Y = "y"
TRANSCRIPTS_Z = "z"
FEATURE_KEY = "name"
INSTANCE_KEY_POINTS = "cell"
# cells
Expand All @@ -88,8 +88,6 @@ class SeqfishKeys(ModeEnum):
SPATIAL_KEY = "spatial"
REGION_KEY = "region"
INSTANCE_KEY_TABLE = "instance_id"
SCALEFEFACTOR_X = "PhysicalSizeX"
SCALEFEFACTOR_Y = "PhysicalSizeY"


@unique
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata_io/readers/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from pathlib import Path
from typing import Any, Union

from anndata import AnnData, read_text
from anndata import AnnData
from anndata.io import read_text
from h5py import File
from ome_types import from_tiff
from ome_types.model import Pixels, UnitsLength
Expand Down
158 changes: 109 additions & 49 deletions src/spatialdata_io/readers/seqfish.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import re
import xml.etree.ElementTree as ET
import warnings
from collections.abc import Mapping
from pathlib import Path
from types import MappingProxyType
Expand All @@ -12,8 +12,10 @@
import numpy as np
import pandas as pd
import tifffile
import xmltodict
from dask_image.imread import imread
from spatialdata import SpatialData
from spatialdata._logging import logger
from spatialdata.models import (
Image2DModel,
Labels2DModel,
Expand All @@ -28,29 +30,31 @@

__all__ = ["seqfish"]

LARGE_IMAGE_THRESHOLD = 100_000_000

@inject_docs(vx=SK)

@inject_docs(vx=SK, megapixels_value=str(int(LARGE_IMAGE_THRESHOLD / 1e6)))
def seqfish(
path: str | Path,
load_images: bool = True,
load_labels: bool = True,
load_points: bool = True,
load_shapes: bool = True,
cells_as_circles: bool = False,
rois: list[int] | None = None,
rois: list[str] | None = None,
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
raster_models_scale_factors: list[float] | None = None,
raster_models_scale_factors: list[int] | None = None,
) -> SpatialData:
"""
Read *seqfish* formatted dataset.

This function reads the following files:

- ```{vx.ROI!r}{vx.COUNTS_FILE!r}{vx.CSV_FILE!r}```: Counts and metadata file.
- ```{vx.ROI!r}{vx.CELL_COORDINATES!r}{vx.CSV_FILE!r}```: Cell coordinates file.
- ```{vx.ROI!r}{vx.DAPI!r}{vx.TIFF_FILE!r}```: High resolution tiff image.
- ```{vx.ROI!r}{vx.SEGMENTATION!r}{vx.TIFF_FILE!r}```: Cell mask file.
- ```{vx.ROI!r}{vx.TRANSCRIPT_COORDINATES!r}{vx.CSV_FILE!r}```: Transcript coordinates file.
- ```<roi_prefix>_{vx.COUNTS_FILE!r}{vx.CSV_FILE!r}```: Counts and metadata file.
- ```<roi_prefix>_{vx.CELL_COORDINATES!r}{vx.CSV_FILE!r}```: Cell coordinates file.
- ```<roi_prefix>_{vx.DAPI!r}{vx.TIFF_FILE!r}```: High resolution tiff image.
- ```<roi_prefix>_{vx.SEGMENTATION!r}{vx.TIFF_FILE!r}```: Cell mask file.
- ```<roi_prefix>_{vx.TRANSCRIPT_COORDINATES!r}{vx.CSV_FILE!r}```: Transcript coordinates file.

.. seealso::

Expand All @@ -71,9 +75,13 @@ def seqfish(
cells_as_circles
Whether to read cells also as circles instead of labels.
rois
Which ROIs (specified as integers) to load. Only necessary if multiple ROIs present.
Which ROIs (specified as strings, without trailing "_") to load (the ROI strings are used as prefixes for the
filenames). If `None`, all ROIs are loaded.
imread_kwargs
Keyword arguments to pass to :func:`dask_image.imread.imread`.
raster_models_scale_factors
Scale factors to downscale high-resolution images and labels. The scale factors will be automatically set to
obtain a multi-scale image for all the images and labels that are larger than {megapixels_value} megapixels.

Returns
-------
Expand All @@ -92,24 +100,29 @@ def seqfish(
>>> sdata.write("path/to/data.zarr")
"""
path = Path(path)
count_file_pattern = re.compile(rf"(.*?){re.escape(SK.CELL_COORDINATES)}{re.escape(SK.CSV_FILE)}$")
count_file_pattern = re.compile(rf"(.*?)_{re.escape(SK.CELL_COORDINATES)}{re.escape(SK.CSV_FILE)}$")
count_files = [f for f in os.listdir(path) if count_file_pattern.match(f)]
if not count_files:
raise ValueError(
f"No files matching the pattern {count_file_pattern} were found. Cannot infer the naming scheme."
)

roi_pattern = re.compile(f"^{SK.ROI}(\\d+)")
found_rois = {m.group(1) for i in os.listdir(path) if (m := roi_pattern.match(i))}
if rois is None:
rois_str = [f"{SK.ROI}{roi}" for roi in found_rois]
elif isinstance(rois, list):
rois_str_set = set()
for count_file in count_files:
found = count_file_pattern.match(count_file)
if found is None:
raise ValueError(f"File {count_file} does not match the expected pattern.")
rois_str_set.add(found.group(1))
logger.info(f"Found ROIs: {rois_str_set}")
rois_str = list(rois_str_set)

if isinstance(rois, list):
for roi in rois:
if str(roi) not in found_rois:
if str(roi) not in rois_str_set:
raise ValueError(f"ROI{roi} not found.")
rois_str = [f"{SK.ROI}{roi}" for roi in rois]
else:
raise ValueError("Invalid type for 'roi'. Must be list[int] or None.")
rois_str = rois
elif rois is not None:
raise ValueError("Invalid type for 'roi'. Must be list[str] or None.")

def get_cell_file(roi: str) -> str:
return f"{roi}_{SK.CELL_COORDINATES}{SK.CSV_FILE}"
Expand Down Expand Up @@ -167,33 +180,44 @@ def get_transcript_file(roi: str) -> str:
scaled = {}
for roi_str in rois_str:
scaled[roi_str] = Scale(
np.array(_get_scale_factors(path / get_dapi_file(roi_str), SK.SCALEFEFACTOR_X, SK.SCALEFEFACTOR_Y)),
np.array(_get_scale_factors_scale0(path / get_dapi_file(roi_str))),
axes=("y", "x"),
)

def _get_scale_factors(raster_path: Path, raster_models_scale_factors: list[int] | None) -> list[int] | None:
n_pixels = _get_n_pixels(raster_path)
if n_pixels > LARGE_IMAGE_THRESHOLD and raster_models_scale_factors is None:
return [2, 2, 2]
else:
return raster_models_scale_factors

if load_images:
images = {
f"{os.path.splitext(get_dapi_file(x))[0]}": Image2DModel.parse(
imread(path / get_dapi_file(x), **imread_kwargs),
images = {}
for x in rois_str:
image_path = path / get_dapi_file(x)
scale_factors = _get_scale_factors(image_path, raster_models_scale_factors)

images[f"{os.path.splitext(get_dapi_file(x))[0]}"] = Image2DModel.parse(
imread(image_path, **imread_kwargs),
dims=("c", "y", "x"),
scale_factors=raster_models_scale_factors,
transformations={"global": scaled[x]},
scale_factors=scale_factors,
transformations={x: scaled[x]},
)
for x in rois_str
}
else:
images = {}

if load_labels:
labels = {
f"{os.path.splitext(get_cell_segmentation_labels_file(x))[0]}": Labels2DModel.parse(
imread(path / get_cell_segmentation_labels_file(x), **imread_kwargs).squeeze(),
labels = {}
for x in rois_str:
labels_path = path / get_cell_segmentation_labels_file(x)
scale_factors = _get_scale_factors(labels_path, raster_models_scale_factors)

labels[f"{os.path.splitext(get_cell_segmentation_labels_file(x))[0]}"] = Labels2DModel.parse(
imread(labels_path, **imread_kwargs).squeeze(),
dims=("y", "x"),
scale_factors=raster_models_scale_factors,
transformations={"global": scaled[x]},
scale_factors=scale_factors,
transformations={x: scaled[x]},
)
for x in rois_str
}
else:
labels = {}

Expand All @@ -206,13 +230,20 @@ def get_transcript_file(roi: str) -> str:
p = pd.read_csv(path / get_transcript_file(x), delimiter=",")
instance_key_points = SK.INSTANCE_KEY_POINTS.value if SK.INSTANCE_KEY_POINTS.value in p.columns else None

coordinates = {"x": SK.TRANSCRIPTS_X, "y": SK.TRANSCRIPTS_Y, "z": SK.TRANSCRIPTS_Z}
if SK.TRANSCRIPTS_Z not in p.columns:
coordinates.pop("z")
warnings.warn(
f"Column {SK.TRANSCRIPTS_Z} not found in {get_transcript_file(x)}.", UserWarning, stacklevel=2
)

# call parser
points[name] = PointsModel.parse(
p,
coordinates={"x": SK.TRANSCRIPTS_X, "y": SK.TRANSCRIPTS_Y},
coordinates=coordinates,
feature_key=SK.FEATURE_KEY.value,
instance_key=instance_key_points,
transformations={"global": Identity()},
transformations={x: Identity()},
)

shapes = {}
Expand All @@ -223,15 +254,15 @@ def get_transcript_file(roi: str) -> str:
geometry=0,
radius=np.sqrt(adata.obs[SK.AREA].to_numpy() / np.pi),
index=adata.obs[SK.INSTANCE_KEY_TABLE].copy(),
transformations={"global": Identity()},
transformations={x: Identity()},
)
if load_shapes:
for x in rois_str:
for x, adata in zip(rois_str, tables.values()):
# this assumes that the index matches the instance key of the table. A more robust approach could be
# implemented, as described here https://github.com/scverse/spatialdata-io/issues/249
shapes[f"{os.path.splitext(get_cell_segmentation_shapes_file(x))[0]}"] = ShapesModel.parse(
path / get_cell_segmentation_shapes_file(x),
transformations={"global": scaled[x]},
transformations={x: scaled[x]},
index=adata.obs[SK.INSTANCE_KEY_TABLE].copy(),
)

Expand All @@ -240,12 +271,41 @@ def get_transcript_file(roi: str) -> str:
return sdata


def _get_scale_factors(DAPI_path: Path, scalefactor_x_key: str, scalefactor_y_key: str) -> list[float]:
with tifffile.TiffFile(DAPI_path) as tif:
ome_metadata = tif.ome_metadata
root = ET.fromstring(ome_metadata)
for element in root.iter():
if scalefactor_x_key in element.attrib.keys():
scalefactor_x = element.attrib[scalefactor_x_key]
scalefactor_y = element.attrib[scalefactor_y_key]
return [float(scalefactor_x), float(scalefactor_y)]
def _is_ome_tiff_multiscale(ome_tiff_file: Path) -> bool:
"""
Check if the OME-TIFF file is multi-scale.

Parameters
----------
ome_tiff_file
Path to the OME-TIFF file.

Returns
-------
Whether the OME-TIFF file is multi-scale.
"""
# for some image files we couldn't find the multiscale information in the omexml metadata, and this method proves to
# be more robust
try:
zarr_tiff_store = tifffile.imread(ome_tiff_file, is_ome=True, level=1, aszarr=True)
zarr_tiff_store.close()
except IndexError:
return False
return True


def _get_n_pixels(ome_tiff_file: Path) -> int:
with tifffile.TiffFile(ome_tiff_file, is_ome=True) as tif:
page = tif.pages[0]
shape = page.shape
n_pixels = np.array(shape).prod().item()
assert isinstance(n_pixels, int)
return n_pixels


def _get_scale_factors_scale0(DAPI_path: Path) -> list[float]:
with tifffile.TiffFile(DAPI_path, is_ome=True) as tif:
ome_metadata = xmltodict.parse(tif.ome_metadata)
scalefactor_x = ome_metadata["OME"]["Image"]["Pixels"]["@PhysicalSizeX"]
scalefactor_y = ome_metadata["OME"]["Image"]["Pixels"]["@PhysicalSizeY"]
return [float(scalefactor_x), float(scalefactor_y)]
5 changes: 3 additions & 2 deletions tests/test_seqfish.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
@pytest.mark.parametrize(
"dataset,expected", [("seqfish-2-test-dataset/instrument 2 official", "{'y': (0, 108), 'x': (0, 108)}")]
)
@pytest.mark.parametrize("rois", [[1], None])
@pytest.mark.parametrize("rois", [["Roi1"], None])
@pytest.mark.parametrize("cells_as_circles", [False, True])
def test_example_data(dataset: str, expected: str, rois: list[int] | None, cells_as_circles: bool) -> None:
f = Path("./data") / dataset
assert f.is_dir()
sdata = seqfish(f, cells_as_circles=cells_as_circles, rois=rois)
from spatialdata import get_extent

extent = get_extent(sdata, exact=False)
extent = get_extent(sdata, exact=False, coordinate_system="Roi1")
extent = {ax: (math.floor(extent[ax][0]), math.ceil(extent[ax][1])) for ax in extent}
del extent["z"]
if cells_as_circles:
# manual correction required to take into account for the circle radii
expected = "{'y': (-2, 109), 'x': (-2, 109)}"
Expand Down
Loading