Skip to content

Commit

Permalink
formatting and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
freemansw1 committed Jan 17, 2024
1 parent 429ee42 commit 043e5d1
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions tobac/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,25 +927,39 @@ def identify_feature_families(
below_threshold_values: int = -1,
) -> Union[tuple[pd.DataFrame, xr.DataArray], pd.DataFrame]:
"""
Function to identify families/storm systems by identifying where segmentation touches.
At a given time, segmentation areas are considered part of the same family if they
touch at any point.
Parameters
----------
feature_df: pd.DataFrame
Input feature dataframe
in_segmentation: xr.Data
start_family_number
in_segmentation: xr.DataArray
Input segmentation
return_grid: bool
Whether to return the segmentation grid showing families
family_column_name: str
The name in the output dataframe of the family ID
unsegmented_point_values: int
The value in the input segmentation for unsegmented but above threshold points
below_threshold_values: int
The value in the input segmentation for below threshold points
Returns
-------
pd.DataFrame and xr.DataArray or pd.DataFrame
Input dataframe with family IDs associated with each feature
if return_grid is True, the segmentation grid showing families is
also returned.
"""

# we need to label the data, but we currently label using skimage label, not dask label.


# 3D should be 4-D (time, then 3 spatial).
# 2D should be 3-D (time, then 2 spatial)
is_3D = len(in_segmentation.shape) ==4
is_3D = len(in_segmentation.shape) == 4
seg_family_dict = dict()
out_families = copy.deepcopy(in_segmentation)

Expand All @@ -966,9 +980,13 @@ def identify_feature_families(
# associate feature ID -> family ID
for seg_area in segmentation_props:
if is_3D:
seg_family = family_labeled_data[seg_area.coords[0, 0], seg_area.coords[0, 1], seg_area.cords[0,2]]
seg_family = family_labeled_data[
seg_area.coords[0, 0], seg_area.coords[0, 1], seg_area.cords[0, 2]
]
else:
seg_family = family_labeled_data[seg_area.coords[0, 0], seg_area.coords[0, 1]]
seg_family = family_labeled_data[
seg_area.coords[0, 0], seg_area.coords[0, 1]
]
seg_family_dict[seg_area.label] = seg_family

out_families[time_index] = segmented_arr
Expand Down

0 comments on commit 043e5d1

Please sign in to comment.