From 15d6b379c1ee2df9a86b106fe1c3d2fbbb755f42 Mon Sep 17 00:00:00 2001 From: Altay Sansal Date: Tue, 17 Dec 2024 09:38:27 -0600 Subject: [PATCH] Add `strict=True` to zip calls and optimize Dimension class. --- src/mdio/converters/segy.py | 3 +-- src/mdio/core/dimension.py | 11 +++-------- src/mdio/core/indexing.py | 15 +++++---------- src/mdio/exceptions.py | 3 +-- src/mdio/segy/geometry.py | 27 +++++++++++++-------------- 5 files changed, 23 insertions(+), 36 deletions(-) diff --git a/src/mdio/converters/segy.py b/src/mdio/converters/segy.py index b1b4fa43..fd68dffe 100644 --- a/src/mdio/converters/segy.py +++ b/src/mdio/converters/segy.py @@ -385,8 +385,7 @@ def segy_to_mdio( # noqa: C901 # Index the dataset using a spec that interprets the user provided index headers. index_fields = [] - # TODO: Add strict=True and remove noqa when minimum Python is 3.10 - for name, byte, format_ in zip(index_names, index_bytes, index_types): # noqa: B905 + for name, byte, format_ in zip(index_names, index_bytes, index_types, strict=True): index_fields.append(HeaderField(name=name, byte=byte, format=format_)) mdio_spec_grid = mdio_spec.customize(trace_header_fields=index_fields) segy_grid = SegyFile(url=segy_path, spec=mdio_spec_grid, settings=segy_settings) diff --git a/src/mdio/core/dimension.py b/src/mdio/core/dimension.py index 0cd16350..8151e8d9 100644 --- a/src/mdio/core/dimension.py +++ b/src/mdio/core/dimension.py @@ -13,10 +13,7 @@ from mdio.exceptions import ShapeError -# TODO: once min Python >3.10, remove slots attribute and -# add `slots=True` to dataclass decorator and also add -# `kw_only=True` to enforce keyword only initialization. -@dataclass(eq=False, order=False) +@dataclass(eq=False, order=False, slots=True) class Dimension: """Dimension class. @@ -28,8 +25,6 @@ class Dimension: name: Name of the dimension. """ - __slots__ = ("coords", "name") - coords: list | tuple | NDArray | range name: str @@ -81,11 +76,11 @@ def __eq__(self, other: Dimension) -> bool: return hash(self) == hash(other) - def min(self) -> NDArray[np.float]: + def min(self) -> NDArray[float]: """Get minimum value of dimension.""" return np.min(self.coords) - def max(self) -> NDArray[np.float]: + def max(self) -> NDArray[float]: """Get maximum value of dimension.""" return np.max(self.coords) diff --git a/src/mdio/core/indexing.py b/src/mdio/core/indexing.py index 4db965fb..864ec8d8 100644 --- a/src/mdio/core/indexing.py +++ b/src/mdio/core/indexing.py @@ -35,10 +35,9 @@ def __init__(self, array: Array, chunk_samples: bool = True): self.len_chunks = self.len_chunks[:-1] + (self.arr_shape[-1],) # Compute number of chunks per dimension, and total number of chunks - # TODO: Add strict=True and remove noqa when minimum Python is 3.10 self.dim_chunks = [ ceil(len_dim / chunk) - for len_dim, chunk in zip(self.arr_shape, self.len_chunks) # noqa: B905 + for len_dim, chunk in zip(self.arr_shape, self.len_chunks, strict=True) ] self.num_chunks = np.prod(self.dim_chunks) @@ -62,27 +61,23 @@ def __next__(self): # We build slices here. It is dimension agnostic current_start = next(self._ranges) - # TODO: Add strict=True and remove noqa when minimum Python is 3.10 start_indices = tuple( dim * chunk - for dim, chunk in zip(current_start, self.len_chunks) # noqa: B905 + for dim, chunk in zip(current_start, self.len_chunks, strict=True) ) - # TODO: Add strict=True and remove noqa when minimum Python is 3.10 stop_indices = tuple( (dim + 1) * chunk - for dim, chunk in zip(current_start, self.len_chunks) # noqa: B905 + for dim, chunk in zip(current_start, self.len_chunks, strict=True) ) - # TODO: Add strict=True and remove noqa when minimum Python is 3.10 slices = tuple( slice(start, stop) - for start, stop in zip(start_indices, stop_indices) # noqa: B905 + for start, stop in zip(start_indices, stop_indices, strict=True) ) self._idx += 1 return slices - else: - raise StopIteration + raise StopIteration diff --git a/src/mdio/exceptions.py b/src/mdio/exceptions.py index 23f30069..740eb1fc 100644 --- a/src/mdio/exceptions.py +++ b/src/mdio/exceptions.py @@ -24,8 +24,7 @@ def __init__( shapes: Shapes of the variables for the `message`. """ if names is not None and shapes is not None: - # TODO: Add strict=True and remove noqa when minimum Python is 3.10 - shape_dict = zip(names, shapes) # noqa: B905 + shape_dict = zip(names, shapes, strict=True) extras = [f"{name}: {shape}" for name, shape in shape_dict] extras = " <> ".join(extras) diff --git a/src/mdio/segy/geometry.py b/src/mdio/segy/geometry.py index c48753a0..b48b15a6 100644 --- a/src/mdio/segy/geometry.py +++ b/src/mdio/segy/geometry.py @@ -6,10 +6,10 @@ import time from abc import ABC from abc import abstractmethod +from collections.abc import Sequence from enum import Enum from enum import auto from typing import TYPE_CHECKING -from typing import Sequence import numpy as np from numpy.lib import recfunctions as rfn @@ -233,18 +233,20 @@ def create_trace_index( dtype=np.int16, ): """Update dictionary counter tree for counting trace key for auto index.""" + if depth == 0: + # If there's no hierarchical depth, no tracing needed. + return None + # Add index header - trace_no_hdr = np.empty(index_headers[header_names[0]].shape, dtype=dtype) + trace_no_field = np.zeros(index_headers.shape, dtype=dtype) index_headers = rfn.append_fields( - index_headers, "trace", trace_no_hdr, usemask=False + index_headers, "trace", trace_no_field, usemask=False ) idx = 0 - if depth == 0: - return None - for idx_values in zip( # noqa: B905 - *(index_headers[header_names[i]] for i in range(depth)) - ): + # Extract the relevant columns upfront + headers = [index_headers[name] for name in header_names[:depth]] + for idx, idx_values in enumerate(zip(*headers, strict=True)): if depth == 1: counter[idx_values[0]] += 1 index_headers["trace"][idx] = counter[idx_values[0]] @@ -486,9 +488,8 @@ def transform( unique_cables, cable_chan_min, cable_chan_max, geom_type = result logger.info(f"Ingesting dataset as {geom_type.name}") - # TODO: Add strict=True and remove noqa when min Python is 3.10 - for cable, chan_min, chan_max in zip( # noqa: B905 - unique_cables, cable_chan_min, cable_chan_max + for cable, chan_min, chan_max in zip( + unique_cables, cable_chan_min, cable_chan_max, strict=True ): logger.info( f"Cable: {cable} has min chan: {chan_min} and max chan: {chan_max}" @@ -601,15 +602,13 @@ def transform( unique_shot_lines, unique_guns_in_shot_line, geom_type = result logger.info(f"Ingesting dataset as shot type: {geom_type.name}") - # TODO: Add strict=True and remove noqa when min Python is 3.10 max_num_guns = 1 for shot_line in unique_shot_lines: logger.info( f"shot_line: {shot_line} has guns: {unique_guns_in_shot_line[str(shot_line)]}" ) num_guns = len(unique_guns_in_shot_line[str(shot_line)]) - if num_guns > max_num_guns: - max_num_guns = num_guns + max_num_guns = max(num_guns, max_num_guns) # This might be slow and potentially could be improved with a rewrite # to prevent so many lookups