Skip to content

Commit

Permalink
Add strict=True to zip calls and optimize Dimension class.
Browse files Browse the repository at this point in the history
  • Loading branch information
tasansal committed Dec 17, 2024
1 parent 46c5966 commit 15d6b37
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 36 deletions.
3 changes: 1 addition & 2 deletions src/mdio/converters/segy.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,7 @@ def segy_to_mdio( # noqa: C901

# Index the dataset using a spec that interprets the user provided index headers.
index_fields = []
# TODO: Add strict=True and remove noqa when minimum Python is 3.10
for name, byte, format_ in zip(index_names, index_bytes, index_types): # noqa: B905
for name, byte, format_ in zip(index_names, index_bytes, index_types, strict=True):
index_fields.append(HeaderField(name=name, byte=byte, format=format_))
mdio_spec_grid = mdio_spec.customize(trace_header_fields=index_fields)
segy_grid = SegyFile(url=segy_path, spec=mdio_spec_grid, settings=segy_settings)
Expand Down
11 changes: 3 additions & 8 deletions src/mdio/core/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
from mdio.exceptions import ShapeError


# TODO: once min Python >3.10, remove slots attribute and
# add `slots=True` to dataclass decorator and also add
# `kw_only=True` to enforce keyword only initialization.
@dataclass(eq=False, order=False)
@dataclass(eq=False, order=False, slots=True)
class Dimension:
"""Dimension class.
Expand All @@ -28,8 +25,6 @@ class Dimension:
name: Name of the dimension.
"""

__slots__ = ("coords", "name")

coords: list | tuple | NDArray | range
name: str

Expand Down Expand Up @@ -81,11 +76,11 @@ def __eq__(self, other: Dimension) -> bool:

return hash(self) == hash(other)

def min(self) -> NDArray[np.float]:
def min(self) -> NDArray[float]:
"""Get minimum value of dimension."""
return np.min(self.coords)

def max(self) -> NDArray[np.float]:
def max(self) -> NDArray[float]:
"""Get maximum value of dimension."""
return np.max(self.coords)

Expand Down
15 changes: 5 additions & 10 deletions src/mdio/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ def __init__(self, array: Array, chunk_samples: bool = True):
self.len_chunks = self.len_chunks[:-1] + (self.arr_shape[-1],)

# Compute number of chunks per dimension, and total number of chunks
# TODO: Add strict=True and remove noqa when minimum Python is 3.10
self.dim_chunks = [
ceil(len_dim / chunk)
for len_dim, chunk in zip(self.arr_shape, self.len_chunks) # noqa: B905
for len_dim, chunk in zip(self.arr_shape, self.len_chunks, strict=True)
]
self.num_chunks = np.prod(self.dim_chunks)

Expand All @@ -62,27 +61,23 @@ def __next__(self):
# We build slices here. It is dimension agnostic
current_start = next(self._ranges)

# TODO: Add strict=True and remove noqa when minimum Python is 3.10
start_indices = tuple(
dim * chunk
for dim, chunk in zip(current_start, self.len_chunks) # noqa: B905
for dim, chunk in zip(current_start, self.len_chunks, strict=True)
)

# TODO: Add strict=True and remove noqa when minimum Python is 3.10
stop_indices = tuple(
(dim + 1) * chunk
for dim, chunk in zip(current_start, self.len_chunks) # noqa: B905
for dim, chunk in zip(current_start, self.len_chunks, strict=True)
)

# TODO: Add strict=True and remove noqa when minimum Python is 3.10
slices = tuple(
slice(start, stop)
for start, stop in zip(start_indices, stop_indices) # noqa: B905
for start, stop in zip(start_indices, stop_indices, strict=True)
)

self._idx += 1

return slices

else:
raise StopIteration
raise StopIteration
3 changes: 1 addition & 2 deletions src/mdio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def __init__(
shapes: Shapes of the variables for the `message`.
"""
if names is not None and shapes is not None:
# TODO: Add strict=True and remove noqa when minimum Python is 3.10
shape_dict = zip(names, shapes) # noqa: B905
shape_dict = zip(names, shapes, strict=True)
extras = [f"{name}: {shape}" for name, shape in shape_dict]
extras = " <> ".join(extras)

Expand Down
27 changes: 13 additions & 14 deletions src/mdio/segy/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import time
from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence
from enum import Enum
from enum import auto
from typing import TYPE_CHECKING
from typing import Sequence

import numpy as np
from numpy.lib import recfunctions as rfn
Expand Down Expand Up @@ -233,18 +233,20 @@ def create_trace_index(
dtype=np.int16,
):
"""Update dictionary counter tree for counting trace key for auto index."""
if depth == 0:
# If there's no hierarchical depth, no tracing needed.
return None

# Add index header
trace_no_hdr = np.empty(index_headers[header_names[0]].shape, dtype=dtype)
trace_no_field = np.zeros(index_headers.shape, dtype=dtype)
index_headers = rfn.append_fields(
index_headers, "trace", trace_no_hdr, usemask=False
index_headers, "trace", trace_no_field, usemask=False
)
idx = 0
if depth == 0:
return None

for idx_values in zip( # noqa: B905
*(index_headers[header_names[i]] for i in range(depth))
):
# Extract the relevant columns upfront
headers = [index_headers[name] for name in header_names[:depth]]
for idx, idx_values in enumerate(zip(*headers, strict=True)):
if depth == 1:
counter[idx_values[0]] += 1
index_headers["trace"][idx] = counter[idx_values[0]]
Expand Down Expand Up @@ -486,9 +488,8 @@ def transform(
unique_cables, cable_chan_min, cable_chan_max, geom_type = result
logger.info(f"Ingesting dataset as {geom_type.name}")

# TODO: Add strict=True and remove noqa when min Python is 3.10
for cable, chan_min, chan_max in zip( # noqa: B905
unique_cables, cable_chan_min, cable_chan_max
for cable, chan_min, chan_max in zip(
unique_cables, cable_chan_min, cable_chan_max, strict=True
):
logger.info(
f"Cable: {cable} has min chan: {chan_min} and max chan: {chan_max}"
Expand Down Expand Up @@ -601,15 +602,13 @@ def transform(
unique_shot_lines, unique_guns_in_shot_line, geom_type = result
logger.info(f"Ingesting dataset as shot type: {geom_type.name}")

# TODO: Add strict=True and remove noqa when min Python is 3.10
max_num_guns = 1
for shot_line in unique_shot_lines:
logger.info(
f"shot_line: {shot_line} has guns: {unique_guns_in_shot_line[str(shot_line)]}"
)
num_guns = len(unique_guns_in_shot_line[str(shot_line)])
if num_guns > max_num_guns:
max_num_guns = num_guns
max_num_guns = max(num_guns, max_num_guns)

# This might be slow and potentially could be improved with a rewrite
# to prevent so many lookups
Expand Down

0 comments on commit 15d6b37

Please sign in to comment.