diff --git a/.vscode/settings.json b/.vscode/settings.json index eb47532fc..25c378bcd 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -9,4 +9,5 @@ "restructuredtext.confPath": "${workspaceFolder}\\docs", "files.eol": "\n", "python.formatting.provider": "black", + "editor.formatOnSave": true } \ No newline at end of file diff --git a/mikeio/__init__.py b/mikeio/__init__.py index b16868394..9caafc3c4 100644 --- a/mikeio/__init__.py +++ b/mikeio/__init__.py @@ -1,5 +1,7 @@ +from __future__ import annotations from pathlib import Path from platform import architecture +from typing import Sequence # PEP0440 compatible formatted version, see: # https://www.python.org/dev/peps/pep-0440/ @@ -43,8 +45,14 @@ from .xyz import read_xyz - -def read(filename, *, items=None, time=None, keepdims=False, **kwargs) -> Dataset: +def read( + filename: str | Path, + *, + items: str | int | Sequence[str |int] | None = None, + time: int | str | slice | None = None, + keepdims: bool = False, + **kwargs, +) -> Dataset: """Read all or a subset of the data from a dfs file All dfs files can be subsetted with the *items* and *time* arguments. But @@ -127,7 +135,7 @@ def read(filename, *, items=None, time=None, keepdims=False, **kwargs) -> Datase return dfs.read(items=items, time=time, keepdims=keepdims, **kwargs) -def open(filename: str, **kwargs): +def open(filename: str | Path, **kwargs): """Open a dfs/mesh file (and read the header) The typical workflow for small dfs files is to read all data @@ -177,6 +185,7 @@ def open(filename: str, **kwargs): return reader_klass(filename, **kwargs) + __all__ = [ "DataArray", "Dataset", @@ -199,4 +208,4 @@ def open(filename: str, **kwargs): "read_xyz", "read", "open", -] \ No newline at end of file +] diff --git a/mikeio/_interpolation.py b/mikeio/_interpolation.py index 54a8cf027..69ca15eeb 100644 --- a/mikeio/_interpolation.py +++ b/mikeio/_interpolation.py @@ -1,10 +1,15 @@ +from __future__ import annotations +from typing import Tuple, TYPE_CHECKING import numpy as np +from numpy.typing import NDArray + +if TYPE_CHECKING: + from .dataset import Dataset, DataArray from .spatial import GeometryUndefined -# class Interpolation2D: -def get_idw_interpolant(distances, p=2): +def get_idw_interpolant(distances:NDArray[np.floating], p:float=2) -> NDArray[np.floating]: """IDW interpolant for 2d array of distances https://pro.arcgis.com/en/pro-app/help/analysis/geostatistical-analyst/how-inverse-distance-weighted-interpolation-works.htm @@ -40,7 +45,7 @@ def get_idw_interpolant(distances, p=2): return weights -def interp2d(data, elem_ids, weights=None, shape=None): +def interp2d(data: NDArray[np.floating] | Dataset | DataArray, elem_ids: NDArray[np.integer], weights:NDArray[np.floating] | None=None, shape:Tuple[int,...]|None=None) -> NDArray[np.floating] | Dataset: """interp spatially in data (2d only) Parameters @@ -124,7 +129,7 @@ def interp2d(data, elem_ids, weights=None, shape=None): return idatitem -def _interp_itemstep(data, elem_ids, weights=None): +def _interp_itemstep(data: NDArray[np.floating], elem_ids: NDArray[np.integer], weights:NDArray[np.floating] | None =None) -> NDArray[np.floating]: """Interpolate a single item and time step Parameters diff --git a/mikeio/_spectral.py b/mikeio/_spectral.py index 4facdaa2e..32c7586b5 100644 --- a/mikeio/_spectral.py +++ b/mikeio/_spectral.py @@ -1,23 +1,28 @@ +from __future__ import annotations +from typing import Sequence, Tuple, Literal import numpy as np +from matplotlib.axes import Axes +from matplotlib.projections.polar import PolarAxes +from numpy.typing import NDArray def plot_2dspectrum( - spectrum, - frequencies, - directions, - plot_type="contourf", - title=None, - label=None, - cmap="Reds", - vmin=1e-5, - vmax=None, - r_as_periods=True, - rmin=None, - rmax=None, - levels=None, - figsize=(7, 7), - add_colorbar=True, -): + spectrum :NDArray[np.floating], + frequencies :NDArray[np.floating], + directions :NDArray[np.floating], + plot_type: str ="contourf", + title:str | None =None, + label:str | None=None, + cmap:str="Reds", + vmin:float | None =1e-5, + vmax:float | None =None, + r_as_periods:bool=True, + rmin:float | None=None, + rmax:float | None=None, + levels: int| Sequence[float] | None=None, + figsize: Tuple[float,float]=(7, 7), + add_colorbar:bool=True, +) -> Axes: """ Plot spectrum in polar coordinates @@ -67,7 +72,7 @@ def plot_2dspectrum( >>> dfs.plot_spectrum(spectrum, rmax=9, title="Wave spectrum T<9s") """ - import matplotlib.pyplot as plt # type: ignore + import matplotlib.pyplot as plt if (frequencies is None or len(frequencies) <= 1) and ( directions is None or len(directions) <= 1 @@ -83,13 +88,14 @@ def plot_2dspectrum( spectrum = np.fliplr(spectrum) fig = plt.figure(figsize=figsize) - ax = plt.subplot(111, polar=True) + ax = plt.subplot(111, polar=True) # type: ignore + assert isinstance(ax, PolarAxes) ax.set_theta_direction(-1) ax.set_theta_zero_location("N") ddir = dirs[1] - dirs[0] - def is_circular(dir): + def is_circular(dir: NDArray[np.floating]) -> bool: dir_diff = np.mod(dir[0], 2 * np.pi) - np.mod(dir[-1] + ddir, 2 * np.pi) return np.abs(dir_diff) < 1e-6 @@ -116,9 +122,9 @@ def is_circular(dir): if levels is None: levels = 10 n_levels = 10 - if np.isscalar(levels): + if isinstance(levels, int): n_levels = levels - levels = np.linspace(vmin, vmax, n_levels) + levels = np.linspace(vmin, vmax, n_levels) # type: ignore if plot_type != "shaded": spectrum[spectrum < vmin] = np.nan @@ -128,7 +134,7 @@ def is_circular(dir): dirs, freq, spectrum.T, levels=levels, cmap=cmap, vmin=vmin, vmax=vmax ) elif plot_type == "contour": - colorax = ax.contour( + colorax = ax.contour( # type: ignore dirs, freq, spectrum.T, levels=levels, cmap=cmap, vmin=vmin, vmax=vmax ) # ax.clabel(colorax, fmt="%1.2f", inline=1, fontsize=9) @@ -136,9 +142,9 @@ def is_circular(dir): ax.set_title(label) elif plot_type in ("patch", "shaded", "box"): - shading = "gouraud" if plot_type == "shaded" else "auto" + shading: Literal['flat', 'nearest', 'gouraud', 'auto'] = "gouraud" if plot_type == "shaded" else "auto" ax.grid(False) # Remove major grid - colorax = ax.pcolormesh( + colorax = ax.pcolormesh( # type: ignore dirs, freq, spectrum.T, @@ -147,14 +153,14 @@ def is_circular(dir): vmin=vmin, vmax=vmax, ) - ax.grid("on") + ax.grid("on") # type: ignore else: raise ValueError( f"plot_type '{plot_type}' not supported (contour, contourf, patch, shaded)" ) # TODO: optional - ax.set_thetagrids( + ax.set_thetagrids( # type: ignore [0.0, 45, 90.0, 135, 180.0, 225, 270.0, 315], labels=["N", "N-E", "E", "S-E", "S", "S-W", "W", "N-W"], ) @@ -171,9 +177,9 @@ def is_circular(dir): # ax.set_xticks(dfs.directions, minor=True); if rmin is not None: - ax.set_rmin(rmin) + ax.set_rmin(rmin) # type: ignore if rmax is not None: - ax.set_rmax(rmax) + ax.set_rmax(rmax) # type: ignore if add_colorbar: cbar = fig.colorbar(colorax) @@ -188,8 +194,9 @@ def is_circular(dir): return ax -def calc_m0_from_spectrum(spec, f, dir=None, tail=True): +def calc_m0_from_spectrum(spec: NDArray[np.floating], f: NDArray[np.floating] | None, dir: NDArray[np.floating] | None =None, tail:bool=True) -> NDArray[np.floating]: if f is None: + assert dir is not None nd = len(dir) dtheta = (dir[-1] - dir[0]) / (nd - 1) return np.sum(spec, axis=-1) * dtheta * np.pi / 180.0 @@ -208,7 +215,7 @@ def calc_m0_from_spectrum(spec, f, dir=None, tail=True): return m0 -def _f_to_df(f): +def _f_to_df(f: NDArray[np.floating]) -> NDArray[np.floating]: """Frequency bins for equidistant or logrithmic frequency axis""" if np.isclose(np.diff(f).min(), np.diff(f).max()): # equidistant frequency bins diff --git a/mikeio/_time.py b/mikeio/_time.py index d1e65b167..7bd8b7e83 100644 --- a/mikeio/_time.py +++ b/mikeio/_time.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime from dataclasses import dataclass -from typing import List, Iterable, Optional +from typing import List, Iterable import pandas as pd @@ -14,9 +14,7 @@ class DateTimeSelector: def isel( self, - x: Optional[ - int | Iterable[int] | str | datetime | pd.DatetimeIndex | slice - ] = None, + x: int | Iterable[int] | str | datetime | pd.DatetimeIndex | slice | None = None, ) -> List[int]: """Select time steps from a pandas DatetimeIndex diff --git a/mikeio/_track.py b/mikeio/_track.py index 4612c0e53..1f4ffc169 100644 --- a/mikeio/_track.py +++ b/mikeio/_track.py @@ -1,6 +1,5 @@ from __future__ import annotations from pathlib import Path -from datetime import datetime from typing import Callable, Sequence, Tuple import numpy as np @@ -15,8 +14,8 @@ def _extract_track( *, deletevalue: float, - start_time: datetime, - end_time: datetime, + start_time: pd.Timestamp, + end_time: pd.Timestamp, timestep: float, geometry: GeometryFM2D, track: str | Dataset | pd.DataFrame, diff --git a/mikeio/dataset/_data_plot.py b/mikeio/dataset/_data_plot.py index 06c8e64fc..04d394dc1 100644 --- a/mikeio/dataset/_data_plot.py +++ b/mikeio/dataset/_data_plot.py @@ -1,17 +1,23 @@ +from __future__ import annotations +from typing import Any, Tuple, TYPE_CHECKING + import numpy as np +from matplotlib.axes import Axes from ..spatial._FM_utils import _plot_map, _plot_vertical_profile from .._spectral import plot_2dspectrum +if TYPE_CHECKING: + from ..dataset import DataArray class _DataArrayPlotter: """Context aware plotter (sensible plotting according to geometry)""" - def __init__(self, da) -> None: + def __init__(self, da: "DataArray") -> None: self.da = da - def __call__(self, ax=None, figsize=None, **kwargs): + def __call__(self, ax:Axes | None=None, figsize=None, **kwargs): """Plot DataArray according to geometry Parameters @@ -59,7 +65,7 @@ def _get_fig_ax(ax=None, figsize=None): fig = plt.gcf() return fig, ax - def hist(self, ax=None, figsize=None, title=None, **kwargs): + def hist(self, ax: Axes| None=None, figsize:Tuple[float,float] | None=None, title: str | None=None, **kwargs: Any) -> Axes: """Plot DataArray as histogram (using ax.hist) Parameters @@ -91,7 +97,7 @@ def hist(self, ax=None, figsize=None, title=None, **kwargs): ax.set_title(title) return self._hist(ax, **kwargs) - def _hist(self, ax, **kwargs): + def _hist(self, ax: Axes, **kwargs): result = ax.hist(self.da.values.ravel(), **kwargs) ax.set_xlabel(self._label_txt()) return result diff --git a/mikeio/dataset/_data_utils.py b/mikeio/dataset/_data_utils.py index 01765cab6..a4af31b41 100644 --- a/mikeio/dataset/_data_utils.py +++ b/mikeio/dataset/_data_utils.py @@ -1,8 +1,8 @@ from __future__ import annotations +from datetime import datetime import re -from typing import Iterable, Sequence, Sized, Tuple, Union, List +from typing import Iterable, Sized, List -import numpy as np import pandas as pd from .._time import DateTimeSelector @@ -19,7 +19,7 @@ def _n_selected_timesteps(x: Sized, k: slice | Sized) -> int: return len(k) -def _get_time_idx_list(time: pd.DatetimeIndex, steps) -> Union [List[int], slice]: +def _get_time_idx_list(time: pd.DatetimeIndex, steps: int | Iterable[int] | str | datetime | pd.DatetimeIndex | slice) -> List[int] | slice: """Find list of idx in DatetimeIndex""" # indexing with a slice needs to be handled differently, since slicing returns a view @@ -31,138 +31,8 @@ def _get_time_idx_list(time: pd.DatetimeIndex, steps) -> Union [List[int], slice dts = DateTimeSelector(time) return dts.isel(steps) -# TODO this only used by DataArray, so consider to move it there -class DataUtilsMixin: - """DataArray Utils""" - - @staticmethod - def _to_safe_name(name: str) -> str: - return _to_safe_name(name) - - @staticmethod - def _time_by_agg_axis( - time: pd.DatetimeIndex, axis: int | Sequence[int] - ) -> pd.DatetimeIndex: - """New DatetimeIndex after aggregating over time axis""" - if axis == 0 or (isinstance(axis, Sequence) and 0 in axis): - time = pd.DatetimeIndex([time[0]]) - - return time - - @staticmethod - def _get_time_idx_list(time: pd.DatetimeIndex, steps): - """Find list of idx in DatetimeIndex""" - - return _get_time_idx_list(time, steps) - - @staticmethod - def _n_selected_timesteps(time, k): - return _n_selected_timesteps(time, k) - - @staticmethod - def _is_boolean_mask(x) -> bool: - if hasattr(x, "dtype"): # isinstance(x, (np.ndarray, DataArray)): - return x.dtype == np.dtype("bool") - return False - - @staticmethod - def _get_by_boolean_mask(data: np.ndarray, mask: np.ndarray) -> np.ndarray: - if data.shape != mask.shape: - return data[np.broadcast_to(mask, data.shape)] - return data[mask] - - @staticmethod - def _set_by_boolean_mask(data: np.ndarray, mask: np.ndarray, value) -> None: - if data.shape != mask.shape: - data[np.broadcast_to(mask, data.shape)] = value - else: - data[mask] = value - - @staticmethod - def _parse_time(time) -> pd.DatetimeIndex: - """Allow anything that we can create a DatetimeIndex from""" - if time is None: - time = [pd.Timestamp(2018, 1, 1)] # TODO is this the correct epoch? - if isinstance(time, str) or (not isinstance(time, Iterable)): - time = [time] - - if not isinstance(time, pd.DatetimeIndex): - index = pd.DatetimeIndex(time) - else: - index = time - - if not index.is_monotonic_increasing: - raise ValueError( - "Time must be monotonic increasing (only equal or increasing) instances." - ) - assert isinstance(index, pd.DatetimeIndex) - return index - - @staticmethod - def _parse_axis(data_shape, dims, axis) -> int | Tuple[int]: - # TODO change to return tuple always - # axis = 0 if axis == "time" else axis - if (axis == "spatial") or (axis == "space"): - if len(data_shape) == 1: - if dims[0][0] == "t": - raise ValueError(f"space axis cannot be selected from dims {dims}") - return 0 - if "frequency" in dims or "directions" in dims: - space_name = "node" if "node" in dims else "element" - return dims.index(space_name) - else: - axis = 1 if (len(data_shape) == 2) else tuple(range(1, len(data_shape))) - if axis is None: - axis = 0 if (len(data_shape) == 1) else tuple(range(0, len(data_shape))) - - if isinstance(axis, str): - axis = "time" if axis == "t" else axis - if axis in dims: - return dims.index(axis) - else: - raise ValueError( - f"axis argument '{axis}' not supported! Must be None, int, list of int or 'time' or 'space'" - ) - - return axis - - @staticmethod - def _axis_to_spatial_axis(dims, axis): - # subtract 1 if has time axis; assumes axis is integer - return axis - int(dims[0] == "time") - - @staticmethod - def _parse_interp_time(old_time, new_time): - if isinstance(new_time, pd.DatetimeIndex): - t_out_index = new_time - elif hasattr(new_time, "time"): - t_out_index = pd.DatetimeIndex(new_time.time) - else: - # offset = pd.tseries.offsets.DateOffset(seconds=new_time) # This seems identical, but doesn't work with slicing - offset = pd.Timedelta(seconds=new_time) - t_out_index = pd.date_range( - start=old_time[0], end=old_time[-1], freq=offset - ) - - return t_out_index - - @staticmethod - def _interpolate_time( - intime, - outtime, - data: np.ndarray, - method: str | int, - extrapolate: bool, - fill_value: float, - ): - from scipy.interpolate import interp1d # type: ignore - - interpolator = interp1d( - intime, - data, - axis=0, - kind=method, - bounds_error=not extrapolate, - fill_value=fill_value, - ) - return interpolator(outtime) + + + + + diff --git a/mikeio/dataset/_dataarray.py b/mikeio/dataset/_dataarray.py index 9848b8f3c..1221e3b4c 100644 --- a/mikeio/dataset/_dataarray.py +++ b/mikeio/dataset/_dataarray.py @@ -3,15 +3,33 @@ from copy import deepcopy from datetime import datetime from functools import cached_property -from typing import Iterable, Optional, Sequence, Tuple, Mapping +from typing import ( + Any, + Iterable, + Sequence, + Tuple, + Mapping, + Union, + Sized, + Literal, + TYPE_CHECKING, + overload, + MutableMapping, + Callable, +) import numpy as np +from numpy.typing import NDArray, ArrayLike, DTypeLike import pandas as pd -from mikecore.DfsuFile import DfsuFileType # type: ignore +from mikecore.DfsuFile import DfsuFileType -from ._data_utils import DataUtilsMixin from ..eum import EUMType, EUMUnit, ItemInfo +from ._data_utils import _get_time_idx_list, _n_selected_timesteps + +if TYPE_CHECKING: + from ._dataset import Dataset + import xarray from ..spatial import ( @@ -46,12 +64,28 @@ _DataArrayPlotterLineSpectrum, ) +GeometryType = Union[ + GeometryUndefined, + GeometryPoint2D, + GeometryPoint3D, + GeometryFM2D, + GeometryFM3D, + GeometryFMAreaSpectrum, + GeometryFMLineSpectrum, + GeometryFMPointSpectrum, + GeometryFMVerticalColumn, + GeometryFMVerticalProfile, + Grid1D, + Grid2D, + Grid3D, +] + class _DataArraySpectrumToHm0: def __init__(self, da: "DataArray") -> None: self.da = da - def __call__(self, tail=True): + def __call__(self, tail: bool = True) -> "DataArray": # TODO: if action_density m0 = calc_m0_from_spectrum( self.da.to_numpy(), @@ -64,7 +98,7 @@ def __call__(self, tail=True): item = ItemInfo(EUMType.Significant_wave_height) g = self.da.geometry if isinstance(g, GeometryFMPointSpectrum): - geometry = GeometryPoint2D(x=g.x, y=g.y) + geometry: Any = GeometryPoint2D(x=g.x, y=g.y) elif isinstance(g, GeometryFMLineSpectrum): geometry = Grid1D( nx=g.n_nodes, @@ -89,7 +123,7 @@ def __call__(self, tail=True): ) -class DataArray(DataUtilsMixin): +class DataArray: """DataArray with data and metadata for a single item in a dfs file The DataArray has these main properties: @@ -104,14 +138,14 @@ class DataArray(DataUtilsMixin): def __init__( self, - data, + data: NDArray[np.floating], *, - time: Optional[pd.DatetimeIndex | str] = None, - item: Optional[ItemInfo] = None, - geometry=GeometryUndefined(), - zn=None, - dims: Optional[Sequence[str]] = None, - ): + time: pd.DatetimeIndex | str | None = None, + item: ItemInfo | None = None, + geometry: GeometryType = GeometryUndefined(), + zn: NDArray[np.floating] | None = None, + dims: Sequence[str] | None = None, + ) -> None: # TODO: add optional validation validate=True self._values = self._parse_data(data) self.time: pd.DatetimeIndex = self._parse_time(time) @@ -126,7 +160,7 @@ def __init__( self.plot = self._get_plotter_by_geometry() @staticmethod - def _parse_data(data): + def _parse_data(data: ArrayLike) -> Any: # NDArray[np.floating] | float: validation_errors = [] for p in ("shape", "ndim", "dtype"): if not hasattr(data, p): @@ -138,7 +172,9 @@ def _parse_data(data): ) return data - def _parse_dims(self, dims, geometry) -> Tuple[str, ...]: + def _parse_dims( + self, dims: Sequence[str] | None, geometry: GeometryType + ) -> Tuple[str, ...]: if dims is None: return self._guess_dims(self.ndim, self.shape, self.n_timesteps, geometry) else: @@ -153,8 +189,9 @@ def _parse_dims(self, dims, geometry) -> Tuple[str, ...]: return tuple(dims) @staticmethod - def _guess_dims(ndim, shape, n_timesteps, geometry): - + def _guess_dims( + ndim: int, shape: Tuple[int, ...], n_timesteps: int, geometry: GeometryType + ) -> Tuple[str, ...]: # TODO delete default dims to geometry # This is not very robust, but is probably a reasonable guess @@ -199,24 +236,29 @@ def _guess_dims(ndim, shape, n_timesteps, geometry): dims.append("x") return tuple(dims) - def _check_time_data_length(self, time): + def _check_time_data_length(self, time: Sized) -> None: if "time" in self.dims and len(time) != self._values.shape[0]: raise ValueError( f"Number of timesteps ({len(time)}) does not fit with data shape {self.values.shape}" ) @staticmethod - def _parse_item(item) -> ItemInfo: + def _parse_item(item: ItemInfo | str | EUMType | None) -> ItemInfo: + if isinstance(item, ItemInfo): + return item + if item is None: return ItemInfo("NoName") - if not isinstance(item, ItemInfo): + if isinstance(item, (str, EUMType, EUMUnit)): return ItemInfo(item) - - return item + + raise ValueError("item must be str, EUMType or EUMUnit") @staticmethod - def _parse_geometry(geometry, dims, shape): + def _parse_geometry( + geometry: Any, dims: Tuple[str, ...], shape: Tuple[int, ...] + ) -> Any: if len(dims) > 1 and ( geometry is None or isinstance(geometry, GeometryUndefined) ): @@ -264,7 +306,9 @@ def _parse_geometry(geometry, dims, shape): return geometry @staticmethod - def _parse_zn(zn, geometry, n_timesteps): + def _parse_zn( + zn: NDArray[np.floating] | None, geometry: GeometryType, n_timesteps: int + ) -> NDArray[np.floating] | None: if zn is not None: if isinstance(geometry, _GeometryFMLayered): # TODO: np.squeeze(zn) if n_timesteps=1 ? @@ -280,11 +324,10 @@ def _parse_zn(zn, geometry, n_timesteps): raise ValueError("zn can only be provided for layered dfsu data") return zn - def _is_compatible(self, other, raise_error=False): + def _is_compatible(self, other: "DataArray", raise_error: bool = False) -> bool: """check if other DataArray has equivalent dimensions, time and geometry""" problems = [] - if not isinstance(other, DataArray): - return False + assert isinstance(other, DataArray) if self.shape != other.shape: problems.append("shape of data must be the same") if self.n_timesteps != other.n_timesteps: @@ -315,9 +358,9 @@ def _is_compatible(self, other, raise_error=False): return len(problems) == 0 - def _get_plotter_by_geometry(self): + def _get_plotter_by_geometry(self) -> Any: # TODO: this is explicit, but with consistent naming, we could create this mapping automatically - PLOTTER_MAP = { + PLOTTER_MAP: Any = { GeometryFMVerticalProfile: _DataArrayPlotterFMVerticalProfile, GeometryFMVerticalColumn: _DataArrayPlotterFMVerticalColumn, GeometryFMPointSpectrum: _DataArrayPlotterPointSpectrum, @@ -332,8 +375,16 @@ def _get_plotter_by_geometry(self): plotter = PLOTTER_MAP.get(type(self.geometry), _DataArrayPlotter) return plotter(self) - def _set_spectral_attributes(self, geometry): + def _set_spectral_attributes(self, geometry: GeometryType) -> None: if hasattr(geometry, "frequencies") and hasattr(geometry, "directions"): + assert isinstance( + geometry, + ( + GeometryFMAreaSpectrum, + GeometryFMLineSpectrum, + GeometryFMPointSpectrum, + ), + ) self.frequencies = geometry.frequencies self.n_frequencies = geometry.n_frequencies self.directions = geometry.directions @@ -345,6 +396,7 @@ def _set_spectral_attributes(self, geometry): @property def name(self) -> str: """Name of this DataArray (=da.item.name)""" + assert isinstance(self.item.name, str) return self.item.name @name.setter @@ -362,13 +414,12 @@ def unit(self) -> EUMUnit: return self.item.unit @property - def start_time(self): + def start_time(self) -> datetime: """First time instance (as datetime)""" - # TODO: use pd.Timestamp instead return self.time[0].to_pydatetime() @property - def end_time(self): + def end_time(self) -> datetime: """Last time instance (as datetime)""" # TODO: use pd.Timestamp instead return self.time[-1].to_pydatetime() @@ -381,14 +432,14 @@ def is_equidistant(self) -> bool: return len(self.time.to_series().diff().dropna().unique()) == 1 @property - def timestep(self) -> Optional[float]: + def timestep(self) -> float | None: """Time step in seconds if equidistant (and at least two time instances); otherwise None """ dt = None if len(self.time) > 1 and self.is_equidistant: - first: pd.Timestamp = self.time[0] # type: ignore - second: pd.Timestamp = self.time[1] # type: ignore + first: pd.Timestamp = self.time[0] + second: pd.Timestamp = self.time[1] dt = (second - first).total_seconds() return dt @@ -398,41 +449,42 @@ def n_timesteps(self) -> int: return len(self.time) @property - def shape(self): + def shape(self) -> Any: """Tuple of array dimensions""" return self.values.shape @property def ndim(self) -> int: """Number of array dimensions""" + assert isinstance(self.values.ndim, int) return self.values.ndim @property - def dtype(self): + def dtype(self) -> DTypeLike: """Data-type of the array elements""" return self.values.dtype @property - def values(self) -> np.ndarray: + def values(self) -> NDArray[np.floating]: """Values as a np.ndarray (equivalent to to_numpy())""" return self._values @values.setter - def values(self, value): + def values(self, value: NDArray[np.floating] | float) -> None: if np.isscalar(self._values): if not np.isscalar(value): raise ValueError("Shape of new data is wrong (should be scalar)") - elif value.shape != self._values.shape: + elif value.shape != self._values.shape: # type: ignore raise ValueError("Shape of new data is wrong") - self._values = value + self._values = value # type: ignore - def to_numpy(self) -> np.ndarray: + def to_numpy(self) -> NDArray[np.floating]: """Values as a np.ndarray (equivalent to values)""" return self._values @property - def _has_time_axis(self): + def _has_time_axis(self) -> bool: return self.dims[0][0] == "t" def dropna(self) -> "DataArray": @@ -454,14 +506,14 @@ def flipud(self) -> "DataArray": self.values = np.flip(self.values, axis=first_non_t_axis) return self - def describe(self, **kwargs) -> pd.DataFrame: + def describe(self, **kwargs: Any) -> pd.DataFrame: """Generate descriptive statistics by wrapping :py:meth:`pandas.DataFrame.describe` - + Parameters ---------- **kwargs Keyword arguments passed to :py:meth:`pandas.DataFrame.describe` - + Returns ------- pd.DataFrame @@ -509,8 +561,7 @@ def squeeze(self) -> "DataArray": # else: # raise ValueError("Invalid mask") - def __getitem__(self, key) -> "DataArray": - + def __getitem__(self, key: Any) -> "DataArray": da = self dims = self.dims key = self._getitem_parse_key(key) @@ -524,7 +575,7 @@ def __getitem__(self, key) -> "DataArray": da = da.isel(k, axis=dims[j]) return da - def _getitem_parse_key(self, key): + def _getitem_parse_key(self, key: Any) -> Any: if isinstance(key, tuple): # is it multiindex or just a tuple of indexes for first axis? # da[2,3,4] and da[(2,3,4)] both have the key=(2,3,4) @@ -551,13 +602,18 @@ def _getitem_parse_key(self, key): ) return key - def __setitem__(self, key, value): + def __setitem__(self, key: Any, value: NDArray[np.floating]) -> None: if self._is_boolean_mask(key): mask = key if isinstance(key, np.ndarray) else key.values return self._set_by_boolean_mask(self._values, mask, value) self._values[key] = value - def isel(self, idx=None, axis=0, **kwargs) -> "DataArray": + def isel( + self, + idx: int | Sequence[int] | slice | None = None, + axis: int | str = 0, + **kwargs: Any, + ) -> "DataArray": """Return a new DataArray whose data is given by integer indexing along the specified dimension(s). @@ -669,8 +725,9 @@ def isel(self, idx=None, axis=0, **kwargs) -> "DataArray": idx_slice = None if isinstance(idx, slice): idx_slice = idx + assert isinstance(axis, int) idx = list(range(*idx.indices(self.shape[axis]))) - if idx is None or (not np.isscalar(idx) and len(idx) == 0): + if idx is None or (not np.isscalar(idx) and len(idx) == 0): # type: ignore raise ValueError( "Empty index is not allowed" ) # TODO other option would be to have a NullDataArray @@ -688,7 +745,8 @@ def isel(self, idx=None, axis=0, **kwargs) -> "DataArray": geometry = GeometryUndefined() zn = None if hasattr(self.geometry, "isel"): - spatial_axis = self._axis_to_spatial_axis(self.dims, axis) + assert isinstance(axis, int) + spatial_axis = axis - 1 if self.dims[0] == "time" else axis geometry = self.geometry.isel(idx, axis=spatial_axis) # TOOD this is ugly @@ -696,7 +754,7 @@ def isel(self, idx=None, axis=0, **kwargs) -> "DataArray": node_ids, _ = self.geometry._get_nodes_and_table_for_elements( idx, node_layers="all" ) - zn = self._zn[:, node_ids] + zn = self._zn[:, node_ids] # type: ignore # reduce dims only if singleton idx dims = ( @@ -732,8 +790,8 @@ def isel(self, idx=None, axis=0, **kwargs) -> "DataArray": def sel( self, *, - time: Optional[str | pd.DatetimeIndex | "DataArray"] = None, - **kwargs, + time: str | pd.DatetimeIndex | "DataArray" | None = None, + **kwargs: Any, ) -> "DataArray": """Return a new DataArray whose data is given by selecting index labels along the specified dimension(s). @@ -841,12 +899,11 @@ def sel( """ if any([isinstance(v, slice) for v in kwargs.values()]): return self._sel_with_slice(kwargs) - + da = self # select in space if len(kwargs) > 0: - idx = self.geometry.find_index(**kwargs) if isinstance(idx, tuple): # TODO: support for dfs3 @@ -870,24 +927,24 @@ def sel( da = da[time] # __getitem__ is 🚀 return da - - def _sel_with_slice(self, kwargs: Mapping[str,slice]) -> "DataArray": + + def _sel_with_slice(self, kwargs: Mapping[str, slice]) -> "DataArray": for k, v in kwargs.items(): if isinstance(v, slice): - idx_start = self.geometry.find_index(**{k:v.start}) - idx_stop = self.geometry.find_index(**{k:v.stop}) + idx_start = self.geometry.find_index(**{k: v.start}) + idx_stop = self.geometry.find_index(**{k: v.stop}) pos = 0 if isinstance(idx_start, tuple): if k == "x": pos = 0 if k == "y": pos = 1 - + start = idx_start[pos][0] if idx_start is not None else None stop = idx_stop[pos][0] if idx_stop is not None else None idx = slice(start, stop) - + self = self.isel(idx, axis=k) return self @@ -896,13 +953,13 @@ def interp( # TODO find out optimal syntax to allow interpolation to single point, new time, grid, mesh... self, # *, # TODO: make this a keyword-only argument in the future - time: Optional[pd.DatetimeIndex | "DataArray"] = None, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, + time: pd.DatetimeIndex | "DataArray" | None = None, + x: float | None = None, + y: float | None = None, + z: float | None = None, n_nearest: int = 3, - interpolant=None, - **kwargs, + interpolant: Tuple[Any, Any] | None = None, + **kwargs: Any, ) -> "DataArray": """Interpolate data in time and space @@ -954,7 +1011,9 @@ def interp( if z is not None: raise NotImplementedError() - geometry: GeometryPoint2D | GeometryPoint3D | GeometryUndefined = GeometryUndefined() + geometry: GeometryPoint2D | GeometryPoint3D | GeometryUndefined = ( + GeometryUndefined() + ) # interp in space if (x is not None) or (y is not None) or (z is not None): @@ -990,7 +1049,7 @@ def interp( x=x, y=y, projection=self.geometry.projection ) # this is not supported yet (see above) - #else: + # else: # geometry = GeometryPoint3D( # x=x, y=y, z=z, projection=self.geometry.projection # ) @@ -1009,7 +1068,7 @@ def interp( def __dataarray_read_item_time_func( self, item: int, step: int - ) -> Tuple[np.ndarray, float]: + ) -> Tuple[NDArray[np.floating], float]: "Used by _extract_track" # Ignore item argument data = self.isel(time=step).to_numpy() @@ -1017,7 +1076,12 @@ def __dataarray_read_item_time_func( return data, time - def extract_track(self, track, method="nearest", dtype=np.float32): + def extract_track( + self, + track: pd.DataFrame, + method: Literal["nearest", "inverse_distance"] = "nearest", + dtype: DTypeLike = np.float32, + ) -> "Dataset": """ Extract data along a moving track @@ -1040,6 +1104,8 @@ def extract_track(self, track, method="nearest", dtype=np.float32): """ from .._track import _extract_track + assert self.timestep is not None + return _extract_track( deletevalue=self.deletevalue, start_time=self.start_time, @@ -1060,9 +1126,9 @@ def interp_time( self, dt: float | pd.DatetimeIndex | "DataArray", *, - method="linear", - extrapolate=True, - fill_value=np.nan, + method: str = "linear", + extrapolate: bool = True, + fill_value: float = np.nan, ) -> "DataArray": """Temporal interpolation @@ -1083,20 +1149,32 @@ def interp_time( ------- DataArray """ + from scipy.interpolate import interp1d # type: ignore + t_out_index = self._parse_interp_time(self.time, dt) t_in = self.time.values.astype(float) t_out = t_out_index.values.astype(float) - data = self._interpolate_time( - t_in, t_out, self.to_numpy(), method, extrapolate, fill_value - ) + data = interp1d( + t_in, + self.to_numpy(), + axis=0, + kind=method, + bounds_error=not extrapolate, + fill_value=fill_value, + )(t_out) zn = ( None if self._zn is None - else self._interpolate_time( - t_in, t_out, self._zn, method, extrapolate, fill_value - ) + else interp1d( + t_in, + self._zn, + axis=0, + kind=method, + bounds_error=not extrapolate, + fill_value=fill_value, + )(t_out) ) return DataArray( @@ -1107,7 +1185,7 @@ def interp_time( zn=zn, ) - def interp_na(self, axis="time", **kwargs) -> "DataArray": + def interp_na(self, axis: str = "time", **kwargs: Any) -> "DataArray": """Fill in NaNs by interpolating according to different methods. Wrapper of :py:meth:`xarray.DataArray.interpolate_na` @@ -1138,8 +1216,8 @@ def interp_na(self, axis="time", **kwargs) -> "DataArray": def interp_like( self, other: "DataArray" | Grid2D | GeometryFM2D | pd.DatetimeIndex, - interpolant=None, - **kwargs, + interpolant: Tuple[Any, Any] | None = None, + **kwargs: Any, ) -> "DataArray": """Interpolate in space (and in time) to other geometry (and time axis) @@ -1210,10 +1288,14 @@ def interp_like( if hasattr(other, "time"): dai = dai.interp_time(other.time) + assert isinstance(dai, DataArray) + return dai @staticmethod - def concat(dataarrays: Sequence["DataArray"], keep="last") -> "DataArray": + def concat( + dataarrays: Sequence["DataArray"], keep: Literal["last"] = "last" + ) -> "DataArray": """Concatenate DataArrays along the time axis Parameters @@ -1249,7 +1331,7 @@ def concat(dataarrays: Sequence["DataArray"], keep="last") -> "DataArray": # ============= Aggregation methods =========== - def max(self, axis=0, **kwargs) -> "DataArray": + def max(self, axis: int | str = 0, **kwargs: Any) -> "DataArray": """Max value along an axis Parameters @@ -1268,7 +1350,7 @@ def max(self, axis=0, **kwargs) -> "DataArray": """ return self.aggregate(axis=axis, func=np.max, **kwargs) - def min(self, axis=0, **kwargs) -> "DataArray": + def min(self, axis: int | str = 0, **kwargs: Any) -> "DataArray": """Min value along an axis Parameters @@ -1287,7 +1369,7 @@ def min(self, axis=0, **kwargs) -> "DataArray": """ return self.aggregate(axis=axis, func=np.min, **kwargs) - def mean(self, axis=0, **kwargs) -> "DataArray": + def mean(self, axis: int | str = 0, **kwargs: Any) -> "DataArray": """Mean value along an axis Parameters @@ -1306,7 +1388,7 @@ def mean(self, axis=0, **kwargs) -> "DataArray": """ return self.aggregate(axis=axis, func=np.mean, **kwargs) - def std(self, axis=0, **kwargs) -> "DataArray": + def std(self, axis: int | str = 0, **kwargs: Any) -> "DataArray": """Standard deviation values along an axis Parameters @@ -1325,7 +1407,7 @@ def std(self, axis=0, **kwargs) -> "DataArray": """ return self.aggregate(axis=axis, func=np.std, **kwargs) - def ptp(self, axis=0, **kwargs) -> "DataArray": + def ptp(self, axis: int | str = 0, **kwargs: Any) -> "DataArray": """Range (max - min) a.k.a Peak to Peak along an axis Parameters @@ -1340,7 +1422,9 @@ def ptp(self, axis=0, **kwargs) -> "DataArray": """ return self.aggregate(axis=axis, func=np.ptp, **kwargs) - def average(self, weights, axis=0, **kwargs) -> "DataArray": + def average( + self, weights: NDArray[np.floating], axis: int | str = 0, **kwargs: Any + ) -> "DataArray": """Compute the weighted average along the specified axis. Parameters @@ -1365,7 +1449,7 @@ def average(self, weights, axis=0, **kwargs) -> "DataArray": >>> da2 = da.average(axis="space", weights=area) """ - def func(x, axis, keepdims): + def func(x, axis, keepdims): # type: ignore if keepdims: raise NotImplementedError() @@ -1373,7 +1457,7 @@ def func(x, axis, keepdims): return self.aggregate(axis=axis, func=func, **kwargs) - def nanmax(self, axis=0, **kwargs) -> "DataArray": + def nanmax(self, axis: int | str = 0, **kwargs: Any) -> "DataArray": """Max value along an axis (NaN removed) Parameters @@ -1392,7 +1476,7 @@ def nanmax(self, axis=0, **kwargs) -> "DataArray": """ return self.aggregate(axis=axis, func=np.nanmax, **kwargs) - def nanmin(self, axis=0, **kwargs) -> "DataArray": + def nanmin(self, axis: int | str = 0, **kwargs: Any) -> "DataArray": """Min value along an axis (NaN removed) Parameters @@ -1411,7 +1495,7 @@ def nanmin(self, axis=0, **kwargs) -> "DataArray": """ return self.aggregate(axis=axis, func=np.nanmin, **kwargs) - def nanmean(self, axis=0, **kwargs) -> "DataArray": + def nanmean(self, axis: int | str = 0, **kwargs: Any) -> "DataArray": """Mean value along an axis (NaN removed) Parameters @@ -1430,7 +1514,7 @@ def nanmean(self, axis=0, **kwargs) -> "DataArray": """ return self.aggregate(axis=axis, func=np.nanmean, **kwargs) - def nanstd(self, axis=0, **kwargs) -> "DataArray": + def nanstd(self, axis: int | str = 0, **kwargs: Any) -> "DataArray": """Standard deviation value along an axis (NaN removed) Parameters @@ -1449,7 +1533,9 @@ def nanstd(self, axis=0, **kwargs) -> "DataArray": """ return self.aggregate(axis=axis, func=np.nanstd, **kwargs) - def aggregate(self, axis=0, func=np.nanmean, **kwargs) -> "DataArray": + def aggregate( + self, axis: int | str = 0, func: Callable[..., Any] = np.nanmean, **kwargs: Any + ) -> "DataArray": """Aggregate along an axis Parameters @@ -1473,14 +1559,12 @@ def aggregate(self, axis=0, func=np.nanmean, **kwargs) -> "DataArray": axis = self._parse_axis(self.shape, self.dims, axis) time = self._time_by_agg_axis(self.time, axis) - if isinstance(axis, int): axes = (axis,) else: axes = axis dims = tuple([d for i, d in enumerate(self.dims) if i not in axes]) - item = deepcopy(self.item) if "name" in kwargs: @@ -1507,7 +1591,17 @@ def aggregate(self, axis=0, func=np.nanmean, **kwargs) -> "DataArray": zn=zn, ) - def quantile(self, q, *, axis=0, **kwargs): + @overload + def quantile(self, q: float, **kwargs: Any) -> "DataArray": + ... + + @overload + def quantile(self, q: Sequence[float], **kwargs: Any) -> "Dataset": + ... + + def quantile( + self, q: float | Sequence[float], *, axis: int | str = 0, **kwargs: Any + ) -> "DataArray" | "Dataset": """Compute the q-th quantile of the data along the specified axis. Wrapping np.quantile @@ -1537,7 +1631,17 @@ def quantile(self, q, *, axis=0, **kwargs): """ return self._quantile(q, axis=axis, func=np.quantile, **kwargs) - def nanquantile(self, q, *, axis=0, **kwargs): + @overload + def nanquantile(self, q: float, **kwargs: Any) -> "DataArray": + ... + + @overload + def nanquantile(self, q: Sequence[float], **kwargs: Any) -> "Dataset": + ... + + def nanquantile( + self, q: float | Sequence[float], *, axis: int | str = 0, **kwargs: Any + ) -> "DataArray" | "Dataset": """Compute the q-th quantile of the data along the specified axis, while ignoring nan values. Wrapping np.nanquantile @@ -1567,11 +1671,11 @@ def nanquantile(self, q, *, axis=0, **kwargs): """ return self._quantile(q, axis=axis, func=np.nanquantile, **kwargs) - def _quantile(self, q, *, axis=0, func=np.quantile, **kwargs): - + def _quantile(self, q, *, axis: int | str = 0, func=np.quantile, **kwargs: Any): # type: ignore from mikeio import Dataset axis = self._parse_axis(self.shape, self.dims, axis) + assert isinstance(axis, int) time = self._time_by_agg_axis(self.time, axis) if np.isscalar(q): @@ -1638,7 +1742,7 @@ def __abs__(self) -> "DataArray": def _apply_unary_math_operation(self, func) -> "DataArray": try: data = func(self.values) - + except TypeError: raise TypeError("Math operation could not be applied to DataArray") @@ -1734,7 +1838,7 @@ def _to_dataset(self): {self.name: self} ) # Single-item Dataset (All info is contained in the DataArray, no need for additional info) - def to_dfs(self, filename, **kwargs) -> None: + def to_dfs(self, filename, **kwargs: Any) -> None: """Write data to a new dfs file Parameters @@ -1778,12 +1882,12 @@ def to_pandas(self) -> pd.Series: return pd.Series(data=self.to_numpy(), index=self.time, name=self.name) - def to_xarray(self): + def to_xarray(self) -> "xarray.DataArray": """Export to xarray.DataArray""" import xarray as xr - coords = {} + coords: MutableMapping[str, Any] = {} if self._has_time_axis: coords["time"] = xr.DataArray(self.time, dims="time") @@ -1825,7 +1929,6 @@ def to_xarray(self): # =============================================== def __repr__(self) -> str: - out = [""] if self.name is not None: out.append(f"name: {self.name}") @@ -1858,7 +1961,6 @@ def _geometry_txt(self) -> str: return f"geometry: {self.geometry}" def _values_txt(self) -> str: - if self.ndim == 0 or (self.ndim == 1 and len(self.values) == 1): return f"values: {self.values}" elif self.ndim == 1 and len(self.values) < 5: @@ -1868,3 +1970,107 @@ def _values_txt(self) -> str: return f"values: [{self.values[0]:0.4g}, {self.values[1]:0.4g}, ..., {self.values[-1]:0.4g}]" else: return "" # raise NotImplementedError() + + @staticmethod + def _parse_interp_time( + old_time: pd.DatetimeIndex, new_time: Any + ) -> pd.DatetimeIndex: + if isinstance(new_time, pd.DatetimeIndex): + t_out_index = new_time + elif hasattr(new_time, "time"): + t_out_index = pd.DatetimeIndex(new_time.time) + else: + # offset = pd.tseries.offsets.DateOffset(seconds=new_time) # This seems identical, but doesn't work with slicing + offset = pd.Timedelta(seconds=new_time) + t_out_index = pd.date_range( + start=old_time[0], end=old_time[-1], freq=offset + ) + + return t_out_index + + @staticmethod + def _time_by_agg_axis( + time: pd.DatetimeIndex, axis: int | Sequence[int] + ) -> pd.DatetimeIndex: + """New DatetimeIndex after aggregating over time axis""" + if axis == 0 or (isinstance(axis, Sequence) and 0 in axis): + time = pd.DatetimeIndex([time[0]]) + + return time + + @staticmethod + def _get_time_idx_list(time: pd.DatetimeIndex, steps): + """Find list of idx in DatetimeIndex""" + + return _get_time_idx_list(time, steps) + + @staticmethod + def _n_selected_timesteps(time: Sized, k: slice | Sized) -> int: + return _n_selected_timesteps(time, k) + + @staticmethod + def _is_boolean_mask(x) -> bool: + if hasattr(x, "dtype"): # isinstance(x, (np.ndarray, DataArray)): + return x.dtype == np.dtype("bool") + return False + + @staticmethod + def _get_by_boolean_mask(data: np.ndarray, mask: np.ndarray) -> np.ndarray: + if data.shape != mask.shape: + return data[np.broadcast_to(mask, data.shape)] + return data[mask] + + @staticmethod + def _set_by_boolean_mask(data: np.ndarray, mask: np.ndarray, value) -> None: + if data.shape != mask.shape: + data[np.broadcast_to(mask, data.shape)] = value + else: + data[mask] = value + + @staticmethod + def _parse_time(time) -> pd.DatetimeIndex: + """Allow anything that we can create a DatetimeIndex from""" + if time is None: + time = [pd.Timestamp(2018, 1, 1)] # TODO is this the correct epoch? + if isinstance(time, str) or (not isinstance(time, Iterable)): + time = [time] + + if not isinstance(time, pd.DatetimeIndex): + index = pd.DatetimeIndex(time) + else: + index = time + + if not index.is_monotonic_increasing: + raise ValueError( + "Time must be monotonic increasing (only equal or increasing) instances." + ) + assert isinstance(index, pd.DatetimeIndex) + return index + + @staticmethod + def _parse_axis(data_shape, dims, axis) -> int | Tuple[int]: + # TODO change to return tuple always + # axis = 0 if axis == "time" else axis + if (axis == "spatial") or (axis == "space"): + if len(data_shape) == 1: + if dims[0][0] == "t": + raise ValueError(f"space axis cannot be selected from dims {dims}") + return 0 + if "frequency" in dims or "directions" in dims: + space_name = "node" if "node" in dims else "element" + return dims.index(space_name) + else: + axis = 1 if (len(data_shape) == 2) else tuple(range(1, len(data_shape))) + if axis is None: + axis = 0 if (len(data_shape) == 1) else tuple(range(0, len(data_shape))) + + if isinstance(axis, str): + axis = "time" if axis == "t" else axis + if axis in dims: + return dims.index(axis) + else: + raise ValueError( + f"axis argument '{axis}' not supported! Must be None, int, list of int or 'time' or 'space'" + ) + + return axis diff --git a/mikeio/dataset/_dataset.py b/mikeio/dataset/_dataset.py index 75757a50f..7217bbf18 100644 --- a/mikeio/dataset/_dataset.py +++ b/mikeio/dataset/_dataset.py @@ -1,27 +1,30 @@ from __future__ import annotations from pathlib import Path import warnings -from copy import deepcopy from datetime import datetime +from copy import deepcopy from typing import ( Iterable, + Iterator, List, Mapping, - Optional, + MutableMapping, Sequence, Tuple, - MutableMapping, Any, overload, Hashable, - Set - + Set, + TYPE_CHECKING, ) import numpy as np import pandas as pd -from mikecore.DfsFile import DfsSimpleType # type: ignore +from mikecore.DfsFile import DfsSimpleType + +if TYPE_CHECKING: + import xarray from ._dataarray import DataArray from ._data_utils import _to_safe_name, _get_time_idx_list, _n_selected_timesteps @@ -41,7 +44,7 @@ from ._data_plot import _DatasetPlotter -class Dataset(MutableMapping): +class Dataset: """Dataset containing one or more DataArrays with common geometry and time Most often obtained by reading a dfs file. But can also be @@ -84,7 +87,7 @@ class Dataset(MutableMapping): def __init__( self, - data: Mapping[str, DataArray] | Iterable[DataArray] | Sequence[np.ndarray], + data: Mapping[str, DataArray] | Sequence[DataArray] | Sequence[np.ndarray], time=None, items=None, geometry=None, @@ -96,11 +99,11 @@ def __init__( data = self._create_dataarrays( data=data, time=time, items=items, geometry=geometry, zn=zn, dims=dims ) # type: ignore - self._data_vars = self._init_from_DataArrays(data, validate=validate) + self._data_vars = self._init_from_DataArrays(data, validate=validate) # type: ignore self.plot = _DatasetPlotter(self) @staticmethod - def _is_DataArrays(data): + def _is_DataArrays(data: Any) -> bool: """Check if input is Sequence/Mapping of DataArrays""" if isinstance(data, (Dataset, DataArray)): return True @@ -125,7 +128,7 @@ def _create_dataarrays( geometry=None, zn=None, dims=None, - ): + ) -> Mapping[str, DataArray]: if not isinstance(data, Iterable): data = [data] items = Dataset._parse_items(items, len(data)) @@ -138,7 +141,9 @@ def _create_dataarrays( ) return data_vars - def _init_from_DataArrays(self, data, validate=True) -> MutableMapping[str, DataArray]: + def _init_from_DataArrays( + self, data: Sequence[DataArray] | Mapping[str, DataArray], validate: bool = True + ) -> MutableMapping[str, DataArray]: """Initialize Dataset object with Iterable of DataArrays""" data_vars = self._DataArrays_as_mapping(data) @@ -147,7 +152,7 @@ def _init_from_DataArrays(self, data, validate=True) -> MutableMapping[str, Data for da in data_vars.values(): first._is_compatible(da, raise_error=True) - self._check_all_different_ids(data_vars.values()) + self._check_all_different_ids(list(data_vars.values())) # TODO is it necessary to keep track of item names? self.__itemattr: Set[str] = set() @@ -162,30 +167,29 @@ def values(self): "Dataset has no property 'values' - use to_numpy() instead or maybe you were looking for DataArray.values?" ) - # remove values and keys from dir to avoid confusion - def __dir__(self): - keys = sorted(list(super().__dir__()) + list(self.__dict__.keys())) - return set([d for d in keys if d not in ("values", "keys")]) - @staticmethod def _modify_list(lst: Iterable[str]) -> List[str]: modified_list = [] count_dict = {} - + for item in lst: if item not in count_dict: modified_list.append(item) count_dict[item] = 2 else: - warnings.warn(f"Duplicate item name: {item}. Renaming to {item}_{count_dict[item]}") + warnings.warn( + f"Duplicate item name: {item}. Renaming to {item}_{count_dict[item]}" + ) modified_item = f"{item}_{count_dict[item]}" modified_list.append(modified_item) count_dict[item] += 1 - + return modified_list @staticmethod - def _parse_items(items, n_items_data): + def _parse_items( + items: None | Sequence[ItemInfo | EUMType | str], n_items_data: int + ) -> List[ItemInfo]: if items is None: # default Undefined items item_infos = [ItemInfo(f"Item_{j+1}") for j in range(n_items_data)] @@ -208,20 +212,20 @@ def _parse_items(items, n_items_data): item_names = Dataset._modify_list(item_names) for it, item_name in zip(item_infos, item_names): it.name = item_name - + return item_infos @staticmethod - def _DataArrays_as_mapping(data): + def _DataArrays_as_mapping( + data: DataArray | Sequence[DataArray] | Mapping[str, DataArray] + ) -> MutableMapping[str, DataArray]: """Create dict of DataArrays if necessary""" if isinstance(data, Mapping): - if isinstance(data, Dataset): - return data - data = Dataset._validate_item_names_and_keys( + data_vars = Dataset._validate_item_names_and_keys( data ) # TODO is this necessary? - _ = Dataset._unique_item_names(data.values()) - return data + _ = Dataset._unique_item_names(data_vars.values()) + return data_vars if isinstance(data, DataArray): data = [data] @@ -251,7 +255,7 @@ def _unique_item_names(das: Sequence[DataArray]) -> List[str]: return item_names @staticmethod - def _check_all_different_ids(das): + def _check_all_different_ids(das: Sequence[DataArray]) -> None: """Are all the DataArrays different objects or are some referring to the same""" ids = np.zeros(len(das), dtype=np.int64) ids_val = np.zeros(len(das), dtype=np.int64) @@ -277,7 +281,7 @@ def _check_all_different_ids(das): Dataset._id_of_DataArrays_equal(das[jj[0]], das[jj[1]]) @staticmethod - def _id_of_DataArrays_equal(da1, da2): + def _id_of_DataArrays_equal(da1: DataArray, da2: DataArray) -> None: """Check if two DataArrays are actually the same object""" if id(da1) == id(da2): raise ValueError( @@ -288,7 +292,7 @@ def _id_of_DataArrays_equal(da1, da2): f"DataArrays {da1.name} and {da2.name} refer to the same data! Create a copy first." ) - def _check_already_present(self, new_da): + def _check_already_present(self, new_da: DataArray) -> None: """Is the DataArray already present in the Dataset?""" for da in self: self._id_of_DataArrays_equal(da, new_da) @@ -303,24 +307,24 @@ def time(self) -> pd.DatetimeIndex: return list(self)[0].time @time.setter - def time(self, new_time): + def time(self, new_time) -> None: for da in self: da.time = new_time @property - def start_time(self): + def start_time(self) -> datetime: """First time instance (as datetime)""" # TODO: use pd.Timestamp instead return self.time[0].to_pydatetime() @property - def end_time(self): + def end_time(self) -> datetime: """Last time instance (as datetime)""" # TODO: use pd.Timestamp instead return self.time[-1].to_pydatetime() @property - def timestep(self) -> Optional[float]: + def timestep(self) -> float | None: """Time step in seconds if equidistant (and at least two time instances); otherwise None """ @@ -351,7 +355,7 @@ def n_timesteps(self) -> int: return len(self.time) @property - def items(self): + def items(self) -> List[ItemInfo]: """ItemInfo for each of the DataArrays as a list""" return [x.item for x in self] @@ -361,7 +365,7 @@ def n_items(self) -> int: return len(self._data_vars) @property - def names(self): + def names(self) -> List[str]: """Name of each of the DataArrays as a list""" return [da.name for da in self] @@ -379,29 +383,29 @@ def dims(self): return self[0].dims @property - def shape(self): + def shape(self) -> Tuple[int, ...]: """Shape of each DataArray""" return self[0].shape @property - def deletevalue(self): + def deletevalue(self) -> float: """File delete value""" return self[0].deletevalue @property - def geometry(self): + def geometry(self) -> Any: """Geometry of each DataArray""" return self[0].geometry @property - def _zn(self) -> np.ndarray: + def _zn(self) -> np.ndarray | None: return self[0]._zn # TODO: remove this @property def n_elements(self) -> int: """Number of spatial elements/points""" - n_elem = np.prod(self.shape) + n_elem = int(np.prod(self.shape)) if self.n_timesteps > 1: n_elem = int(n_elem / self.n_timesteps) return n_elem @@ -473,16 +477,16 @@ def create_data_array(self, data, item=None) -> DataArray: # TODO: delete this? @staticmethod - def create_empty_data(n_items=1, n_timesteps=1, n_elements=None, shape=None): + def create_empty_data(n_items: int = 1, n_timesteps: int = 1, n_elements: int | None = None, shape: Tuple[int, ...] | None = None): # type: ignore data = [] if shape is None: if n_elements is None: raise ValueError("n_elements and shape cannot both be None") else: - shape = n_elements + shape = n_elements # type: ignore if np.isscalar(shape): - shape = [shape] - dati = np.empty(shape=(n_timesteps, *shape)) + shape = [shape] # type: ignore + dati = np.empty(shape=(n_timesteps, *shape)) # type: ignore dati[:] = np.nan for _ in range(n_items): data.append(dati.copy()) @@ -490,23 +494,16 @@ def create_empty_data(n_items=1, n_timesteps=1, n_elements=None, shape=None): # ============= Dataset is (almost) a MutableMapping =========== - def __len__(self): + def __len__(self) -> int: return len(self._data_vars) - def __iter__(self): + def __iter__(self) -> Iterator[DataArray]: yield from self._data_vars.values() - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: self.__set_or_insert_item(key, value, insert=False) - def __set_or_insert_item(self, key, value, insert=False): - if not isinstance(value, DataArray): - try: - value = DataArray(value) - # TODO: warn that this is not the preferred way! - except TypeError: - raise ValueError("Input could not be interpreted as a DataArray") - + def __set_or_insert_item(self, key, value: DataArray, insert=False) -> None: if len(self) > 0: self[0]._is_compatible(value) @@ -547,7 +544,7 @@ def __set_or_insert_item(self, key, value, insert=False): self._data_vars[key] = value self._set_name_attr(key, value) - def insert(self, key, value: DataArray): + def insert(self, key: int, value: DataArray) -> None: """Insert DataArray in a specific position Parameters @@ -560,12 +557,7 @@ def insert(self, key, value: DataArray): """ self.__set_or_insert_item(key, value, insert=True) - if isinstance(key, slice): - s = self.time.slice_indexer(key.start, key.stop) - time_steps = list(range(s.start, s.stop)) - return self.isel(time_steps, axis=0) - - def remove(self, key: int | str): + def remove(self, key: int | str) -> None: """Remove DataArray from Dataset Parameters @@ -579,16 +571,7 @@ def remove(self, key: int | str): """ self.__delitem__(key) - def popitem(self): - """Pop first DataArray from Dataset - - See also - -------- - pop - """ - return self.pop(0) - - def rename(self, mapper: Mapping[str, str], inplace=False): + def rename(self, mapper: Mapping[str, str], inplace=False) -> "Dataset": """Rename items (DataArrays) in Dataset Parameters @@ -623,13 +606,13 @@ def rename(self, mapper: Mapping[str, str], inplace=False): return ds - def _set_name_attr(self, name: str, value: DataArray): + def _set_name_attr(self, name: str, value: DataArray) -> None: name = _to_safe_name(name) if name not in self.__itemattr: self.__itemattr.add(name) # keep track of what we insert setattr(self, name, value) - def _del_name_attr(self, name: str): + def _del_name_attr(self, name: str) -> None: name = _to_safe_name(name) if name in self.__itemattr: self.__itemattr.remove(name) @@ -644,9 +627,7 @@ def __getitem__(self, key: Hashable | int) -> DataArray: def __getitem__(self, key: Iterable[Hashable]) -> "Dataset": ... - def __getitem__(self, key) -> DataArray | "Dataset": - # select time steps if ( isinstance(key, Sequence) and not isinstance(key, str) @@ -699,7 +680,7 @@ def __getitem__(self, key) -> DataArray | "Dataset": raise TypeError(f"indexing with a {type(key)} is not (yet) supported") - def _is_slice_time_slice(self, s): + def _is_slice_time_slice(self, s: slice) -> bool: if (s.start is None) and (s.stop is None): return False if s.start is not None: @@ -710,7 +691,11 @@ def _is_slice_time_slice(self, s): return False return True - def _is_key_time(self, key): + def _is_key_time(self, key): # type: ignore + if isinstance(key, slice): + return False + if isinstance(key, (int, float)): + return False if isinstance(key, str) and key in self.names: return False if isinstance(key, str) and len(key) > 0 and key[0].isnumeric(): @@ -718,9 +703,10 @@ def _is_key_time(self, key): return True if isinstance(key, (datetime, np.datetime64, pd.Timestamp)): return True + return False - def _multi_indexing_attempted(self, key) -> bool: + def _multi_indexing_attempted(self, key: Any) -> bool: # find out if user is attempting ds[2, :, 1] or similar (not allowed) # this is not bullet-proof, but a good estimate if not isinstance(key, tuple): @@ -758,15 +744,14 @@ def _key_to_str(self, key: Any) -> Any: return key.name raise TypeError(f"indexing with type {type(key)} is not supported") - def __delitem__(self, key): - + def __delitem__(self, key: Hashable | int) -> None: key = self._key_to_str(key) self._data_vars.__delitem__(key) self._del_name_attr(key) # ============ select/interp ============= - def isel(self, idx=None, axis=0, **kwargs): + def isel(self, idx=None, axis=0, **kwargs) -> "Dataset": """Return a new Dataset whose data is given by integer indexing along the specified dimension(s). @@ -883,10 +868,10 @@ def sel( def interp( self, *, - time: Optional[pd.DatetimeIndex | "DataArray"] = None, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, + time: pd.DatetimeIndex | "DataArray" | None = None, + x: float | None = None, + y: float | None = None, + z: float | None = None, n_nearest: int = 3, **kwargs, ) -> "Dataset": @@ -947,7 +932,6 @@ def interp( if isinstance( self.geometry, GeometryFM2D ): # TODO remove this when all geometries implements the same method - interpolant = self.geometry.get_2d_interpolant( xy, n_nearest=n_nearest, **kwargs ) @@ -974,7 +958,7 @@ def __dataset_read_item_time_func( return data, time - def extract_track(self, track, method="nearest", dtype=np.float32): + def extract_track(self, track, method="nearest", dtype=np.float32) -> "Dataset": """ Extract data along a moving track @@ -1000,6 +984,10 @@ def extract_track(self, track, method="nearest", dtype=np.float32): item_numbers = list(range(self.n_items)) time_steps = list(range(self.n_timesteps)) + assert self.start_time is not None + assert self.end_time is not None + assert self.timestep is not None + return _extract_track( deletevalue=self.deletevalue, start_time=self.start_time, @@ -1018,12 +1006,12 @@ def extract_track(self, track, method="nearest", dtype=np.float32): def interp_time( self, - dt: Optional[float | pd.DatetimeIndex | "Dataset" | DataArray] = None, + dt: float | pd.DatetimeIndex | "Dataset" | DataArray | None = None, *, - freq: Optional[str] = None, - method="linear", - extrapolate=True, - fill_value=np.nan, + freq: str | None = None, + method: str = "linear", + extrapolate: bool = True, + fill_value: float = np.nan, ) -> "Dataset": """Temporal interpolation @@ -1159,9 +1147,12 @@ def interp_like( # ============= Combine/concat =========== - def _append_items(self, other, copy=True): + def _append_items( + self, other: DataArray | "Dataset", copy: bool = True + ) -> "Dataset": if isinstance(other, DataArray): other = other._to_dataset() + assert isinstance(other, Dataset) item_names = {item.name for item in self.items} other_names = {item.name for item in other.items} @@ -1278,7 +1269,7 @@ def _concat_time(self, other, copy=True) -> "Dataset": newdata, time=newtime, items=ds.items, geometry=ds.geometry, zn=zn ) - def _check_all_items_match(self, other): + def _check_all_items_match(self, other: "Dataset") -> None: if self.n_items != other.n_items: raise ValueError( f"Number of items must match ({self.n_items} and {other.n_items})" @@ -1299,9 +1290,7 @@ def _check_all_items_match(self, other): # ============ aggregate ============= - def aggregate( - self, axis=0, func=np.nanmean, **kwargs - ) -> "Dataset": + def aggregate(self, axis=0, func=np.nanmean, **kwargs) -> "Dataset": """Aggregate along an axis Parameters @@ -1331,7 +1320,7 @@ def aggregate( dims=self.dims, zn=self._zn, ) - + return Dataset([da], validate=False) else: res = { @@ -1341,7 +1330,7 @@ def aggregate( return Dataset(data=res, validate=False) @staticmethod - def _agg_item_from_items(items, name): + def _agg_item_from_items(items: Sequence[ItemInfo], name: str) -> ItemInfo: it_type = ( items[0].type if all([it.type == items[0].type for it in items]) @@ -1410,10 +1399,7 @@ def nanquantile(self, q, *, axis=0, **kwargs) -> "Dataset": """ return self._quantile(q, axis=axis, func=np.nanquantile, **kwargs) - def _quantile( - self, q, *, axis=0, func=np.quantile, **kwargs - ) -> "Dataset": - + def _quantile(self, q, *, axis=0, func=np.quantile, **kwargs) -> "Dataset": if axis == "items": if self.n_items <= 1: return self # or raise ValueError? @@ -1689,7 +1675,7 @@ def _add_dataset(self, other, sign=1.0) -> "Dataset": newds[j].values = data[j] # type: ignore return newds - def _check_datasets_match(self, other): + def _check_datasets_match(self, other: "Dataset") -> None: if self.n_items != other.n_items: raise ValueError( f"Number of items must match ({self.n_items} and {other.n_items})" @@ -1805,7 +1791,6 @@ def to_dfs(self, filename, **kwargs): if isinstance( self.geometry, (GeometryPoint2D, GeometryPoint3D, GeometryUndefined) ): - if self.ndim == 0: # Not very common, but still... self._validate_extension(filename, ".dfs0") self._to_dfs0(filename, **kwargs) @@ -1833,42 +1818,42 @@ def to_dfs(self, filename, **kwargs): ) @staticmethod - def _validate_extension(filename, valid_extension): + def _validate_extension(filename: str | Path, valid_extension: str) -> None: path = Path(filename) ext = path.suffix.lower() if ext != valid_extension: raise ValueError(f"File extension must be {valid_extension}") - def _to_dfs0(self, filename, **kwargs): + def _to_dfs0(self, filename: str | Path, **kwargs: Any) -> None: from ..dfs._dfs0 import _write_dfs0 dtype = kwargs.get("dtype", DfsSimpleType.Float) _write_dfs0(filename, self, dtype=dtype) - def _to_dfs2(self, filename): + def _to_dfs2(self, filename: str | Path) -> None: # assumes Grid2D geometry from ..dfs._dfs2 import write_dfs2 write_dfs2(filename, self) - def _to_dfs3(self, filename): + def _to_dfs3(self, filename: str | Path) -> None: # assumes Grid3D geometry from ..dfs._dfs3 import write_dfs3 write_dfs3(filename, self) - def _to_dfs1(self, filename): + def _to_dfs1(self, filename: str | Path) -> None: from ..dfs._dfs1 import write_dfs1 - write_dfs1(filename=filename,ds=self) + write_dfs1(filename=filename, ds=self) - def _to_dfsu(self, filename): + def _to_dfsu(self, filename: str | Path) -> None: from ..dfsu._dfsu import _write_dfsu _write_dfsu(filename, self) - def to_xarray(self): + def to_xarray(self) -> "xarray.Dataset": """Export to xarray.Dataset""" import xarray diff --git a/mikeio/dfs/_dfs.py b/mikeio/dfs/_dfs.py index 7cfab977c..effba0852 100644 --- a/mikeio/dfs/_dfs.py +++ b/mikeio/dfs/_dfs.py @@ -3,7 +3,7 @@ from abc import abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import List, Optional, Tuple, Sequence +from typing import List, Tuple, Sequence import numpy as np import pandas as pd from tqdm import tqdm, trange @@ -26,9 +26,9 @@ from ..spatial import GeometryUndefined from .._time import DateTimeSelector + @dataclass class DfsHeader: - n_items: int n_timesteps: int start_time: datetime @@ -82,7 +82,7 @@ def _fuzzy_item_search( def _valid_item_numbers( dfsItemInfo: List[DfsDynamicItemInfo], - items: Optional[str | int | List[int] | List[str]] = None, + items: str | int | Sequence[int | str] | None = None, ignore_first: bool = False, ) -> List[int]: start_idx = 1 if ignore_first else 0 @@ -127,7 +127,6 @@ def _valid_item_numbers( def _valid_timesteps(dfsFileInfo: DfsFileInfo, time_steps) -> Tuple[bool, List[int]]: - time_axis = dfsFileInfo.TimeAxis single_time_selected = False @@ -149,10 +148,11 @@ def _valid_timesteps(dfsFileInfo: DfsFileInfo, time_steps) -> Tuple[bool, List[i time_step_file = time_axis.TimeStep if time_step_file <= 0: - if nt > 1: - raise ValueError(f"Time step must be a positive number. Time step in the file is {time_step_file} seconds.") - + raise ValueError( + f"Time step must be a positive number. Time step in the file is {time_step_file} seconds." + ) + warnings.warn( f"Time step is {time_step_file} seconds. This must be a positive number. Setting to 1 second." ) @@ -214,7 +214,7 @@ def _item_numbers_by_name( def _get_item_info( dfsItemInfo: List[DfsDynamicItemInfo], - item_numbers: Optional[List[int]] = None, + item_numbers: List[int] | None = None, ignore_first: bool = False, ) -> ItemInfoList: """Read DFS ItemInfo for specific item numbers @@ -243,7 +243,6 @@ def _get_item_info( def _write_dfs_data(*, dfs: DfsFile, ds: Dataset, n_spatial_dims: int) -> None: - deletevalue = dfs.FileInfo.DeleteValueFloat # ds.deletevalue has_no_time = "time" not in ds.dims if ds.is_equidistant: @@ -253,7 +252,6 @@ def _write_dfs_data(*, dfs: DfsFile, ds: Dataset, n_spatial_dims: int) -> None: for i in range(ds.n_timesteps): for item in range(ds.n_items): - if has_no_time: d = ds[item].values else: @@ -346,7 +344,6 @@ def read( for i, it in enumerate(tqdm(time_steps, disable=not self.show_progress)): for item in range(n_items): - itemdata = self._dfs.ReadItemTimeStep(item_numbers[item] + 1, int(it)) src = itemdata.Data @@ -410,7 +407,6 @@ def _write( title, keep_open=False, ): - assert isinstance(ds, Dataset) neq_datetimes = None @@ -421,7 +417,6 @@ def _write( title=title, data=ds, dt=dt, coordinate=coordinate ) - shape = np.shape(data[0]) t_offset = 0 if len(shape) == self._ndim else 1 @@ -464,7 +459,6 @@ def _write( for i in trange(header.n_timesteps, disable=not self.show_progress): for item in range(header.n_items): - d = data[item][i] if t_offset == 1 else data[item] d = d.copy() # to avoid modifying the input d[np.isnan(d)] = deletevalue @@ -482,19 +476,17 @@ def _write( return self def append(self, data: Dataset) -> None: - warnings.warn(FutureWarning("append() is deprecated.")) if not data.dims == ("time", "y", "x"): raise NotImplementedError( - "Append is only available for 2D files with dims ('time', 'y', 'x')" - ) + "Append is only available for 2D files with dims ('time', 'y', 'x')" + ) deletevalue = self._dfs.FileInfo.DeleteValueFloat # -1.0000000031710769e-30 for i in trange(data.n_timesteps, disable=not self.show_progress): for da in data: - values = da.to_numpy() d = values[i] d = d.copy() # to avoid modifying the input @@ -503,8 +495,7 @@ def append(self, data: Dataset) -> None: d = d.reshape(data.shape[1:]) darray = d.reshape(d.size, 1)[:, 0] self._dfs.WriteItemTimeStepNext(0, darray.astype(np.float32)) - - + def __enter__(self): return self @@ -515,8 +506,9 @@ def close(self): "Finalize write for a dfs file opened with `write(...,keep_open=True)`" self._dfs.Close() - def _write_handle_common_arguments(self, *, title: Optional[str], data: Dataset, coordinate, dt: Optional[float] = None): - + def _write_handle_common_arguments( + self, *, title: str | None, data: Dataset, coordinate, dt: float | None = None + ): if title is None: self._title = "" @@ -544,7 +536,9 @@ def _write_handle_common_arguments(self, *, title: Optional[str], data: Dataset, else: self._override_coordinates = True - assert isinstance(data, Dataset), "data must be supplied in the form of a mikeio.Dataset" + assert isinstance( + data, Dataset + ), "data must be supplied in the form of a mikeio.Dataset" items = data.items start_time = data.time[0] @@ -558,14 +552,17 @@ def _write_handle_common_arguments(self, *, title: Optional[str], data: Dataset, if n_timesteps > 1: warnings.warn("No timestep supplied. Using 1s.") - if items is None: - items = [ItemInfo(f"Item {i+1}") for i in range(self._n_items)] - - header = DfsHeader(n_items=n_items, n_timesteps=n_timesteps, dt=dt, start_time=start_time, coordinates=coordinate, items=items) + header = DfsHeader( + n_items=n_items, + n_timesteps=n_timesteps, + dt=dt, + start_time=start_time, + coordinates=coordinate, + items=items, + ) return header, data def _setup_header(self, filename: str, header: DfsHeader): - system_start_time = header.start_time self._builder.SetDataType(0) diff --git a/mikeio/dfs/_dfs1.py b/mikeio/dfs/_dfs1.py index 5534832fb..45560f97d 100644 --- a/mikeio/dfs/_dfs1.py +++ b/mikeio/dfs/_dfs1.py @@ -1,3 +1,4 @@ +from __future__ import annotations from pathlib import Path from mikecore.DfsFactory import DfsBuilder, DfsFactory @@ -15,12 +16,12 @@ from ..spatial import Grid1D -def write_dfs1(filename: str, ds: Dataset, title="") -> None: +def write_dfs1(filename: str | Path, ds: Dataset, title="") -> None: dfs = _write_dfs1_header(filename, ds, title) _write_dfs_data(dfs=dfs, ds=ds, n_spatial_dims=1) -def _write_dfs1_header(filename, ds: Dataset, title="") -> DfsFile: +def _write_dfs1_header(filename: str | Path, ds: Dataset, title="") -> DfsFile: builder = DfsBuilder.Create(title, "mikeio", __dfs_version__) builder.SetDataType(0) @@ -54,7 +55,7 @@ def _write_dfs1_header(filename, ds: Dataset, title="") -> DfsFile: ) try: - builder.CreateFile(filename) + builder.CreateFile(str(filename)) except IOError: print("cannot create dfs file: ", filename) diff --git a/mikeio/dfs/_dfs2.py b/mikeio/dfs/_dfs2.py index 40730ff76..278556add 100644 --- a/mikeio/dfs/_dfs2.py +++ b/mikeio/dfs/_dfs2.py @@ -1,5 +1,7 @@ +from __future__ import annotations from copy import deepcopy -from typing import List, Tuple, Optional +from pathlib import Path +from typing import List, Tuple import warnings import numpy as np @@ -25,12 +27,12 @@ from ..spatial import Grid2D -def write_dfs2(filename: str, ds: Dataset, title="") -> None: +def write_dfs2(filename: str | Path, ds: Dataset, title="") -> None: dfs = _write_dfs2_header(filename, ds, title) _write_dfs_data(dfs=dfs, ds=ds, n_spatial_dims=2) -def _write_dfs2_header(filename, ds: Dataset, title="") -> DfsFile: +def _write_dfs2_header(filename: str | Path, ds: Dataset, title="") -> DfsFile: builder = DfsBuilder.Create(title, "mikeio", __dfs_version__) builder.SetDataType(0) @@ -82,7 +84,7 @@ def _write_dfs2_header(filename, ds: Dataset, title="") -> DfsFile: ) try: - builder.CreateFile(filename) + builder.CreateFile(str(filename)) except IOError: print("cannot create dfs file: ", filename) @@ -104,7 +106,6 @@ def _write_dfs2_spatial_axis(builder, factory, geometry): class Dfs2(_Dfs123): - _ndim = 2 def __init__(self, filename=None, type: str = "horizontal"): @@ -239,7 +240,6 @@ def read( for i, it in enumerate(tqdm(time_steps, disable=not self.show_progress)): for item in range(n_items): - itemdata = self._dfs.ReadItemTimeStep(item_numbers[item] + 1, int(it)) d = itemdata.Data @@ -284,13 +284,14 @@ def write( self, filename: str, data: Dataset, - dt: Optional[float] = None, - title: Optional[str]=None, - keep_open: bool =False, + dt: float | None = None, + title: str | None = None, + keep_open: bool = False, ): - # this method is deprecated - warnings.warn(FutureWarning("Dfs2.write() is deprecated, use Dataset.to_dfs() instead")) + warnings.warn( + FutureWarning("Dfs2.write() is deprecated, use Dataset.to_dfs() instead") + ) filename = str(filename) diff --git a/mikeio/dfs/_dfs3.py b/mikeio/dfs/_dfs3.py index 9fecd3576..3e7220c66 100644 --- a/mikeio/dfs/_dfs3.py +++ b/mikeio/dfs/_dfs3.py @@ -1,14 +1,16 @@ +from __future__ import annotations +from pathlib import Path from typing import Tuple import numpy as np import pandas as pd -from mikecore.DfsBuilder import DfsBuilder # type: ignore -from mikecore.DfsFactory import DfsFactory # type: ignore -from mikecore.DfsFile import DfsFile, DfsSimpleType # type: ignore -from mikecore.DfsFileFactory import DfsFileFactory # type: ignore -from mikecore.eum import eumQuantity, eumUnit # type: ignore -from mikecore.Projections import Cartography # type: ignore +from mikecore.DfsBuilder import DfsBuilder +from mikecore.DfsFactory import DfsFactory +from mikecore.DfsFile import DfsFile, DfsSimpleType +from mikecore.DfsFileFactory import DfsFileFactory +from mikecore.eum import eumQuantity, eumUnit +from mikecore.Projections import Cartography from .. import __dfs_version__ from ..dataset import Dataset @@ -23,12 +25,12 @@ from ..spatial import Grid3D -def write_dfs3(filename: str, ds: Dataset, title="") -> None: +def write_dfs3(filename: str| Path, ds: Dataset, title="") -> None: dfs = _write_dfs3_header(filename, ds, title) _write_dfs_data(dfs=dfs, ds=ds, n_spatial_dims=3) -def _write_dfs3_header(filename, ds: Dataset, title="") -> DfsFile: +def _write_dfs3_header(filename: str | Path, ds: Dataset, title="") -> DfsFile: builder = DfsBuilder.Create(title, "mikeio", __dfs_version__) builder.SetDataType(0) @@ -75,7 +77,7 @@ def _write_dfs3_header(filename, ds: Dataset, title="") -> DfsFile: ) try: - builder.CreateFile(filename) + builder.CreateFile(str(filename)) except IOError: print("cannot create dfs file: ", filename) diff --git a/mikeio/dfsu/_dfsu.py b/mikeio/dfsu/_dfsu.py index 7c89fd84c..ade38e05f 100644 --- a/mikeio/dfsu/_dfsu.py +++ b/mikeio/dfsu/_dfsu.py @@ -4,7 +4,7 @@ import warnings from datetime import datetime, timedelta from functools import wraps -from typing import Collection, List, Optional, Tuple +from typing import Collection, List, Tuple import numpy as np import pandas as pd @@ -40,8 +40,7 @@ from .._track import _extract_track -def _write_dfsu(filename: str, data: Dataset): - +def _write_dfsu(filename: str | Path, data: Dataset) -> None: filename = str(filename) if len(data.time) == 1: @@ -91,6 +90,7 @@ def _write_dfsu(filename: str, data: Dataset): for i in range(n_time_steps): if geometry.is_layered: if "time" in data.dims: + assert data._zn is not None zn = data._zn[i] else: zn = data._zn @@ -572,7 +572,6 @@ def plot( ax=None, add_colorbar=True, ): - warnings.warn( FutureWarning( "Dfsu.plot() have been deprecated, please use DataArray.plot() instead" @@ -696,7 +695,7 @@ def _read( *, items=None, time=None, - elements: Optional[Collection[int]] = None, + elements: Collection[int] | None = None, area=None, x=None, y=None, @@ -705,7 +704,6 @@ def _read( error_bad_data=True, fill_bad_data_value=np.nan, ) -> Dataset: - if dtype not in [np.float32, np.float64]: raise ValueError("Invalid data type. Choose np.float32 or np.float64") @@ -761,7 +759,6 @@ def _read( for i in trange(n_steps, disable=not self.show_progress): it = time_steps[i] for item in range(n_items): - dfs, d = _read_item_time_step( dfs=dfs, filename=self._filename, @@ -1182,7 +1179,6 @@ def _write( return self except Exception as e: - print(e) self._dfs.Close() os.remove(filename) @@ -1244,13 +1240,12 @@ def to_mesh(self, outfilename): class Dfsu2DH(_Dfsu): - def read( self, *, items=None, time=None, - elements: Optional[Collection[int]] = None, + elements: Collection[int] | None = None, area=None, x=None, y=None, @@ -1291,17 +1286,19 @@ def read( Dataset A Dataset with data dimensions [t,elements] """ - - return self._read(items=items, - time=time, - elements=elements, - area=area, - x=x, - y=y, - keepdims=keepdims, - dtype=dtype, - error_bad_data=error_bad_data, - fill_bad_data_value=fill_bad_data_value) + + return self._read( + items=items, + time=time, + elements=elements, + area=area, + x=x, + y=y, + keepdims=keepdims, + dtype=dtype, + error_bad_data=error_bad_data, + fill_bad_data_value=fill_bad_data_value, + ) def _dfs_read_item_time_func(self, item: int, step: int): dfs = DfsuFile.Open(self._filename) diff --git a/mikeio/dfsu/_layered.py b/mikeio/dfsu/_layered.py index 7a320f8d5..289b1ebb1 100644 --- a/mikeio/dfsu/_layered.py +++ b/mikeio/dfsu/_layered.py @@ -1,4 +1,5 @@ -from typing import Collection, Optional +from __future__ import annotations +from typing import Collection from functools import wraps import numpy as np @@ -96,7 +97,7 @@ def read( *, items=None, time=None, - elements: Optional[Collection[int]] = None, + elements: Collection[int] | None = None, area=None, x=None, y=None, @@ -209,7 +210,6 @@ def read( for i in trange(n_steps, disable=not self.show_progress): it = time_steps[i] for item in range(n_items): - dfs, d = _read_item_time_step( dfs=dfs, filename=self._filename, diff --git a/mikeio/eum/_eum.py b/mikeio/eum/_eum.py index 9cb01b1f0..6556bbb68 100644 --- a/mikeio/eum/_eum.py +++ b/mikeio/eum/_eum.py @@ -19,7 +19,7 @@ from __future__ import annotations import warnings from enum import IntEnum -from typing import Dict, List, Sequence +from typing import Dict, List, Sequence, Literal import pandas as pd from mikecore.DfsFile import DataValueType, DfsDynamicItemInfo @@ -1429,8 +1429,9 @@ class ItemInfo: """ def __init__( - self, name=None, itemtype=None, unit=None, data_value_type="Instantaneous" - ): + self, name: str | EUMType | None =None, itemtype: EUMType | EUMUnit| None =None, unit: EUMUnit | None=None, + data_value_type:Literal["Instantaneous","Accumulated","StepAccumulated","MeanStepBackWard"]="Instantaneous" + ) -> None: # Handle arguments in the wrong place if isinstance(name, EUMType): @@ -1476,7 +1477,7 @@ def __init__( if not isinstance(name, str): raise ValueError("Invalid name, name should be a string") - self.name = name + self.name : str = name def __eq__(self, other): if not isinstance(other, ItemInfo): diff --git a/mikeio/generic.py b/mikeio/generic.py index 00d4bfd0f..5a76e6e65 100644 --- a/mikeio/generic.py +++ b/mikeio/generic.py @@ -5,20 +5,33 @@ from copy import deepcopy from datetime import datetime, timedelta from shutil import copyfile -from typing import Iterable, List, Optional, Sequence +from typing import Iterable, List, Sequence, Tuple, Union import numpy as np +from numpy.typing import NDArray import pandas as pd -from mikecore.DfsBuilder import DfsBuilder # type: ignore -from mikecore.DfsFile import DfsDynamicItemInfo, DfsFile # type: ignore -from mikecore.DfsFileFactory import DfsFileFactory # type: ignore -from mikecore.eum import eumQuantity # type: ignore +from mikecore.DfsBuilder import DfsBuilder +from mikecore.DfsFile import ( + DfsDynamicItemInfo, + DfsFile, + DfsEqTimeAxis, + DfsNonEqTimeAxis, + DfsEqCalendarAxis, + DfsNonEqCalendarAxis, +) +from mikecore.DfsFileFactory import DfsFileFactory +from mikecore.eum import eumQuantity from tqdm import tqdm, trange from . import __dfs_version__ from .dfs._dfs import _get_item_info, _valid_item_numbers from .eum import ItemInfo + +TimeAxis = Union[ + DfsEqTimeAxis, DfsNonEqTimeAxis, DfsEqCalendarAxis, DfsNonEqCalendarAxis +] + show_progress = True @@ -52,15 +65,14 @@ class _ChunkInfo: """ def __init__(self, n_data: int, n_chunks: int): - self.n_data = n_data self.n_chunks = n_chunks - def __repr__(self): + def __repr__(self) -> str: return f"_ChunkInfo(n_chunks={self.n_chunks}, n_data={self.n_data}, chunk_size={self.chunk_size})" @property - def chunk_size(self): + def chunk_size(self) -> int: """number of data points per chunk""" return math.ceil(self.n_data / self.n_chunks) @@ -68,7 +80,7 @@ def stop(self, start: int) -> int: """Return the stop index for a chunk""" return min(start + self.chunk_size, self.n_data) - def chunk_end(self, start): + def chunk_end(self, start: int) -> int: """Return the end index for a chunk""" e2 = self.stop(start) return self.chunk_size - ((start + self.chunk_size) - e2) @@ -80,7 +92,7 @@ def from_dfs( """Calculate chunk info based on # of elements in dfs file and selected buffer size""" n_time_steps = dfs.FileInfo.TimeAxis.NumberOfTimeSteps - n_data_all = np.sum([dfs.ItemInfo[i].ElementCount for i in item_numbers]) + n_data_all: int = np.sum([dfs.ItemInfo[i].ElementCount for i in item_numbers]) mem_need = 8 * n_time_steps * n_data_all # n_items * n_chunks = math.ceil(mem_need / buffer_size) n_data = n_data_all // len(item_numbers) @@ -91,9 +103,9 @@ def from_dfs( def _clone( infilename: str | pathlib.Path, outfilename: str | pathlib.Path, - start_time=None, - timestep=None, - items=None, + start_time: datetime | None = None, + timestep: float | None = None, + items: Sequence[int | str | DfsDynamicItemInfo] | None = None, ) -> DfsFile: """Clone a dfs file @@ -196,7 +208,7 @@ def scale( outfilename: str | pathlib.Path, offset: float = 0.0, factor: float = 1.0, - items: Optional[List[str] | List[int]] = None, + items: Sequence[int | str] | None = None, ) -> None: """Apply scaling to any dfs file @@ -228,7 +240,6 @@ def scale( for timestep in trange(n_time_steps, disable=not show_progress): for item in range(n_items): - itemdata = dfs.ReadItemTimeStep(item_numbers[item] + 1, timestep) time = itemdata.Time d = itemdata.Data @@ -248,7 +259,7 @@ def fill_corrupt( infilename: str | pathlib.Path, outfilename: str | pathlib.Path, fill_value: float = np.nan, - items: Optional[List[str] | List[int]] = None, + items: Sequence[str | int] | None = None, ) -> None: """ Replace corrupt (unreadable) data with fill_value, default delete value. @@ -282,7 +293,6 @@ def fill_corrupt( for timestep in trange(n_time_steps, disable=not show_progress): for item in range(n_items): - itemdata = dfs_i.ReadItemTimeStep(item_numbers[item] + 1, timestep) if itemdata is not None: time = itemdata.Time @@ -339,7 +349,6 @@ def sum( for timestep in trange(n_time_steps): for item in range(n_items): - itemdata_a = dfs_i_a.ReadItemTimeStep(item + 1, timestep) d_a = itemdata_a.Data d_a[d_a == deletevalue] = np.nan @@ -394,7 +403,6 @@ def diff( for timestep in trange(n_time_steps): for item in range(n_items): - itemdata_a = dfs_i_a.ReadItemTimeStep(item + 1, timestep) d_a = itemdata_a.Data d_a[d_a == deletevalue] = np.nan @@ -419,7 +427,7 @@ def diff( def concat( infilenames: Sequence[str | pathlib.Path], outfilename: str | pathlib.Path, - keep="last", + keep: str = "last", ) -> None: """Concatenates files along the time axis @@ -451,7 +459,6 @@ def concat( current_time = datetime(1, 1, 1) # beginning of time... for i, infilename in enumerate(tqdm(infilenames, disable=not show_progress)): - dfs_i = DfsFileFactory.DfsGenericOpen(str(infilename)) t_axis = dfs_i.FileInfo.TimeAxis n_time_steps = t_axis.NumberOfTimeSteps @@ -466,7 +473,6 @@ def concat( current_time = start_time if keep == "last": - if i < (len(infilenames) - 1): dfs_n = DfsFileFactory.DfsGenericOpen(str(infilenames[i + 1])) nf = dfs_n.FileInfo.TimeAxis.StartDateTime @@ -476,14 +482,12 @@ def concat( dfs_n.Close() for timestep in range(n_time_steps): - current_time = start_time + timedelta(seconds=timestep * dt) if i < (len(infilenames) - 1): if current_time >= next_start_time: break for item in range(n_items): - itemdata = dfs_i.ReadItemTimeStep(item + 1, timestep) d = itemdata.Data @@ -493,7 +497,6 @@ def concat( dfs_i.Close() if keep == "first": - if ( i == 0 ): # all timesteps in first file are kept (is there a more efficient way to do this without the loop?) @@ -501,7 +504,6 @@ def concat( current_time = start_time + timedelta(seconds=timestep * dt) for item in range(n_items): - itemdata = dfs_i.ReadItemTimeStep(item + 1, timestep) d = itemdata.Data @@ -523,7 +525,6 @@ def concat( current_time = start_time + timedelta(seconds=timestep * dt) for item in range(n_items): - itemdata = dfs_i.ReadItemTimeStep(item + 1, timestep) d = itemdata.Data @@ -541,10 +542,10 @@ def concat( def extract( infilename: str | pathlib.Path, outfilename: str | pathlib.Path, - start=0, - end=-1, - step=1, - items=None, + start: int = 0, + end: int = -1, + step: int = 1, + items: Sequence[int | str] | None = None, ) -> None: """Extract timesteps and/or items to a new dfs file @@ -580,9 +581,9 @@ def extract( is_layered_dfsu = dfs_i.ItemInfo[0].Name == "Z coordinate" file_start_new, start_step, start_sec, end_step, end_sec = _parse_start_end( - dfs_i, start, end + dfs_i.FileInfo.TimeAxis, start, end ) - timestep = _parse_step(dfs_i, step) + timestep = _parse_step(dfs_i.FileInfo.TimeAxis, step) item_numbers = _valid_item_numbers( dfs_i.ItemInfo, items, ignore_first=is_layered_dfsu ) @@ -629,18 +630,22 @@ def extract( dfs_o.Close() -def _parse_start_end(dfs_i, start, end): +def _parse_start_end( + time_axis: TimeAxis, + start: int | float | str | datetime, + end: int | float | str | datetime, +) -> Tuple[datetime | None, int, float, int, float]: # TODO better return type """Helper function for parsing start and end arguments""" - n_time_steps = dfs_i.FileInfo.TimeAxis.NumberOfTimeSteps - file_start_datetime = dfs_i.FileInfo.TimeAxis.StartDateTime - file_start_sec = dfs_i.FileInfo.TimeAxis.StartTimeOffset + n_time_steps = time_axis.NumberOfTimeSteps + file_start_datetime = time_axis.StartDateTime + file_start_sec = time_axis.StartTimeOffset start_sec = file_start_sec timespan = 0 - if dfs_i.FileInfo.TimeAxis.TimeAxisType == 3: - timespan = dfs_i.FileInfo.TimeAxis.TimeStep * (n_time_steps - 1) - elif dfs_i.FileInfo.TimeAxis.TimeAxisType == 4: - timespan = dfs_i.FileInfo.TimeAxis.TimeSpan + if time_axis.TimeAxisType == 3: + timespan = time_axis.TimeStep * (n_time_steps - 1) + elif time_axis.TimeAxisType == 4: + timespan = time_axis.TimeSpan else: raise ValueError("TimeAxisType not supported") @@ -695,26 +700,26 @@ def _parse_start_end(dfs_i, start, end): raise ValueError(f"end cannot be after end of file. end={end_sec} is invalid.") file_start_new = None - if dfs_i.FileInfo.TimeAxis.TimeAxisType == 3: - dt = dfs_i.FileInfo.TimeAxis.TimeStep + if time_axis.TimeAxisType == 3: + dt = time_axis.TimeStep if (start_sec > file_start_sec) and (start_step == 0): # we can find the coresponding step start_step = int((start_sec - file_start_sec) / dt) file_start_new = file_start_datetime + timedelta(seconds=start_step * dt) - elif dfs_i.FileInfo.TimeAxis.TimeAxisType == 4: + elif time_axis.TimeAxisType == 4: if start_sec > file_start_sec: file_start_new = file_start_datetime + timedelta(seconds=start_sec) return file_start_new, start_step, start_sec, end_step, end_sec -def _parse_step(dfs_i, step): +def _parse_step(time_axis: TimeAxis, step: int) -> float | None: """Helper function for parsing step argument""" if step == 1: timestep = None - elif dfs_i.FileInfo.TimeAxis.TimeAxisType == 3: - timestep = dfs_i.FileInfo.TimeAxis.TimeStep * step - elif dfs_i.FileInfo.TimeAxis.TimeAxisType == 4: + elif time_axis.TimeAxisType == 3: + timestep = time_axis.TimeStep * step + elif time_axis.TimeAxisType == 4: timestep = None else: raise ValueError("TimeAxisType not supported") @@ -724,8 +729,8 @@ def _parse_step(dfs_i, step): def avg_time( infilename: str | pathlib.Path, outfilename: str | pathlib.Path, - skipna=True, -): + skipna: bool = True, +) -> None: """Create a temporally averaged dfs file Parameters @@ -764,7 +769,6 @@ def avg_time( for timestep in trange(1, n_time_steps, disable=not show_progress): for item in range(n_items): - itemdata = dfs_i.ReadItemTimeStep(item_numbers[item] + 1, timestep) d = itemdata.Data has_value = d != deletevalue @@ -790,12 +794,12 @@ def avg_time( def quantile( infilename: str | pathlib.Path, outfilename: str | pathlib.Path, - q, + q: float | Sequence[float], *, - items=None, - skipna=True, - buffer_size=1.0e9, -): + items: Sequence[int | str] | None = None, + skipna: bool = True, + buffer_size: float = 1.0e9, +) -> None: """Create temporal quantiles of all items in dfs file Parameters @@ -843,8 +847,8 @@ def quantile( ci = _ChunkInfo.from_dfs(dfs_i, item_numbers, buffer_size) - qvec = [q] if np.isscalar(q) else q - qtxt = [f"Quantile {q}" for q in qvec] + qvec: Sequence[float] = [q] if isinstance(q, float) else q + qtxt = [f"Quantile {q!r}" for q in qvec] core_items = [dfs_i.ItemInfo[i] for i in item_numbers] items = _get_repeated_items(core_items, prefixes=qtxt) @@ -904,7 +908,7 @@ def quantile( dfs_o.Close() -def _read_item(dfs: DfsFile, item: int, timestep: int) -> np.ndarray: +def _read_item(dfs: DfsFile, item: int, timestep: int) -> NDArray[np.float64]: """Read item data from dfs file Parameters diff --git a/mikeio/pfs/_pfssection.py b/mikeio/pfs/_pfssection.py index 5f85f6654..d0b43d021 100644 --- a/mikeio/pfs/_pfssection.py +++ b/mikeio/pfs/_pfssection.py @@ -1,6 +1,7 @@ +from __future__ import annotations from datetime import datetime from types import SimpleNamespace -from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence +from typing import Any, Callable, List, Mapping, MutableMapping, Sequence import pandas as pd @@ -175,10 +176,10 @@ def update_recursive(self, key, value): def search( self, - text: Optional[str] = None, + text: str | None = None, *, - key: Optional[str] = None, - section: Optional[str] = None, + key: str | None = None, + section: str | None = None, param=None, case: bool = False, ): @@ -302,7 +303,6 @@ def _write_with_func(self, func: Callable, level: int = 0, newline: str = "\n"): """ lvl_prefix = " " for k, v in self.items(): - # check for empty sections if v is None: func(f"{lvl_prefix * level}[{k}]{newline}") @@ -350,7 +350,6 @@ def _prepare_value_for_write(self, v): """ # some crude checks and corrections if isinstance(v, str): - if len(v) > 5 and not ("PROJ" in v or " dict: d[key] = value.to_dict() return d - def to_dataframe(self, prefix: Optional[str] = None) -> pd.DataFrame: + def to_dataframe(self, prefix: str | None = None) -> pd.DataFrame: """Output enumerated subsections to a DataFrame Parameters diff --git a/mikeio/spatial/_FM_geometry.py b/mikeio/spatial/_FM_geometry.py index ba5b0eafb..a3aacf6c4 100644 --- a/mikeio/spatial/_FM_geometry.py +++ b/mikeio/spatial/_FM_geometry.py @@ -2,7 +2,7 @@ import warnings from collections import namedtuple from functools import cached_property -from typing import Collection, Optional, List +from typing import Collection, List, Any import numpy as np from mikecore.DfsuFile import DfsuFileType # type: ignore @@ -121,7 +121,6 @@ def _get_ax(ax=None, figsize=None): return ax def _plot_FM_map(self, ax, **kwargs): - if "title" not in kwargs: kwargs["title"] = "Bathymetry" @@ -236,7 +235,7 @@ def __init__( node_coordinates, element_table, codes=None, - projection=None, + projection: str = "LONG/LAT", dfsu_type=None, # TODO should this be mandatory? element_ids=None, node_ids=None, @@ -267,7 +266,6 @@ def __init__( self._reindex() def _check_elements(self, element_table, element_ids=None, validate=True): - if validate: max_node_id = self._node_ids.max() for i, e in enumerate(element_table): @@ -352,7 +350,7 @@ def __init__( node_coordinates, element_table, codes=None, - projection=None, + projection: str = "LONG/LAT", dfsu_type=DfsuFileType.Dfsu2D, # Reasonable default? element_ids=None, node_ids=None, @@ -374,7 +372,6 @@ def __init__( self.plot = _GeometryFMPlotter(self) def __str__(self) -> str: - return f"{self.type_name} ({self.n_elements} elements, {self.n_nodes} nodes)" def __repr__(self): @@ -565,7 +562,7 @@ def get_2d_interpolant( n_nearest: int = 5, extrapolate: bool = False, p: int = 2, - radius: Optional[float] = None, + radius: float | None = None, ): """IDW interpolant for list of coordinates @@ -595,16 +592,16 @@ def get_2d_interpolant( if n_nearest == 1: weights = np.ones(dists.shape) if not extrapolate: - weights[~self.contains(xy)] = np.nan # type: ignore + weights[~self.contains(xy)] = np.nan # type: ignore elif n_nearest > 1: weights = get_idw_interpolant(dists, p=p) if not extrapolate: - weights[~self.contains(xy), :] = np.nan # type: ignore + weights[~self.contains(xy), :] = np.nan # type: ignore else: ValueError("n_nearest must be at least 1") if radius is not None: - weights[dists > radius] = np.nan # type: ignore + weights[dists > radius] = np.nan # type: ignore return ids, weights @@ -636,9 +633,8 @@ def interp2d(self, data, elem_ids, weights=None, shape=None): """ return interp2d(data, elem_ids, weights, shape) - def _find_n_nearest_2d_elements(self, x, y=None, n=1): - - # TODO + def _find_n_nearest_2d_elements(self, x, y=None, n=1) -> tuple[Any, Any]: + # TODO return arguments in the same order than cKDTree.query? if n > self.n_elements: raise ValueError( @@ -655,7 +651,6 @@ def _find_n_nearest_2d_elements(self, x, y=None, n=1): return elem_id, d def _find_element_2d(self, coords: np.ndarray): - points_outside = [] coords = np.atleast_2d(coords) @@ -711,7 +706,6 @@ def _find_element_2d(self, coords: np.ndarray): return ids def _find_single_element_2d(self, x: float, y: float) -> int: - nc = self.node_coordinates few_nearest, _ = self._find_n_nearest_2d_elements( @@ -1106,9 +1100,8 @@ def _nodes_to_geometry(self, nodes) -> "GeometryFM2D" | GeometryPoint2D: def elements_to_geometry( self, elements: int | Collection[int], keepdims=False ) -> "GeometryFM2D" | GeometryPoint2D: - - if isinstance(elements, (int,np.integer)): - sel_elements : List[int] = [elements] + if isinstance(elements, (int, np.integer)): + sel_elements: List[int] = [elements] else: sel_elements = list(elements) if len(sel_elements) == 1 and not keepdims: @@ -1236,7 +1229,7 @@ def __init__( node_coordinates, element_table, codes=None, - projection=None, + projection: str = "LONG/LAT", dfsu_type=None, element_ids=None, node_ids=None, @@ -1264,7 +1257,7 @@ def __init__( node_coordinates, element_table, codes=None, - projection=None, + projection: str = "LONG/LAT", dfsu_type=None, element_ids=None, node_ids=None, @@ -1314,9 +1307,7 @@ class GeometryFMAreaSpectrum(_GeometryFMSpectrum): def isel(self, idx=None, axis="elements"): return self.elements_to_geometry(elements=idx) - def elements_to_geometry( - self, elements, keepdims=False - ): + def elements_to_geometry(self, elements, keepdims=False): """export a selection of elements to new flexible file geometry Parameters ---------- @@ -1364,9 +1355,7 @@ class GeometryFMLineSpectrum(_GeometryFMSpectrum): def isel(self, idx=None, axis="node"): return self._nodes_to_geometry(nodes=idx) - def _nodes_to_geometry( - self, nodes - ): + def _nodes_to_geometry(self, nodes): """export a selection of nodes to new flexible file geometry Note: takes only the elements for which all nodes are selected Parameters diff --git a/mikeio/spatial/_FM_geometry_layered.py b/mikeio/spatial/_FM_geometry_layered.py index 38d135494..260dbac4e 100644 --- a/mikeio/spatial/_FM_geometry_layered.py +++ b/mikeio/spatial/_FM_geometry_layered.py @@ -22,7 +22,7 @@ def __init__( node_coordinates, element_table, codes=None, - projection=None, + projection:str = "LONG/LAT", dfsu_type=DfsuFileType.Dfsu3DSigma, element_ids=None, node_ids=None, @@ -669,7 +669,7 @@ def __init__( node_coordinates, element_table, codes=None, - projection=None, + projection:str = "LONG/LAT", dfsu_type=DfsuFileType.Dfsu3DSigma, element_ids=None, node_ids=None, @@ -749,7 +749,7 @@ def __init__( node_coordinates, element_table, codes=None, - projection=None, + projection:str = "LONG/LAT", dfsu_type=None, element_ids=None, node_ids=None, diff --git a/mikeio/spatial/_geometry.py b/mikeio/spatial/_geometry.py index 56ab92b34..b4e7098ae 100644 --- a/mikeio/spatial/_geometry.py +++ b/mikeio/spatial/_geometry.py @@ -8,13 +8,12 @@ class _Geometry(ABC): - def __init__(self, projection=None) -> None: + def __init__(self, projection: str = "LONG/LAT") -> None: - # TODO should projection be a mandatory argument? - if projection and not MapProjection.IsValid(projection): + if not MapProjection.IsValid(projection): raise ValueError(f"{projection=} is not a valid projection string") - self._projstr = projection if projection else "LONG/LAT" + self._projstr = projection @property def projection_string(self) -> str: @@ -48,7 +47,7 @@ def __repr__(self): class GeometryPoint2D(_Geometry): - def __init__(self, x: float, y: float, projection=None): + def __init__(self, x: float, y: float, projection:str = "LONG/LAT"): super().__init__(projection) self.x = x self.y = y @@ -76,7 +75,7 @@ def to_shapely(self): class GeometryPoint3D(_Geometry): - def __init__(self, x: float, y: float, z: float, projection=None): + def __init__(self, x: float, y: float, z: float, projection:str = "LONG/LAT"): super().__init__(projection) self.x = x diff --git a/mikeio/spatial/_grid_geometry.py b/mikeio/spatial/_grid_geometry.py index 7bb20bd21..d7e2d9b6f 100644 --- a/mikeio/spatial/_grid_geometry.py +++ b/mikeio/spatial/_grid_geometry.py @@ -1,8 +1,9 @@ from __future__ import annotations import warnings -from typing import Optional, Sequence, Tuple +from typing import Sequence, Tuple from dataclasses import dataclass import numpy as np +from numpy.typing import ArrayLike, NDArray from mikecore.Projections import Cartography # type: ignore @@ -108,7 +109,7 @@ def __init__( axis_name="x", ): """Create equidistant 1D spatial geometry""" - super().__init__(projection) + super().__init__(projection=projection) self._origin = (0.0, 0.0) if origin is None else (origin[0], origin[1]) assert len(self._origin) == 2, "origin must be a tuple of length 2" self._orientation = orientation @@ -139,7 +140,6 @@ def find_index(self, x: float, **kwargs) -> int: return int(np.argmin(d)) def get_spatial_interpolant(self, coords, **kwargs): - x = coords[0][0] # TODO accept list of points assert self.nx > 1, "Interpolation not possible for Grid1D with one point" @@ -227,9 +227,11 @@ def isel( else: coords = self._nc[idx, :] if len(coords) == 3: - return GeometryPoint3D(*coords) + x, y, z = coords + return GeometryPoint3D(x=x, y=y, z=z, projection=self.projection) else: - return GeometryPoint2D(*coords) + x, y = coords + return GeometryPoint2D(x=x, y=y, projection=self.projection) class _Grid2DPlotter: @@ -348,17 +350,17 @@ class Grid2D(_Geometry): def __init__( self, *, - x: Optional[Sequence[float]] = None, + x: Sequence[float] | None = None, x0: float = 0.0, - dx: Optional[float] = None, - nx: Optional[int] = None, - y: Optional[Sequence[float]] = None, + dx: float | None = None, + nx: int | None = None, + y: Sequence[float] | None = None, y0: float = 0.0, - dy: Optional[float] = None, - ny: Optional[int] = None, + dy: float | None = None, + ny: int | None = None, bbox=None, projection="NON-UTM", - origin: Optional[Tuple[float, float]] = None, + origin: Tuple[float, float] | None = None, orientation=0.0, axis_names=("x", "y"), is_spectral=False, @@ -404,7 +406,7 @@ def __init__( y: [55, 55.25, 55.5] (ny=3, dy=0.25) projection: LONG/LAT """ - super().__init__(projection) + super().__init__(projection=projection) self._shift_origin_on_write = origin is None # user-constructed self._origin = (0.0, 0.0) if origin is None else (origin[0], origin[1]) assert len(self._origin) == 2, "origin must be a tuple of length 2" @@ -680,7 +682,7 @@ def _shift_x0y0_to_origin(self): self._x0, self._y0 = 0.0, 0.0 self._origin = (self._origin[0] + x0, self._origin[1] + y0) - def contains(self, coords): + def contains(self, coords: ArrayLike) -> NDArray[np.bool_]: """test if a list of points are inside grid Parameters @@ -701,13 +703,13 @@ def contains(self, coords): yinside = (self.bbox.bottom <= y) & (y <= self.bbox.top) return xinside & yinside - def __contains__(self, pt) -> bool: + def __contains__(self, pt) -> NDArray[np.bool_]: return self.contains(pt) def find_index( self, - x: Optional[float] = None, - y: Optional[float] = None, + x: float | None = None, + y: float | None = None, coords=None, area=None, ): @@ -777,7 +779,7 @@ def _xy_to_index(self, xy): return ii, jj def _bbox_to_index( - self, bbox: Tuple[float,float,float,float] | BoundingBox + self, bbox: Tuple[float, float, float, float] | BoundingBox ) -> Tuple[range, range]: """Find subarea within this geometry""" if not (len(bbox) == 4): @@ -799,9 +801,7 @@ def _bbox_to_index( return i, j - def isel( - self, idx, axis: int | str - ) -> "Grid2D" | "Grid1D" | "GeometryUndefined": + def isel(self, idx, axis: int | str) -> "Grid2D | Grid1D | GeometryUndefined": """Return a new geometry as a subset of Grid2D along the given axis.""" if isinstance(axis, str): if axis == "y": @@ -835,7 +835,7 @@ def isel( else: raise ValueError(f"axis must be 0 or 1 (or 'x' or 'y'), not {axis}") - def _index_to_Grid2D(self, ii=None, jj=None): + def _index_to_Grid2D(self, ii=None, jj=None) -> "Grid2D | GeometryUndefined": ii = range(self.nx) if ii is None else ii jj = range(self.ny) if jj is None else jj assert len(ii) > 1 and len(jj) > 1, "Index must be at least len 2" @@ -845,8 +845,9 @@ def _index_to_Grid2D(self, ii=None, jj=None): if (np.any(di < 1) or not np.allclose(di, di[0])) or ( np.any(dj < 1) or not np.allclose(dj, dj[0]) ): - warnings.warn("Axis not equidistant! Will return GeometryUndefined()") - return GeometryUndefined() + # warnings.warn("Axis not equidistant! Will return GeometryUndefined()") + raise ValueError() + # return GeometryUndefined() else: dx = self.dx * di[0] dy = self.dy * dj[0] @@ -877,7 +878,6 @@ def _index_to_Grid2D(self, ii=None, jj=None): ) def _to_element_table(self, index_base=0): - elem_table = [] for elx in range(self.nx - 1): # each col @@ -1027,8 +1027,7 @@ def __init__( origin: Tuple[float, float] = (0.0, 0.0), orientation=0.0, ) -> None: - - super().__init__() + super().__init__(projection=projection) self._origin = (0.0, 0.0) if origin is None else (origin[0], origin[1]) assert len(self._origin) == 2, "origin must be a tuple of length 2" self._x0, self._dx, self._nx = _parse_grid_axis("x", x, x0, dx, nx) diff --git a/pyproject.toml b/pyproject.toml index 47ef2ba17..2b98ff95d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,9 +51,11 @@ dev = ["pytest", "xarray", "netcdf4", "matplotlib", - "ruff"] + "ruff", + "mypy==1.6.1", + ] -test = ["pytest", "pytest-cov", "matplotlib!=3.5.0", "xarray","mypy","shapely","pyproj"] +test = ["pytest", "pytest-cov", "matplotlib!=3.5.0", "xarray","mypy==1.6.1","shapely","pyproj"] notebooks= [ "nbformat", @@ -81,6 +83,11 @@ warn_return_any = false allow_redefinition = true warn_unreachable = true + +[[tool.mypy.overrides]] +module = ["mikeio.spatial", "mikeio.dataset"] +warn_return_any = true + [[tool.mypy.overrides]] module = [ "mikecore.*", diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 72441d7d7..4e8d38271 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -4,7 +4,8 @@ # TODO this file tests private methods, Options: 1. declare methods as public, 2. Test at a higher level of abstraction -from mikeio.dataset._data_utils import DataUtilsMixin as du +from mikeio import DataArray as du +from mikeio.dataset._data_utils import _to_safe_name def test_parse_time_None(): @@ -72,16 +73,12 @@ def test_parse_time_decreasing(): def test_safe_name_noop(): - good_name = "MSLP" - assert du._to_safe_name(good_name) == good_name + assert _to_safe_name(good_name) == good_name def test_safe_name_bad(): - - # fmt: off - bad_name = "MSLP., 1:st level\n 2nd chain" - safe_name = "MSLP_1_st_level_2nd_chain" - assert du._to_safe_name(bad_name) == safe_name - # fmt : on + bad_name = "MSLP., 1:st level\n 2nd chain" + safe_name = "MSLP_1_st_level_2nd_chain" + assert _to_safe_name(bad_name) == safe_name diff --git a/tests/test_dataarray.py b/tests/test_dataarray.py index 6056bf7e9..9c9d02e10 100644 --- a/tests/test_dataarray.py +++ b/tests/test_dataarray.py @@ -4,7 +4,7 @@ import pytest import mikeio -from mikeio import EUMType, EUMUnit, ItemInfo +from mikeio import EUMType, EUMUnit, ItemInfo, Mesh from mikeio.exceptions import OutsideModelDomainError @@ -372,8 +372,8 @@ def test_dataarray_init_dfsu2d(): nt = 10 time = pd.date_range(start="2000-01-01", freq="S", periods=nt) filename = "tests/testdata/north_sea_2.mesh" - dfs = mikeio.open(filename) - g = dfs.geometry + msh = Mesh(filename) + g = msh.geometry ne = g.n_elements # time-varying @@ -844,7 +844,7 @@ def test_modify_values_1d(da1): assert da1.values[4] == 12.0 # values is scalar, therefore copy by definition. Original is not changed. - da1.isel(4).values = 11.0 + da1.isel(4).values = 11.0 # TODO is the treatment of scalar sensible, i.e. consistent with xarray? assert da1.values[4] != 11.0 # fancy indexing will return copy! Original is *not* changed. diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 19c7a252e..228ccab09 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -104,30 +104,6 @@ def test_properties(ds1): assert isinstance(ds1.items[0], ItemInfo) -def test_pop(ds1): - da = ds1.pop("Foo") - assert len(ds1) == 1 - assert ds1.names == ["Bar"] - assert isinstance(da, mikeio.DataArray) - assert da.name == "Foo" - - ds1["Foo2"] = da # re-insert - assert len(ds1) == 2 - - da = ds1.pop(-1) - assert len(ds1) == 1 - assert ds1.names == ["Bar"] - assert isinstance(da, mikeio.DataArray) - assert da.name == "Foo2" - - -def test_popitem(ds1): - da = ds1.popitem() - assert len(ds1) == 1 - assert ds1.names == ["Bar"] - assert isinstance(da, mikeio.DataArray) - assert da.name == "Foo" - def test_insert(ds1): da = ds1[0].copy() @@ -139,12 +115,6 @@ def test_insert(ds1): assert ds1[-1] == da -def test_insert_wrong_type(ds1): - - with pytest.raises(ValueError): - ds1["Foo"] = "Bar" - - def test_insert_fail(ds1): da = ds1[0] with pytest.raises(ValueError, match="Cannot add the same object"): diff --git a/tests/test_generic.py b/tests/test_generic.py index 3c4ec7624..02fda0bc1 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -550,10 +550,10 @@ def test_dfs_ext_capitalisation(tmp_path): ds = mikeio.open(filename) ds = mikeio.read(filename) ds.to_dfs(tmp_path / "void.DFS0") - filename = "tests/testdata/odense_rough2.MESH" - ds = mikeio.open(filename) filename = "tests/testdata/oresund_vertical_slice2.DFSU" ds = mikeio.open(filename) + filename = "tests/testdata/odense_rough2.MESH" + ds = mikeio.open(filename) assert True