Skip to content

Commit

Permalink
Refactors to address @w-k-jones comments
Browse files Browse the repository at this point in the history
  • Loading branch information
freemansw1 committed Feb 28, 2024
1 parent 8679d44 commit 4d54d24
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions tobac/utils/internal/xarray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ def find_axis_from_dim_coord(
Returns ValueError if there are more than one matching dimension name or
if the dimension/coordinate isn't found.
"""

dim_axis = find_axis_from_dim(in_da, dim_coord_name)
try:
dim_axis = find_axis_from_dim(in_da, dim_coord_name)
except ValueError:
dim_axis = None

try:
coord_axes = find_axis_from_coord(in_da, dim_coord_name)
Expand Down Expand Up @@ -96,7 +98,7 @@ def find_axis_from_dim(in_da: xr.DataArray, dim_name: str) -> Union[int, None]:
"More than one matching dimension. Need to specify which axis number or rename "
"your dimensions."
)
return None
raise ValueError("Dimension not found. ")


def find_axis_from_coord(in_da: xr.DataArray, coord_name: str) -> tuple[int]:
Expand Down Expand Up @@ -135,18 +137,18 @@ def find_axis_from_coord(in_da: xr.DataArray, coord_name: str) -> tuple[int]:

if len(all_matching_coords) > 1:
raise ValueError("Too many matching coords")
return tuple()
raise ValueError("No matching coords")


def find_vertical_coord_name(
variable_cube: xr.DataArray,
variable_da: xr.DataArray,
vertical_coord: Union[str, None] = None,
) -> str:
"""Function to find the vertical coordinate in the iris cube
Parameters
----------
variable_cube: xarray.DataArray
variable_da: xarray.DataArray
Input variable cube, containing a vertical coordinate.
vertical_coord: str
Vertical coordinate name. If None, this function tries to auto-detect.
Expand All @@ -162,7 +164,7 @@ def find_vertical_coord_name(
Raised if the vertical coordinate isn't found in the cube.
"""

list_coord_names = variable_cube.coords
list_coord_names = variable_da.coords

if vertical_coord is None or vertical_coord == "auto":
# find the intersection
Expand Down Expand Up @@ -347,14 +349,20 @@ def add_coordinates_to_features(
hdim2_name_original = variable_da.dims[hdim2_axis]

# generate random names for the new coordinates that are based on i, j, k values
hdim1_name_new = "".join(
random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits)
for _ in range(16)
)
hdim2_name_new = "".join(
random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits)
for _ in range(16)
)
hdim1_name_new = "__temp_hdim1_name"
hdim2_name_new = "__temp_hdim2_name"
vdim_name_new = "__temp_vdim_name"

if (
hdim1_name_new in variable_da.dims
or hdim2_name_new in variable_da.dims
or vdim_name_new in variable_da.dims
):
raise ValueError(
"Cannot have dimensions named {0}, {1}, or {2}".format(
hdim1_name_new, hdim2_name_new, vdim_name_new
)
)

dim_new_names = {
hdim1_name_original: hdim1_name_new,
Expand All @@ -367,12 +375,6 @@ def add_coordinates_to_features(

if is_3d:
vdim_name_original = variable_da.dims[vertical_axis]
vdim_name_new = "".join(
random.choice(
string.ascii_uppercase + string.ascii_lowercase + string.digits
)
for _ in range(16)
)
dim_interp_coords[vdim_name_new] = xr.DataArray(
return_feat_df["vdim"].values, dims="features"
)
Expand All @@ -383,9 +385,9 @@ def add_coordinates_to_features(
# dataset
renamed_dim_da = variable_da.swap_dims(dim_new_names)
interpolated_df = renamed_dim_da.interp(coords=dim_interp_coords)
interpolated_df = interpolated_df.drop([hdim1_name_new, hdim2_name_new])
if is_3d:
interpolated_df = interpolated_df.drop([vdim_name_new])
interpolated_df = interpolated_df.drop_vars(
[hdim1_name_new, hdim2_name_new, vdim_name_new], errors="ignore"
)
return_feat_df[time_dim_name] = variable_da[time_dim_name].values[
return_feat_df["frame"]
]
Expand Down

0 comments on commit 4d54d24

Please sign in to comment.