Skip to content

Commit

Permalink
updates from @w-k-jones review
Browse files Browse the repository at this point in the history
  • Loading branch information
freemansw1 committed Feb 23, 2024
1 parent ca5631e commit 8679d44
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 22 deletions.
11 changes: 3 additions & 8 deletions tobac/feature_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,11 +1246,6 @@ def feature_detection_multithreshold(
logging.debug("start feature detection based on thresholds")

ndim_time = internal_utils.find_axis_from_coord(field_in, time_var_name)
if ndim_time is None:
raise ValueError(
"input to feature detection step must include a dimension named "
+ time_var_name
)

# Check whether we need to run 2D or 3D feature detection
if field_in.ndim == 3:
Expand Down Expand Up @@ -1280,7 +1275,7 @@ def feature_detection_multithreshold(
if vertical_axis is None:
# We need to determine vertical axis.
# first, find the name of the vertical axis
vertical_axis_name = internal_utils.find_vertical_axis_from_coord(
vertical_axis_name = internal_utils.find_vertical_coord_name(
field_in, vertical_coord=vertical_coord
)
# then find our axis number.
Expand Down Expand Up @@ -1341,8 +1336,8 @@ def feature_detection_multithreshold(
"given in meter."
)

for i_time, data_i in enumerate(field_in.transpose(time_var_name, ...)):
time_i = data_i[time_var_name].values
for i_time, time_i in enumerate(field_in.coords[time_var_name]):
data_i = field_in.isel({time_var_name: i_time})

features_thresholds = feature_detection_multithreshold_timestep(
data_i,
Expand Down
2 changes: 1 addition & 1 deletion tobac/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def segmentation_timestep(
hdim_1_axis = 0
hdim_2_axis = 1
elif field_in.ndim == 3:
vertical_axis = internal_utils.find_vertical_axis_from_coord(
vertical_axis = internal_utils.find_vertical_coord_name(
field_in, vertical_coord=vertical_coord
)
ndim_vertical = field_in.coord_dims(vertical_axis)
Expand Down
2 changes: 1 addition & 1 deletion tobac/tests/test_xarray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
},
"latitude", # coord_looking_for
None,
False,
True,
),
(
("time", "altitude", "x", "y"), # dim_names
Expand Down
4 changes: 1 addition & 3 deletions tobac/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,7 @@ def transform_feature_points(
RADIUS_EARTH_M = 6371000
is_3D = "vdim" in features
if is_3D:
vert_coord = internal_utils.find_vertical_axis_from_coord(
new_dataset, altitude_name
)
vert_coord = internal_utils.find_vertical_coord_name(new_dataset, altitude_name)

lat_coord, lon_coord = internal_utils.detect_latlon_coord_name(
new_dataset, latitude_name=latitude_name, longitude_name=longitude_name
Expand Down
2 changes: 1 addition & 1 deletion tobac/utils/internal/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get_indices_of_labels_from_reg_prop_dict(region_property_dict: dict) -> tupl
return [curr_loc_indices, y_indices, x_indices]


def find_vertical_axis_from_coord(
def find_vertical_coord_name(
variable_cube: Union[iris.cube.Cube, xr.DataArray],
vertical_coord: Union[str, None] = None,
) -> str:
Expand Down
2 changes: 1 addition & 1 deletion tobac/utils/internal/iris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def add_coordinates_3D(
ndim_vertical = vertical_coord
vertical_axis = None
else:
vertical_axis = tb_utils_gi.find_vertical_axis_from_coord(
vertical_axis = tb_utils_gi.find_vertical_coord_name(
variable_cube, vertical_coord=vertical_coord
)

Expand Down
12 changes: 5 additions & 7 deletions tobac/utils/internal/xarray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import copy
from typing import Union

import cftime
import numpy as np
import pandas as pd
import xarray as xr
from . import basic as tb_utils_gi
import datetime
import random, string
import random
import string


def find_axis_from_dim_coord(
Expand All @@ -30,8 +29,7 @@ def find_axis_from_dim_coord(
Returns
-------
axis_number: int
the number of the axis of the given coordinate, or None if the coordinate
is not found in the cube or not a dimensional coordinate
the number of the axis of the given coordinate
Raises
------
Expand All @@ -57,8 +55,8 @@ def find_axis_from_dim_coord(
return dim_axis
if dim_axis is None and len(coord_axes) == 1:
return coord_axes[0]

return None
raise ValueError("Coordinate/Dimension " + dim_coord_name + " not found.")
# return None


def find_axis_from_dim(in_da: xr.DataArray, dim_name: str) -> Union[int, None]:
Expand Down

0 comments on commit 8679d44

Please sign in to comment.