Skip to content

Commit

Permalink
Added preservation of iris datatypes based on whether or not iris dat…
Browse files Browse the repository at this point in the history
…a was passed
  • Loading branch information
freemansw1 committed Dec 4, 2023
1 parent e763274 commit 752d285
Show file tree
Hide file tree
Showing 6 changed files with 351 additions and 302 deletions.
16 changes: 12 additions & 4 deletions tobac/feature_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ def feature_detection_threshold(
return features_threshold, regions


@internal_utils.irispandas_to_xarray
@internal_utils.irispandas_to_xarray()
def feature_detection_multithreshold_timestep(
data_i: xr.DataArray,
i_time: int,
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def feature_detection_multithreshold_timestep(
return features_thresholds


@internal_utils.irispandas_to_xarray
@internal_utils.irispandas_to_xarray(save_iris_info=True)
def feature_detection_multithreshold(
field_in: xr.DataArray,
dxy: float = None,
Expand All @@ -1153,6 +1153,7 @@ def feature_detection_multithreshold(
dz: Union[float, None] = None,
strict_thresholding: bool = False,
statistic: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None,
**kwargs,
) -> pd.DataFrame:
"""Perform feature detection based on contiguous regions.
Expand Down Expand Up @@ -1401,10 +1402,17 @@ def feature_detection_multithreshold(
# features_filtered.drop(columns=['idx','num','threshold_value'],inplace=True)
if "vdim" in features:
features = add_coordinates_3D(
features, field_in, vertical_coord=vertical_coord
features,
field_in,
vertical_coord=vertical_coord,
preserve_iris_datetime_types=kwargs["converted_from_iris"],
)
else:
features = add_coordinates(features, field_in)
features = add_coordinates(
features,
field_in,
preserve_iris_datetime_types=kwargs["converted_from_iris"],
)
else:
features = None
logging.debug("No features detected")
Expand Down
14 changes: 8 additions & 6 deletions tobac/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ def test_function_kwarg(test_input, kwarg=None):
def test_function_tuple_output(test_input, kwarg=None):
return (test_input, test_input)

decorated_function_kwarg = decorator(test_function_kwarg)
decorated_function_tuple = decorator(test_function_tuple_output)
decorator_i = decorator()
decorated_function_kwarg = decorator_i(test_function_kwarg)
decorated_function_tuple = decorator_i(test_function_tuple_output)

if input_types[0] == xarray.DataArray:
data = xarray.DataArray.from_iris(tobac.testing.make_simple_sample_data_2D())
Expand Down Expand Up @@ -227,15 +228,16 @@ def test_xarray_workflow():
data_xarray = xarray.DataArray.from_iris(deepcopy(data))

# Testing the get_spacings utility
get_spacings_xarray = xarray_to_iris(tobac.utils.get_spacings)
xarray_to_iris_i = xarray_to_iris()
get_spacings_xarray = xarray_to_iris_i(tobac.utils.get_spacings)
dxy, dt = tobac.utils.get_spacings(data)
dxy_xarray, dt_xarray = get_spacings_xarray(data_xarray)

assert dxy == dxy_xarray
assert dt == dt_xarray

# Testing feature detection
feature_detection_xarray = xarray_to_iris(
feature_detection_xarray = xarray_to_iris_i(
tobac.feature_detection.feature_detection_multithreshold
)
features = tobac.feature_detection.feature_detection_multithreshold(
Expand All @@ -246,7 +248,7 @@ def test_xarray_workflow():
assert_frame_equal(features, features_xarray)

# Testing the segmentation
segmentation_xarray = xarray_to_iris(tobac.segmentation.segmentation)
segmentation_xarray = xarray_to_iris_i(tobac.segmentation.segmentation)
mask, features = tobac.segmentation.segmentation(features, data, dxy, threshold=1.0)
mask_xarray, features_xarray = segmentation_xarray(
features_xarray, data_xarray, dxy_xarray, threshold=1.0
Expand All @@ -255,7 +257,7 @@ def test_xarray_workflow():
assert (mask.data == mask_xarray.to_iris().data).all()

# testing tracking
tracking_xarray = xarray_to_iris(tobac.tracking.linking_trackpy)
tracking_xarray = xarray_to_iris_i(tobac.tracking.linking_trackpy)
track = tobac.tracking.linking_trackpy(features, data, dt, dxy, v_max=100.0)
track_xarray = tracking_xarray(
features_xarray, data_xarray, dt_xarray, dxy_xarray, v_max=100.0
Expand Down
2 changes: 1 addition & 1 deletion tobac/utils/bulk_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_statistics(
return features


@decorators.iris_to_xarray
@decorators.iris_to_xarray()
def get_statistics_from_mask(
features: pd.DataFrame,
segmentation_mask: xr.DataArray,
Expand Down
Loading

0 comments on commit 752d285

Please sign in to comment.