Skip to content

Commit

Permalink
add family finding
Browse files Browse the repository at this point in the history
  • Loading branch information
freemansw1 committed Jan 17, 2024
1 parent 654c813 commit 429ee42
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 7 deletions.
22 changes: 19 additions & 3 deletions tobac/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,10 +591,26 @@ def test_transform_feature_points_3D():
def test_identify_feature_families():
"""tests tobac.utils.general.identify_feature_families"""
orig_feat_df_1 = tb_test.generate_single_feature(
0, 95, 10, max_h1=1000, max_h2=1000
10, 30, 10, max_h1=50, max_h2=50, feature_num=1
)
orig_feat_df_2 = tb_test.generate_single_feature(
5, 105, 20, max_h1=1000, max_h2=1000
30, 30, 20, max_h1=50, max_h2=50, feature_num=2
)

orig_feat_df = tb_utils.combine_feature_dataframes(
[orig_feat_df_1, orig_feat_df_2], renumber_features=False
)

orig_feat_df = tb_utils.combine_feature_dataframes([orig_feat_df_1, orig_feat_df_2])
# make fake segmentation
test_arr = np.zeros((2, 50, 50), dtype=int)
test_arr[0, 5:15, 20:40] = 1
test_arr[0, 15:40, 20:40] = 2

test_xr = xr.DataArray(data=test_arr, dims=["time", "hdim_1", "hdim_2"])

out_df, out_grid = tb_utils.general.identify_feature_families(
orig_feat_df, test_xr, return_grid=True, family_column_name="family"
)
assert np.unique(out_df["family"] == 1)
assert np.all(out_grid[0, 5:15, 20:40] == 1)
assert np.all(out_grid[0, 15:40, 20:40] == 1)
53 changes: 50 additions & 3 deletions tobac/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
"""
import copy
import logging
from typing import Union

import pandas as pd
import skimage

from . import internal as internal_utils
import numpy as np
Expand Down Expand Up @@ -919,8 +921,11 @@ def standardize_track_dataset(TrackedFeatures, Mask, Projection=None):
def identify_feature_families(
feature_df: pd.DataFrame,
in_segmentation: xr.DataArray,
start_family_number: int = 0,
):
return_grid: bool = False,
family_column_name: str = "feature_family_id",
unsegmented_point_values: int = 0,
below_threshold_values: int = -1,
) -> Union[tuple[pd.DataFrame, xr.DataArray], pd.DataFrame]:
"""
Parameters
Expand All @@ -934,4 +939,46 @@ def identify_feature_families(
-------
"""
pass

# we need to label the data, but we currently label using skimage label, not dask label.


# 3D should be 4-D (time, then 3 spatial).
# 2D should be 3-D (time, then 2 spatial)
is_3D = len(in_segmentation.shape) ==4
seg_family_dict = dict()
out_families = copy.deepcopy(in_segmentation)

for time_index in range(in_segmentation.shape[0]):
in_arr = np.array(in_segmentation.values[time_index])

segmented_arr = np.logical_and(
in_arr != unsegmented_point_values, in_arr != below_threshold_values
)
# These are our families
family_labeled_data = skimage.measure.label(
segmented_arr,
)

# now we need to note feature->family relationship in the dataframe.
segmentation_props = skimage.measure.regionprops(in_arr)

# associate feature ID -> family ID
for seg_area in segmentation_props:
if is_3D:
seg_family = family_labeled_data[seg_area.coords[0, 0], seg_area.coords[0, 1], seg_area.cords[0,2]]
else:
seg_family = family_labeled_data[seg_area.coords[0, 0], seg_area.coords[0, 1]]
seg_family_dict[seg_area.label] = seg_family

out_families[time_index] = segmented_arr

family_series = pd.Series(seg_family_dict, name=family_column_name)
feature_series = pd.Series({x: x for x in seg_family_dict.keys()}, name="feature")
family_df = pd.concat([family_series, feature_series], axis=1)
out_df = feature_df.merge(family_df, on="feature", how="inner")

if return_grid:
return out_df, out_families
else:
return out_df
2 changes: 1 addition & 1 deletion tobac/utils/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import skimage.measure
import xarray as xr
import iris
import iris.cube
import warnings


Expand Down Expand Up @@ -110,7 +111,6 @@ def iris_to_xarray(func):
Function including decorator
"""

import iris
import xarray

def wrapper(*args, **kwargs):
Expand Down

0 comments on commit 429ee42

Please sign in to comment.