From af6ca8b4cc1623f29da8fcbe31d2720c8243bd19 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Sat, 18 Dec 2021 00:57:18 +0000 Subject: [PATCH 1/5] Refactor HEStainExtractor and StainNormalizer Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/apps/pathology/__init__.py | 14 +- monai/apps/pathology/transforms/__init__.py | 14 +- .../pathology/transforms/stain/__init__.py | 14 +- .../apps/pathology/transforms/stain/array.py | 221 +++++++++--------- .../pathology/transforms/stain/dictionary.py | 73 +++--- tests/test_pathology_he_stain.py | 50 ++-- tests/test_pathology_he_stain_dict.py | 53 +++-- 7 files changed, 225 insertions(+), 214 deletions(-) diff --git a/monai/apps/pathology/__init__.py b/monai/apps/pathology/__init__.py index 80f32403ea..a09e176b58 100644 --- a/monai/apps/pathology/__init__.py +++ b/monai/apps/pathology/__init__.py @@ -12,13 +12,13 @@ from .data import MaskedInferenceWSIDataset, PatchWSIDataset, SmartCachePatchWSIDataset from .handlers import ProbMapProducer from .metrics import LesionFROC -from .transforms.stain.array import ExtractHEStains, NormalizeHEStains +from .transforms.stain.array import HEStainExtractor, StainNormalizer from .transforms.stain.dictionary import ( - ExtractHEStainsd, - ExtractHEStainsD, - ExtractHEStainsDict, - NormalizeHEStainsd, - NormalizeHEStainsD, - NormalizeHEStainsDict, + HEStainExtractord, + HEStainExtractorD, + HEStainExtractorDict, + StainNormalizerd, + StainNormalizerD, + StainNormalizerDict, ) from .utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index b418e20279..ec16b9004f 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -11,12 +11,12 @@ from .spatial.array import SplitOnGrid, TileOnGrid from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict -from .stain.array import ExtractHEStains, NormalizeHEStains +from .stain.array import HEStainExtractor, StainNormalizer from .stain.dictionary import ( - ExtractHEStainsd, - ExtractHEStainsD, - ExtractHEStainsDict, - NormalizeHEStainsd, - NormalizeHEStainsD, - NormalizeHEStainsDict, + HEStainExtractord, + HEStainExtractorD, + HEStainExtractorDict, + StainNormalizerd, + StainNormalizerD, + StainNormalizerDict, ) diff --git a/monai/apps/pathology/transforms/stain/__init__.py b/monai/apps/pathology/transforms/stain/__init__.py index 824f40a579..aa2ce3d61c 100644 --- a/monai/apps/pathology/transforms/stain/__init__.py +++ b/monai/apps/pathology/transforms/stain/__init__.py @@ -9,12 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .array import ExtractHEStains, NormalizeHEStains +from .array import HEStainExtractor, StainNormalizer from .dictionary import ( - ExtractHEStainsd, - ExtractHEStainsD, - ExtractHEStainsDict, - NormalizeHEStainsd, - NormalizeHEStainsD, - NormalizeHEStainsDict, + HEStainExtractord, + HEStainExtractorD, + HEStainExtractorDict, + StainNormalizerd, + StainNormalizerD, + StainNormalizerDict, ) diff --git a/monai/apps/pathology/transforms/stain/array.py b/monai/apps/pathology/transforms/stain/array.py index 3a03777299..f094e9ca83 100644 --- a/monai/apps/pathology/transforms/stain/array.py +++ b/monai/apps/pathology/transforms/stain/array.py @@ -16,181 +16,182 @@ from monai.transforms.transform import Transform -class ExtractHEStains(Transform): - """Class to extract a target stain from an image, using stain deconvolution (see Note). +class HEStainExtractor: + """Extract stain coefficients from an image. Args: - tli: transmitted light intensity. Defaults to 240. - alpha: tolerance in percentile for the pseudo-min (alpha percentile) - and pseudo-max (100 - alpha percentile). Defaults to 1. - beta: absorbance threshold for transparent pixels. Defaults to 0.15 - max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E). - Defaults to (1.9705, 1.0308). + source_intensity: transmitted light intensity. + Defaults to 240. + alpha: percentiles to ignore for outliers, so to calculate min and max, + if only consider (alpha, 100-alpha) percentiles. Defaults to 1. + beta: absorbance threshold for transparent pixels. + Defaults to 0.15 Note: - For more information refer to: - - the original paper: Macenko et al., 2009 http://wwwx.cs.unc.edu/~mn/sites/default/files/macenko2009.pdf - - the previous implementations: - - - MATLAB: https://github.com/mitkovetta/staining-normalization - - Python: https://github.com/schaugf/HEnorm_python + Please refer to this paper for further information on the method: + Macenko et al., 2009 http://wwwx.cs.unc.edu/~mn/sites/default/files/macenko2009.pdf """ def __init__( self, - tli: float = 240, + source_intensity: float = 240, alpha: float = 1, beta: float = 0.15, - max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), ) -> None: - self.tli = tli + self.source_intensity = source_intensity self.alpha = alpha self.beta = beta - self.max_cref = np.array(max_cref) - def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray: - """Perform Stain Deconvolution and return stain matrix for the image. + def calculate_flat_absorbance(self, image): + """Calculate absorbace and remove transparent pixels""" + # calculate absorbance + image = image.astype(np.float32, copy=False) + 1.0 + absorbance = -np.log(image.clip(max=self.source_intensity) / self.source_intensity) + + # reshape to form a CxN matrix + c = absorbance.shape[0] + absorbance = absorbance.reshape((c, -1)) + + # remove transparent pixels + absorbance = absorbance[np.all(absorbance > self.beta, axis=1)] + if len(absorbance) == 0: + raise ValueError("All pixels of the input image are below the absorbance threshold.") + + return absorbance + + def _stain_decomposition(self, absorbance: np.ndarray) -> np.ndarray: + """Calculate the matrix of stain coefficient from the image. Args: - image: uint8 RGB image to perform stain deconvolution on + absorbance: absorbance matrix to perform stain extraction on Return: - he: H&E absorbance matrix for the image (first column is H, second column is E, rows are RGB values) + stain_coeff: stain attenuation coefficient matrix derive from the + image, where first column is H, second column is E, and + rows are RGB values """ - # check image type and values - if not isinstance(image, np.ndarray): - raise TypeError("Image must be of type numpy.ndarray.") - if image.min() < 0: - raise ValueError("Image should not have negative values.") - if image.max() > 255: - raise ValueError("Image should not have values greater than 255.") - - # reshape image and calculate absorbance - image = image.reshape((-1, 3)) - image = image.astype(np.float32, copy=False) + 1.0 - absorbance = -np.log(image.clip(max=self.tli) / self.tli) - - # remove transparent pixels - absorbance_hat = absorbance[np.all(absorbance > self.beta, axis=1)] - if len(absorbance_hat) == 0: - raise ValueError("All pixels of the input image are below the absorbance threshold.") # compute eigenvectors - _, eigvecs = np.linalg.eigh(np.cov(absorbance_hat.T).astype(np.float32, copy=False)) + _, eigvecs = np.linalg.eigh(np.cov(absorbance).astype(np.float32, copy=False)) # project on the plane spanned by the eigenvectors corresponding to the two largest eigenvalues - t_hat = absorbance_hat.dot(eigvecs[:, 1:3]) + projection = np.dot(eigvecs[:, -2:].T, absorbance) - # find the min and max vectors and project back to absorbance space - phi = np.arctan2(t_hat[:, 1], t_hat[:, 0]) + # find the vectors that span the whole data (min and max angles) + phi = np.arctan2(projection[1], projection[0]) min_phi = np.percentile(phi, self.alpha) max_phi = np.percentile(phi, 100 - self.alpha) - v_min = eigvecs[:, 1:3].dot(np.array([(np.cos(min_phi), np.sin(min_phi))], dtype=np.float32).T) - v_max = eigvecs[:, 1:3].dot(np.array([(np.cos(max_phi), np.sin(max_phi))], dtype=np.float32).T) + # project back to absorbance space + v_min = eigvecs[:, -2:].dot(np.array([(np.cos(min_phi), np.sin(min_phi))], dtype=np.float32).T) + v_max = eigvecs[:, -2:].dot(np.array([(np.cos(max_phi), np.sin(max_phi))], dtype=np.float32).T) - # a heuristic to make the vector corresponding to hematoxylin first and the one corresponding to eosin second + # make the vector corresponding to hematoxylin first and eosin second (based on R channel) if v_min[0] > v_max[0]: - he = np.array((v_min[:, 0], v_max[:, 0]), dtype=np.float32).T + stain_coeff = np.array((v_min[:, 0], v_max[:, 0]), dtype=np.float32).T else: - he = np.array((v_max[:, 0], v_min[:, 0]), dtype=np.float32).T + stain_coeff = np.array((v_max[:, 0], v_min[:, 0]), dtype=np.float32).T - return he + return stain_coeff def __call__(self, image: np.ndarray) -> np.ndarray: """Perform stain extraction. Args: - image: uint8 RGB image to extract stain from + image: RGB image to extract stain from - return: - target_he: H&E absorbance matrix for the image (first column is H, second column is E, rows are RGB values) + Return: + ref_stain_coeff: H&E absorbance matrix for the image (first column is H, second column is E, rows are RGB values) """ + # check image type and values if not isinstance(image, np.ndarray): - raise TypeError("Image must be of type numpy.ndarray.") - - target_he = self._deconvolution_extract_stain(image) - return target_he + raise TypeError("Image must be of type cupy.ndarray.") + if image.min() < 0: + raise ValueError("Image should not have negative values.") + absorbance = self.calculate_flat_absorbance(image) + ref_stain_coeff = self._stain_decomposition(absorbance) + return ref_stain_coeff -class NormalizeHEStains(Transform): - """Class to normalize patches/images to a reference or target image stain (see Note). - Performs stain deconvolution of the source image using the ExtractHEStains - class, to obtain the stain matrix and calculate the stain concentration matrix - for the image. Then, performs the inverse Beer-Lambert transform to recreate the - patch using the target H&E stain matrix provided. If no target stain provided, a default - reference stain is used. Similarly, if no maximum stain concentrations are provided, a - reference maximum stain concentrations matrix is used. +class StainNormalizer: + """Normalize images to a reference stain color matrix. - Args: - tli: transmitted light intensity. Defaults to 240. - alpha: tolerance in percentile for the pseudo-min (alpha percentile) and - pseudo-max (100 - alpha percentile). Defaults to 1. - beta: absorbance threshold for transparent pixels. Defaults to 0.15. - target_he: target stain matrix. Defaults to ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)). - max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E). - Defaults to [1.9705, 1.0308]. + First, it extracts the stain coefficient matrix from the image using the provided stain extractor. + Then, it calculates the stain concentrations based on Beer-Lamber Law. + Next, it reconstructs the image using the provided reference stain matrix (stain-normalized image). - Note: - For more information refer to: - - the original paper: Macenko et al., 2009 http://wwwx.cs.unc.edu/~mn/sites/default/files/macenko2009.pdf - - the previous implementations: + Parameters + ---------- + source_intensity: transmitted light intensity. + Defaults to 240. + alpha: percentiles to ignore for outliers, so to calculate min and max, + if only consider (alpha, 100-alpha) percentiles. Defaults to 1. + ref_stain_coeff: reference stain attenuation coefficient matrix. + Defaults to ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)). + ref_max_conc: reference maximum stain concentrations for + Hematoxylin & Eosin (H&E). Defaults to (1.9705, 1.0308). - - MATLAB: https://github.com/mitkovetta/staining-normalization - - Python: https://github.com/schaugf/HEnorm_python """ def __init__( self, - tli: float = 240, + source_intensity: float = 240, alpha: float = 1, - beta: float = 0.15, - target_he: Union[tuple, np.ndarray] = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)), - max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), + ref_stain_coeff: Union[tuple, np.ndarray] = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)), + ref_max_conc: Union[tuple, np.ndarray] = (1.9705, 1.0308), + stain_extractor=None, ) -> None: - self.tli = tli - self.target_he = np.array(target_he) - self.max_cref = np.array(max_cref) - self.stain_extractor = ExtractHEStains(tli=self.tli, alpha=alpha, beta=beta, max_cref=self.max_cref) + self.source_intensity = source_intensity + self.alpha = alpha + self.ref_stain_coeff = np.array(ref_stain_coeff) + self.ref_max_conc = np.array(ref_max_conc) + if stain_extractor is None: + self.stain_extractor = HEStainExtractor() + else: + self.stain_extractor = stain_extractor def __call__(self, image: np.ndarray) -> np.ndarray: """Perform stain normalization. - Args: - image: uint8 RGB image/patch to be stain normalized, pixel values between 0 and 255 + Parameters + ---------- + image: uint8 RGB image to be stain normalized, pixel values between 0 and 255 - Return: + Returns + ------- image_norm: stain normalized image/patch """ # check image type and values if not isinstance(image, np.ndarray): - raise TypeError("Image must be of type numpy.ndarray.") + raise TypeError("Image must be of type cupy.ndarray.") if image.min() < 0: raise ValueError("Image should not have negative values.") - if image.max() > 255: - raise ValueError("Image should not have values greater than 255.") - # extract stain of the image - he = self.stain_extractor(image) + if self.source_intensity < 0: + raise ValueError("Source transmitted light intensity must be a positive value.") + + # derive stain coefficient matrix from the image + stain_coeff = self.stain_extractor(image) - # reshape image and calculate absorbance - h, w, _ = image.shape - image = image.reshape((-1, 3)) - image = image.astype(np.float32) + 1.0 - absorbance = -np.log(image.clip(max=self.tli) / self.tli) + # calculate absorbance + image = image.astype(np.float32, copy=False) + 1.0 + absorbance = -np.log(image.clip(max=self.source_intensity) / self.source_intensity) - # rows correspond to channels (RGB), columns to absorbance values - y = np.reshape(absorbance, (-1, 3)).T + # reshape to form a CxN matrix + c, h, w = absorbance.shape + absorbance = absorbance.reshape((c, -1)) - # determine concentrations of the individual stains - conc = np.linalg.lstsq(he, y, rcond=None)[0] + # calculate concentrations of the each stain, based on Beer-Lambert Law + conc_raw = np.linalg.lstsq(stain_coeff, absorbance, rcond=None)[0] # normalize stain concentrations - max_conc = np.asarray([np.percentile(conc[0, :], 99), np.percentile(conc[1, :], 99)], dtype=np.float32) - tmp = np.divide(max_conc, self.max_cref, dtype=np.float32) - image_c = np.divide(conc, tmp[:, np.newaxis], dtype=np.float32) - - image_norm: np.ndarray = np.multiply(self.tli, np.exp(-self.target_he.dot(image_c)), dtype=np.float32) - image_norm[image_norm > 255] = 254 - image_norm = np.reshape(image_norm.T, (h, w, 3)).astype(np.uint8) + max_conc = np.percentile(conc_raw, 100 - self.alpha, axis=1) + normalization_factors = self.ref_max_conc / max_conc + conc_norm = conc_raw * normalization_factors[:, np.newaxis] + + # reconstruct the image based on the reference stain matrix + image_norm: np.ndarray = np.multiply( + self.source_intensity, np.exp(-self.ref_stain_coeff.dot(conc_norm)), dtype=np.float32 + ) + image_norm = np.reshape(image_norm, (c, h, w)).astype(np.uint8) return image_norm diff --git a/monai/apps/pathology/transforms/stain/dictionary.py b/monai/apps/pathology/transforms/stain/dictionary.py index 976af1e7c7..d310a1b084 100644 --- a/monai/apps/pathology/transforms/stain/dictionary.py +++ b/monai/apps/pathology/transforms/stain/dictionary.py @@ -22,22 +22,22 @@ from monai.config import KeysCollection from monai.transforms.transform import MapTransform -from .array import ExtractHEStains, NormalizeHEStains +from .array import HEStainExtractor, StainNormalizer -class ExtractHEStainsd(MapTransform): - """Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.ExtractHEStains`. +class HEStainExtractord(MapTransform): + """Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.HEStainExtractor`. Class to extract a target stain from an image, using stain deconvolution. Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - tli: transmitted light intensity. Defaults to 240. - alpha: tolerance in percentile for the pseudo-min (alpha percentile) - and pseudo-max (100 - alpha percentile). Defaults to 1. - beta: absorbance threshold for transparent pixels. Defaults to 0.15 - max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E). - Defaults to (1.9705, 1.0308). + source_intensity: transmitted light intensity. + Defaults to 240. + alpha: percentiles to ignore for outliers, so to calculate min and max, + if only consider (alpha, 100-alpha) percentiles. Defaults to 1. + beta: absorbance threshold for transparent pixels. + Defaults to 0.15 allow_missing_keys: don't raise exception if key is missing. """ @@ -45,14 +45,13 @@ class ExtractHEStainsd(MapTransform): def __init__( self, keys: KeysCollection, - tli: float = 240, + source_intensity: float = 240, alpha: float = 1, beta: float = 0.15, - max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.extractor = ExtractHEStains(tli=tli, alpha=alpha, beta=beta, max_cref=max_cref) + self.extractor = HEStainExtractor(source_intensity=source_intensity, alpha=alpha, beta=beta) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) @@ -61,28 +60,26 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class NormalizeHEStainsd(MapTransform): - """Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.NormalizeHEStains`. +class StainNormalizerd(MapTransform): + """Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.StainNormalizer`. - Class to normalize patches/images to a reference or target image stain. + Normalize images to a reference stain color matrix. - Performs stain deconvolution of the source image using the ExtractHEStains - class, to obtain the stain matrix and calculate the stain concentration matrix - for the image. Then, performs the inverse Beer-Lambert transform to recreate the - patch using the target H&E stain matrix provided. If no target stain provided, a default - reference stain is used. Similarly, if no maximum stain concentrations are provided, a - reference maximum stain concentrations matrix is used. + First, it extracts the stain coefficient matrix from the image using the provided stain extractor. + Then, it calculates the stain concentrations based on Beer-Lamber Law. + Next, it reconstructs the image using the provided reference stain matrix (stain-normalized image). Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - tli: transmitted light intensity. Defaults to 240. - alpha: tolerance in percentile for the pseudo-min (alpha percentile) and - pseudo-max (100 - alpha percentile). Defaults to 1. - beta: absorbance threshold for transparent pixels. Defaults to 0.15. - target_he: target stain matrix. Defaults to None. - max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E). - Defaults to None. + source_intensity: transmitted light intensity. + Defaults to 240. + alpha: percentiles to ignore for outliers, so to calculate min and max, + if only consider (alpha, 100-alpha) percentiles. Defaults to 1. + ref_stain_coeff: reference stain attenuation coefficient matrix. + Defaults to ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)). + ref_max_conc: reference maximum stain concentrations for + Hematoxylin & Eosin (H&E). Defaults to (1.9705, 1.0308). allow_missing_keys: don't raise exception if key is missing. """ @@ -90,15 +87,21 @@ class NormalizeHEStainsd(MapTransform): def __init__( self, keys: KeysCollection, - tli: float = 240, + source_intensity: float = 240, alpha: float = 1, - beta: float = 0.15, - target_he: Union[tuple, np.ndarray] = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)), - max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), + ref_stain_coeff: Union[tuple, np.ndarray] = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)), + ref_max_conc: Union[tuple, np.ndarray] = (1.9705, 1.0308), + stain_extractor=None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.normalizer = NormalizeHEStains(tli=tli, alpha=alpha, beta=beta, target_he=target_he, max_cref=max_cref) + self.normalizer = StainNormalizer( + source_intensity=source_intensity, + alpha=alpha, + ref_stain_coeff=ref_stain_coeff, + ref_max_conc=ref_max_conc, + stain_extractor=stain_extractor, + ) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) @@ -107,5 +110,5 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -ExtractHEStainsDict = ExtractHEStainsD = ExtractHEStainsd -NormalizeHEStainsDict = NormalizeHEStainsD = NormalizeHEStainsd +HEStainExtractorDict = HEStainExtractorD = HEStainExtractord +StainNormalizerDict = StainNormalizerD = StainNormalizerd diff --git a/tests/test_pathology_he_stain.py b/tests/test_pathology_he_stain.py index 7f76c3f03e..35d3598dd8 100644 --- a/tests/test_pathology_he_stain.py +++ b/tests/test_pathology_he_stain.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.apps.pathology.transforms import ExtractHEStains, NormalizeHEStains +from monai.apps.pathology.transforms import HEStainExtractor, StainNormalizer # None inputs EXTRACT_STAINS_TEST_CASE_0 = (None,) @@ -23,23 +23,23 @@ NORMALIZE_STAINS_TEST_CASE_00: tuple = ({}, None, None) # input pixels with negative values -NEGATIVE_VALUE_TEST_CASE = [np.full((3, 2, 3), -1)] +NEGATIVE_VALUE_TEST_CASE = [np.full((3, 2, 4), -1)] # input pixels with greater than 255 values -INVALID_VALUE_TEST_CASE = [np.full((3, 2, 3), 256)] +INVALID_VALUE_TEST_CASE = [np.full((3, 2, 4), 256)] # input pixels all transparent and below the beta absorbance threshold -EXTRACT_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)] +EXTRACT_STAINS_TEST_CASE_1 = [np.full((3, 2, 4), 240)] # input pixels uniformly filled, but above beta absorbance threshold -EXTRACT_STAINS_TEST_CASE_2 = [np.full((3, 2, 3), 100)] +EXTRACT_STAINS_TEST_CASE_2 = [np.full((3, 2, 4), 100)] # input pixels uniformly filled (different value), but above beta absorbance threshold -EXTRACT_STAINS_TEST_CASE_3 = [np.full((3, 2, 3), 150)] +EXTRACT_STAINS_TEST_CASE_3 = [np.full((3, 2, 4), 150)] # input pixels uniformly filled with zeros, leading to two identical stains extracted EXTRACT_STAINS_TEST_CASE_4 = [ - np.zeros((3, 2, 3)), + np.zeros((3, 2, 4)), np.array([[0.0, 0.0], [0.70710678, 0.70710678], [0.70710678, 0.70710678]]), ] @@ -51,27 +51,27 @@ # input pixels all transparent and below the beta absorbance threshold -NORMALIZE_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)] +NORMALIZE_STAINS_TEST_CASE_1 = [np.full((3, 2, 5), 240)] # input pixels uniformly filled with zeros, and target stain matrix provided -NORMALIZE_STAINS_TEST_CASE_2 = [{"target_he": np.full((3, 2), 1)}, np.zeros((3, 2, 3)), np.full((3, 2, 3), 11)] +NORMALIZE_STAINS_TEST_CASE_2 = [{"ref_stain_coeff": np.full((3, 2), 1)}, np.zeros((3, 2, 4)), np.full((3, 2, 4), 11)] # input pixels uniformly filled with zeros, and target stain matrix not provided NORMALIZE_STAINS_TEST_CASE_3 = [ {}, np.zeros((3, 2, 3)), - np.array([[[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]]]), + np.array([[[63, 63, 63], [63, 63, 63]], [[25, 25, 25], [25, 25, 25]], [[60, 60, 60], [60, 60, 60]]]), ] # input pixels not uniformly filled NORMALIZE_STAINS_TEST_CASE_4 = [ - {"target_he": np.full((3, 2), 1)}, + {"ref_stain_coeff": np.full((3, 2), 1)}, np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]), - np.array([[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]]]), + np.array([[[87, 33, 33], [33, 33, 33]], [[87, 33, 33], [33, 33, 33]], [[87, 33, 33], [33, 33, 33]]]), ] -class TestExtractHEStains(unittest.TestCase): +class TestHEStainExtractor(unittest.TestCase): @parameterized.expand( [NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_1] ) @@ -85,10 +85,10 @@ def test_transparent_image(self, image): """ if image is None: with self.assertRaises(TypeError): - ExtractHEStains()(image) + HEStainExtractor()(image) else: with self.assertRaises(ValueError): - ExtractHEStains()(image) + HEStainExtractor()(image) @parameterized.expand([EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_2, EXTRACT_STAINS_TEST_CASE_3]) def test_identical_result_vectors(self, image): @@ -102,9 +102,9 @@ def test_identical_result_vectors(self, image): """ if image is None: with self.assertRaises(TypeError): - ExtractHEStains()(image) + HEStainExtractor()(image) else: - result = ExtractHEStains()(image) + result = HEStainExtractor()(image) np.testing.assert_array_equal(result[:, 0], result[:, 1]) @parameterized.expand([EXTRACT_STAINS_TEST_CASE_00, EXTRACT_STAINS_TEST_CASE_4, EXTRACT_STAINS_TEST_CASE_5]) @@ -137,13 +137,13 @@ def test_result_value(self, image, expected_data): """ if image is None: with self.assertRaises(TypeError): - ExtractHEStains()(image) + HEStainExtractor()(image) else: - result = ExtractHEStains()(image) + result = HEStainExtractor()(image) np.testing.assert_allclose(result, expected_data) -class TestNormalizeHEStains(unittest.TestCase): +class TestStainNormalizer(unittest.TestCase): @parameterized.expand( [NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, NORMALIZE_STAINS_TEST_CASE_0, NORMALIZE_STAINS_TEST_CASE_1] ) @@ -157,10 +157,10 @@ def test_transparent_image(self, image): """ if image is None: with self.assertRaises(TypeError): - NormalizeHEStains()(image) + StainNormalizer()(image) else: with self.assertRaises(ValueError): - NormalizeHEStains()(image) + StainNormalizer()(image) @parameterized.expand( [ @@ -204,7 +204,7 @@ def test_result_value(self, argments, image, expected_data): For test case 4: - For this non-uniformly filled image, the stain extracted should be [[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]], as validated for the - ExtractHEStains class. Solving the linear least squares problem (since + HEStainExtractor class. Solving the linear least squares problem (since absorbance matrix = stain matrix * concentration matrix), we obtain the concentration matrix that should be [[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508], [5.8022, 0, 0, 0, 0, 0]] @@ -217,9 +217,9 @@ def test_result_value(self, argments, image, expected_data): """ if image is None: with self.assertRaises(TypeError): - NormalizeHEStains()(image) + StainNormalizer()(image) else: - result = NormalizeHEStains(**argments)(image) + result = StainNormalizer(**argments)(image) np.testing.assert_allclose(result, expected_data) diff --git a/tests/test_pathology_he_stain_dict.py b/tests/test_pathology_he_stain_dict.py index 2ba2c3f71b..528eca91a3 100644 --- a/tests/test_pathology_he_stain_dict.py +++ b/tests/test_pathology_he_stain_dict.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.apps.pathology.transforms import ExtractHEStainsD, NormalizeHEStainsD +from monai.apps.pathology.transforms import HEStainExtractorD, StainNormalizerD # None inputs EXTRACT_STAINS_TEST_CASE_0 = (None,) @@ -22,18 +22,24 @@ NORMALIZE_STAINS_TEST_CASE_0 = (None,) NORMALIZE_STAINS_TEST_CASE_00: tuple = ({}, None, None) +# input pixels with negative values +NEGATIVE_VALUE_TEST_CASE = [np.full((3, 2, 4), -1)] + +# input pixels with greater than 255 values +INVALID_VALUE_TEST_CASE = [np.full((3, 2, 4), 256)] + # input pixels all transparent and below the beta absorbance threshold -EXTRACT_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)] +EXTRACT_STAINS_TEST_CASE_1 = [np.full((3, 2, 4), 240)] # input pixels uniformly filled, but above beta absorbance threshold -EXTRACT_STAINS_TEST_CASE_2 = [np.full((3, 2, 3), 100)] +EXTRACT_STAINS_TEST_CASE_2 = [np.full((3, 2, 4), 100)] # input pixels uniformly filled (different value), but above beta absorbance threshold -EXTRACT_STAINS_TEST_CASE_3 = [np.full((3, 2, 3), 150)] +EXTRACT_STAINS_TEST_CASE_3 = [np.full((3, 2, 4), 150)] # input pixels uniformly filled with zeros, leading to two identical stains extracted EXTRACT_STAINS_TEST_CASE_4 = [ - np.zeros((3, 2, 3)), + np.zeros((3, 2, 4)), np.array([[0.0, 0.0], [0.70710678, 0.70710678], [0.70710678, 0.70710678]]), ] @@ -43,28 +49,29 @@ np.array([[0.70710677, 0.18696113], [0.0, 0.0], [0.70710677, 0.98236734]]), ] + # input pixels all transparent and below the beta absorbance threshold -NORMALIZE_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)] +NORMALIZE_STAINS_TEST_CASE_1 = [np.full((3, 2, 5), 240)] # input pixels uniformly filled with zeros, and target stain matrix provided -NORMALIZE_STAINS_TEST_CASE_2 = [{"target_he": np.full((3, 2), 1)}, np.zeros((3, 2, 3)), np.full((3, 2, 3), 11)] +NORMALIZE_STAINS_TEST_CASE_2 = [{"ref_stain_coeff": np.full((3, 2), 1)}, np.zeros((3, 2, 4)), np.full((3, 2, 4), 11)] # input pixels uniformly filled with zeros, and target stain matrix not provided NORMALIZE_STAINS_TEST_CASE_3 = [ {}, np.zeros((3, 2, 3)), - np.array([[[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]]]), + np.array([[[63, 63, 63], [63, 63, 63]], [[25, 25, 25], [25, 25, 25]], [[60, 60, 60], [60, 60, 60]]]), ] # input pixels not uniformly filled NORMALIZE_STAINS_TEST_CASE_4 = [ - {"target_he": np.full((3, 2), 1)}, + {"ref_stain_coeff": np.full((3, 2), 1)}, np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]), - np.array([[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]]]), + np.array([[[87, 33, 33], [33, 33, 33]], [[87, 33, 33], [33, 33, 33]], [[87, 33, 33], [33, 33, 33]]]), ] -class TestExtractHEStainsD(unittest.TestCase): +class TestHEStainExtractorD(unittest.TestCase): @parameterized.expand([EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_1]) def test_transparent_image(self, image): """ @@ -77,10 +84,10 @@ def test_transparent_image(self, image): key = "image" if image is None: with self.assertRaises(TypeError): - ExtractHEStainsD([key])({key: image}) + HEStainExtractorD([key])({key: image}) else: with self.assertRaises(ValueError): - ExtractHEStainsD([key])({key: image}) + HEStainExtractorD([key])({key: image}) @parameterized.expand([EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_2, EXTRACT_STAINS_TEST_CASE_3]) def test_identical_result_vectors(self, image): @@ -95,9 +102,9 @@ def test_identical_result_vectors(self, image): key = "image" if image is None: with self.assertRaises(TypeError): - ExtractHEStainsD([key])({key: image}) + HEStainExtractorD([key])({key: image}) else: - result = ExtractHEStainsD([key])({key: image}) + result = HEStainExtractorD([key])({key: image}) np.testing.assert_array_equal(result[key][:, 0], result[key][:, 1]) @parameterized.expand([EXTRACT_STAINS_TEST_CASE_00, EXTRACT_STAINS_TEST_CASE_4, EXTRACT_STAINS_TEST_CASE_5]) @@ -131,13 +138,13 @@ def test_result_value(self, image, expected_data): key = "image" if image is None: with self.assertRaises(TypeError): - ExtractHEStainsD([key])({key: image}) + HEStainExtractorD([key])({key: image}) else: - result = ExtractHEStainsD([key])({key: image}) + result = HEStainExtractorD([key])({key: image}) np.testing.assert_allclose(result[key], expected_data) -class TestNormalizeHEStainsD(unittest.TestCase): +class TestStainNormalizerD(unittest.TestCase): @parameterized.expand([NORMALIZE_STAINS_TEST_CASE_0, NORMALIZE_STAINS_TEST_CASE_1]) def test_transparent_image(self, image): """ @@ -150,10 +157,10 @@ def test_transparent_image(self, image): key = "image" if image is None: with self.assertRaises(TypeError): - NormalizeHEStainsD([key])({key: image}) + StainNormalizerD([key])({key: image}) else: with self.assertRaises(ValueError): - NormalizeHEStainsD([key])({key: image}) + StainNormalizerD([key])({key: image}) @parameterized.expand( [ @@ -197,7 +204,7 @@ def test_result_value(self, argments, image, expected_data): For test case 4: - For this non-uniformly filled image, the stain extracted should be [[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]], as validated for the - ExtractHEStains class. Solving the linear least squares problem (since + HEStainExtractor class. Solving the linear least squares problem (since absorbance matrix = stain matrix * concentration matrix), we obtain the concentration matrix that should be [[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508], [5.8022, 0, 0, 0, 0, 0]] @@ -211,9 +218,9 @@ def test_result_value(self, argments, image, expected_data): key = "image" if image is None: with self.assertRaises(TypeError): - NormalizeHEStainsD([key])({key: image}) + StainNormalizerD([key])({key: image}) else: - result = NormalizeHEStainsD([key], **argments)({key: image}) + result = StainNormalizerD([key], **argments)({key: image}) np.testing.assert_allclose(result[key], expected_data) From 8404d3fcec0885958b8339f7c7066fc6fa691c60 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Sat, 18 Dec 2021 01:00:12 +0000 Subject: [PATCH 2/5] Update docstring and add transform Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/apps/pathology/transforms/stain/array.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/monai/apps/pathology/transforms/stain/array.py b/monai/apps/pathology/transforms/stain/array.py index f094e9ca83..1ce47c5a41 100644 --- a/monai/apps/pathology/transforms/stain/array.py +++ b/monai/apps/pathology/transforms/stain/array.py @@ -16,7 +16,7 @@ from monai.transforms.transform import Transform -class HEStainExtractor: +class HEStainExtractor(Transform): """Extract stain coefficients from an image. Args: @@ -113,15 +113,14 @@ def __call__(self, image: np.ndarray) -> np.ndarray: return ref_stain_coeff -class StainNormalizer: +class StainNormalizer(Transform): """Normalize images to a reference stain color matrix. First, it extracts the stain coefficient matrix from the image using the provided stain extractor. Then, it calculates the stain concentrations based on Beer-Lamber Law. Next, it reconstructs the image using the provided reference stain matrix (stain-normalized image). - Parameters - ---------- + Args: source_intensity: transmitted light intensity. Defaults to 240. alpha: percentiles to ignore for outliers, so to calculate min and max, @@ -153,12 +152,10 @@ def __init__( def __call__(self, image: np.ndarray) -> np.ndarray: """Perform stain normalization. - Parameters - ---------- + Args: image: uint8 RGB image to be stain normalized, pixel values between 0 and 255 - Returns - ------- + Return: image_norm: stain normalized image/patch """ # check image type and values From a047ad17d103d74683419f5e5e37a262c8854769 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Sat, 18 Dec 2021 01:31:54 +0000 Subject: [PATCH 3/5] Update docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/apps.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f4f7aff2d2..312034d625 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -100,15 +100,15 @@ Clara MMARs :members: .. automodule:: monai.apps.pathology.transforms.stain.array -.. autoclass:: ExtractHEStains +.. autoclass:: HEStainExtractor :members: -.. autoclass:: NormalizeHEStains +.. autoclass:: StainNormalizers :members: .. automodule:: monai.apps.pathology.transforms.stain.dictionary -.. autoclass:: ExtractHEStainsd +.. autoclass:: HEStainExtractord :members: -.. autoclass:: NormalizeHEStainsd +.. autoclass:: StainNormalizersd :members: .. automodule:: monai.apps.pathology.transforms.spatial.array From d58e20f8212e7413a064ac1fcf217d040268179e Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Sat, 18 Dec 2021 01:34:12 +0000 Subject: [PATCH 4/5] Fix a typo Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/apps.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 312034d625..c3b272971d 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -102,13 +102,13 @@ Clara MMARs .. automodule:: monai.apps.pathology.transforms.stain.array .. autoclass:: HEStainExtractor :members: -.. autoclass:: StainNormalizers +.. autoclass:: StainNormalizer :members: .. automodule:: monai.apps.pathology.transforms.stain.dictionary .. autoclass:: HEStainExtractord :members: -.. autoclass:: StainNormalizersd +.. autoclass:: StainNormalizerd :members: .. automodule:: monai.apps.pathology.transforms.spatial.array From 3552c898245169cac09dfd170110425e64b58f80 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Sat, 18 Dec 2021 02:19:45 +0000 Subject: [PATCH 5/5] Fix formatting Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/apps/pathology/transforms/stain/array.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/monai/apps/pathology/transforms/stain/array.py b/monai/apps/pathology/transforms/stain/array.py index 1ce47c5a41..6b820b98d9 100644 --- a/monai/apps/pathology/transforms/stain/array.py +++ b/monai/apps/pathology/transforms/stain/array.py @@ -32,12 +32,7 @@ class HEStainExtractor(Transform): Macenko et al., 2009 http://wwwx.cs.unc.edu/~mn/sites/default/files/macenko2009.pdf """ - def __init__( - self, - source_intensity: float = 240, - alpha: float = 1, - beta: float = 0.15, - ) -> None: + def __init__(self, source_intensity: float = 240, alpha: float = 1, beta: float = 0.15) -> None: self.source_intensity = source_intensity self.alpha = alpha self.beta = beta