diff --git a/tobac/feature_detection.py b/tobac/feature_detection.py index 041686bb..24d4600d 100644 --- a/tobac/feature_detection.py +++ b/tobac/feature_detection.py @@ -1246,11 +1246,6 @@ def feature_detection_multithreshold( logging.debug("start feature detection based on thresholds") ndim_time = internal_utils.find_axis_from_coord(field_in, time_var_name) - if ndim_time is None: - raise ValueError( - "input to feature detection step must include a dimension named " - + time_var_name - ) # Check whether we need to run 2D or 3D feature detection if field_in.ndim == 3: @@ -1280,7 +1275,7 @@ def feature_detection_multithreshold( if vertical_axis is None: # We need to determine vertical axis. # first, find the name of the vertical axis - vertical_axis_name = internal_utils.find_vertical_axis_from_coord( + vertical_axis_name = internal_utils.find_vertical_coord_name( field_in, vertical_coord=vertical_coord ) # then find our axis number. @@ -1341,8 +1336,8 @@ def feature_detection_multithreshold( "given in meter." ) - for i_time, data_i in enumerate(field_in.transpose(time_var_name, ...)): - time_i = data_i[time_var_name].values + for i_time, time_i in enumerate(field_in.coords[time_var_name]): + data_i = field_in.isel({time_var_name: i_time}) features_thresholds = feature_detection_multithreshold_timestep( data_i, diff --git a/tobac/segmentation.py b/tobac/segmentation.py index b59d7975..00f0b21c 100644 --- a/tobac/segmentation.py +++ b/tobac/segmentation.py @@ -457,7 +457,7 @@ def segmentation_timestep( hdim_1_axis = 0 hdim_2_axis = 1 elif field_in.ndim == 3: - vertical_axis = internal_utils.find_vertical_axis_from_coord( + vertical_axis = internal_utils.find_vertical_coord_name( field_in, vertical_coord=vertical_coord ) ndim_vertical = field_in.coord_dims(vertical_axis) diff --git a/tobac/tests/test_xarray_utils.py b/tobac/tests/test_xarray_utils.py index 210a1a90..a73ca31c 100644 --- a/tobac/tests/test_xarray_utils.py +++ b/tobac/tests/test_xarray_utils.py @@ -56,7 +56,7 @@ }, "latitude", # coord_looking_for None, - False, + True, ), ( ("time", "altitude", "x", "y"), # dim_names diff --git a/tobac/utils/general.py b/tobac/utils/general.py index 1f40cc0c..be827014 100644 --- a/tobac/utils/general.py +++ b/tobac/utils/general.py @@ -487,9 +487,7 @@ def transform_feature_points( RADIUS_EARTH_M = 6371000 is_3D = "vdim" in features if is_3D: - vert_coord = internal_utils.find_vertical_axis_from_coord( - new_dataset, altitude_name - ) + vert_coord = internal_utils.find_vertical_coord_name(new_dataset, altitude_name) lat_coord, lon_coord = internal_utils.detect_latlon_coord_name( new_dataset, latitude_name=latitude_name, longitude_name=longitude_name diff --git a/tobac/utils/internal/basic.py b/tobac/utils/internal/basic.py index 1ba787ea..810b2977 100644 --- a/tobac/utils/internal/basic.py +++ b/tobac/utils/internal/basic.py @@ -110,7 +110,7 @@ def get_indices_of_labels_from_reg_prop_dict(region_property_dict: dict) -> tupl return [curr_loc_indices, y_indices, x_indices] -def find_vertical_axis_from_coord( +def find_vertical_coord_name( variable_cube: Union[iris.cube.Cube, xr.DataArray], vertical_coord: Union[str, None] = None, ) -> str: diff --git a/tobac/utils/internal/iris_utils.py b/tobac/utils/internal/iris_utils.py index 3e3c301f..f5210404 100644 --- a/tobac/utils/internal/iris_utils.py +++ b/tobac/utils/internal/iris_utils.py @@ -370,7 +370,7 @@ def add_coordinates_3D( ndim_vertical = vertical_coord vertical_axis = None else: - vertical_axis = tb_utils_gi.find_vertical_axis_from_coord( + vertical_axis = tb_utils_gi.find_vertical_coord_name( variable_cube, vertical_coord=vertical_coord ) diff --git a/tobac/utils/internal/xarray_utils.py b/tobac/utils/internal/xarray_utils.py index 035297cb..e7a649a1 100644 --- a/tobac/utils/internal/xarray_utils.py +++ b/tobac/utils/internal/xarray_utils.py @@ -5,13 +5,12 @@ import copy from typing import Union -import cftime import numpy as np import pandas as pd import xarray as xr from . import basic as tb_utils_gi -import datetime -import random, string +import random +import string def find_axis_from_dim_coord( @@ -30,8 +29,7 @@ def find_axis_from_dim_coord( Returns ------- axis_number: int - the number of the axis of the given coordinate, or None if the coordinate - is not found in the cube or not a dimensional coordinate + the number of the axis of the given coordinate Raises ------ @@ -57,8 +55,8 @@ def find_axis_from_dim_coord( return dim_axis if dim_axis is None and len(coord_axes) == 1: return coord_axes[0] - - return None + raise ValueError("Coordinate/Dimension " + dim_coord_name + " not found.") + # return None def find_axis_from_dim(in_da: xr.DataArray, dim_name: str) -> Union[int, None]: