diff --git a/rt_utils/image_helper.py b/rt_utils/image_helper.py index b030987..8fce9d3 100644 --- a/rt_utils/image_helper.py +++ b/rt_utils/image_helper.py @@ -48,19 +48,27 @@ def get_contours_coords(roi_data: ROIData, series_data): series_contours = [] for i, series_slice in enumerate(series_data): - mask_slice = roi_data.mask[:, :, i] + if roi_data.polygon is None: + mask_slice = roi_data.mask[:, :, i] - # Do not add ROI's for blank slices - if np.sum(mask_slice) == 0: - series_contours.append([]) - continue + # Do not add ROI's for blank slices + if np.sum(mask_slice) == 0: + series_contours.append([]) + continue + + # Create pin hole mask if specified + if roi_data.use_pin_hole: + mask_slice = create_pin_hole_mask(mask_slice, roi_data.approximate_contours) - # Create pin hole mask if specified - if roi_data.use_pin_hole: - mask_slice = create_pin_hole_mask(mask_slice, roi_data.approximate_contours) + # Get contours from mask + contours, _ = find_mask_contours(mask_slice, roi_data.approximate_contours) + else: + if roi_data.polygon[i].area == 0: + # empty ROI + series_contours.append([]) + continue + contours = [roi_data.polygon[i].coords.tolist()] - # Get contours from mask - contours, _ = find_mask_contours(mask_slice, roi_data.approximate_contours) validate_contours(contours) # Format for DICOM diff --git a/rt_utils/rtstruct.py b/rt_utils/rtstruct.py index dfe82be..73ef1c5 100644 --- a/rt_utils/rtstruct.py +++ b/rt_utils/rtstruct.py @@ -3,7 +3,7 @@ import numpy as np from pydicom.dataset import FileDataset -from rt_utils.utils import ROIData +from rt_utils.utils import ROIData, Polygon2D from . import ds_helper, image_helper @@ -28,7 +28,8 @@ def set_series_description(self, description: str): def add_roi( self, - mask: np.ndarray, + mask: np.ndarray=None, + polygon: list=None, color: Union[str, List[int]] = None, name: str = None, description: str = "", @@ -37,17 +38,22 @@ def add_roi( roi_generation_algorithm: Union[str, int] = 0, ): """ - Add a ROI to the rtstruct given a 3D binary mask for the ROI's at each slice + Add a ROI to the rtstruct given a 3D binary mask or list of polygons for the ROI's at each slice Optionally input a color or name for the ROI If use_pin_hole is set to true, will cut a pinhole through ROI's with holes in them so that they are represented with one contour If approximate_contours is set to False, no approximation will be done when generating contour data, leading to much larger amount of contour data """ - # TODO test if name already exists - self.validate_mask(mask) + assert isinstance(mask, (type(None), np.ndarray)) and isinstance(polygon, (type(None), list)) + assert (mask is None) ^ (polygon is None), "Only one of 'mas' and 'polygon' can be set." + if mask is not None: + self.validate_mask(mask) + else: + self.validate_polygon(polygon) + data = mask if mask is not None else polygon roi_number = len(self.ds.StructureSetROISequence) + 1 roi_data = ROIData( - mask, + data, color, roi_number, name, @@ -88,6 +94,19 @@ def validate_mask(self, mask: np.ndarray) -> bool: return True + def validate_polygon(self, polygon: list) -> None: + """ + polygon should be in the format of list of lists. + The inner loop list contains Polygon2D objects representing + ROIs of correspoding slice. The innter loop list can be empty. + + """ + for poly in polygon: + if not isinstance(poly, Polygon2D): + raise RTStruct.ROIException( + f"polygon must be list of Polygon2D objects" + ) + def get_roi_names(self) -> List[str]: """ Returns a list of the names of all ROI within the RTStruct diff --git a/rt_utils/utils.py b/rt_utils/utils.py index f04089e..69bdac1 100644 --- a/rt_utils/utils.py +++ b/rt_utils/utils.py @@ -2,6 +2,8 @@ from random import randrange from pydicom.uid import PYDICOM_IMPLEMENTATION_UID from dataclasses import dataclass +import numpy as np +from PIL import Image, ImageDraw COLOR_PALETTE = [ [255, 0, 255], @@ -41,16 +43,40 @@ class SOPClassUID: @dataclass class ROIData: """Data class to easily pass ROI data to helper methods.""" - - mask: str - color: Union[str, List[int]] - number: int - name: str - frame_of_reference_uid: int - description: str = "" - use_pin_hole: bool = False - approximate_contours: bool = True - roi_generation_algorithm: Union[str, int] = 0 + def __init__(self, + data, + color:str, + number:int, + name: str, + frame_of_reference_uid:int, + description:str, + use_pin_hole:bool=False, + approximate_contours:bool=True, + roi_generation_algorithm: Union[str, int] = 0) -> None: + """ + The ROI data can be in two formats. + 1, a [H, W, N] tensor contain N binary masks where N ths number of slices. + 2, a list of contour coordinates representing the vertex of a polygon ROI + """ + assert isinstance(data, (np.ndarray, list)) + if isinstance(data, np.ndarray): + self.mask = data + self.polygon = None + else: + self.polygon = self.valaidate_polygon(data) + self.mask=self.polygon2mask(data) + self.polygon = data + # set attributes + self.color = color + self.number = number + self.name = name + self.frame_of_reference_uid = frame_of_reference_uid + self.description = description + self.use_pin_hole = use_pin_hole + self.approximate_contours = approximate_contours + self.roi_generation_algorithm = roi_generation_algorithm + + self.__post_init__() def __post_init__(self): self.validate_color() @@ -125,3 +151,58 @@ def validate_roi_generation_algoirthm(self): type(self.roi_generation_algorithm) ) ) + + def valaidate_polygon(self, polygon): + if len(polygon) == 0: + raise ValueError('Empty polygon') + return polygon + + @staticmethod + def polygon2mask(polygon): + h, w = polygon[0].h, polygon[0].w + mask = np.concatenate([p.mask[:, :, None] for p in polygon], axis=-1) + return mask + + +class Polygon2D: + def __init__(self, coords, h, w) -> None: + """ + coords: coordinates of vertice of a polygon [x1, y1, x2, y2, ..., xn, yn] + """ + assert len(coords) % 2 == 0, 'invalid size of coords' + self._coords = np.array(coords).reshape(-1, 2) + self._h, self._w = h, w + self._mask = None + self._area = -1 + + @property + def h(self): + return self._h + + @property + def w(self): + return self._w + + @property + def coords(self): + return self._coords + + @property + def area(self): + if self._area > 0: + return self._area + else: + return self.mask.sum() + + @property + def mask(self): + if self._mask is not None: + return self._mask + else: + if self.coords.shape[0] <= 1: + self._mask = np.zeros((self.h, self.w), dtype=bool) + else: + img = Image.new('L', (self.w, self.h), 0) + ImageDraw.Draw(img).polygon(self.coords.flatten().tolist(), outline=1, fill=1) + self._mask = np.array(img, dtype=bool) + return self._mask