Skip to content

Commit

Permalink
Merge pull request #613 from DHI/fix_types
Browse files Browse the repository at this point in the history
More static types
  • Loading branch information
ecomodeller authored Dec 29, 2023
2 parents 0b76745 + df2b4a9 commit 149b20c
Show file tree
Hide file tree
Showing 28 changed files with 719 additions and 670 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
"restructuredtext.confPath": "${workspaceFolder}\\docs",
"files.eol": "\n",
"python.formatting.provider": "black",
"editor.formatOnSave": true
}
17 changes: 13 additions & 4 deletions mikeio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations
from pathlib import Path
from platform import architecture
from typing import Sequence

# PEP0440 compatible formatted version, see:
# https://www.python.org/dev/peps/pep-0440/
Expand Down Expand Up @@ -43,8 +45,14 @@
from .xyz import read_xyz



def read(filename, *, items=None, time=None, keepdims=False, **kwargs) -> Dataset:
def read(
filename: str | Path,
*,
items: str | int | Sequence[str |int] | None = None,
time: int | str | slice | None = None,
keepdims: bool = False,
**kwargs,
) -> Dataset:
"""Read all or a subset of the data from a dfs file
All dfs files can be subsetted with the *items* and *time* arguments. But
Expand Down Expand Up @@ -127,7 +135,7 @@ def read(filename, *, items=None, time=None, keepdims=False, **kwargs) -> Datase
return dfs.read(items=items, time=time, keepdims=keepdims, **kwargs)


def open(filename: str, **kwargs):
def open(filename: str | Path, **kwargs):
"""Open a dfs/mesh file (and read the header)
The typical workflow for small dfs files is to read all data
Expand Down Expand Up @@ -177,6 +185,7 @@ def open(filename: str, **kwargs):

return reader_klass(filename, **kwargs)


__all__ = [
"DataArray",
"Dataset",
Expand All @@ -199,4 +208,4 @@ def open(filename: str, **kwargs):
"read_xyz",
"read",
"open",
]
]
13 changes: 9 additions & 4 deletions mikeio/_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations
from typing import Tuple, TYPE_CHECKING
import numpy as np
from numpy.typing import NDArray

if TYPE_CHECKING:
from .dataset import Dataset, DataArray

from .spatial import GeometryUndefined


# class Interpolation2D:
def get_idw_interpolant(distances, p=2):
def get_idw_interpolant(distances:NDArray[np.floating], p:float=2) -> NDArray[np.floating]:
"""IDW interpolant for 2d array of distances
https://pro.arcgis.com/en/pro-app/help/analysis/geostatistical-analyst/how-inverse-distance-weighted-interpolation-works.htm
Expand Down Expand Up @@ -40,7 +45,7 @@ def get_idw_interpolant(distances, p=2):
return weights


def interp2d(data, elem_ids, weights=None, shape=None):
def interp2d(data: NDArray[np.floating] | Dataset | DataArray, elem_ids: NDArray[np.integer], weights:NDArray[np.floating] | None=None, shape:Tuple[int,...]|None=None) -> NDArray[np.floating] | Dataset:
"""interp spatially in data (2d only)
Parameters
Expand Down Expand Up @@ -124,7 +129,7 @@ def interp2d(data, elem_ids, weights=None, shape=None):
return idatitem


def _interp_itemstep(data, elem_ids, weights=None):
def _interp_itemstep(data: NDArray[np.floating], elem_ids: NDArray[np.integer], weights:NDArray[np.floating] | None =None) -> NDArray[np.floating]:
"""Interpolate a single item and time step
Parameters
Expand Down
67 changes: 37 additions & 30 deletions mikeio/_spectral.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
from __future__ import annotations
from typing import Sequence, Tuple, Literal
import numpy as np
from matplotlib.axes import Axes
from matplotlib.projections.polar import PolarAxes
from numpy.typing import NDArray


def plot_2dspectrum(
spectrum,
frequencies,
directions,
plot_type="contourf",
title=None,
label=None,
cmap="Reds",
vmin=1e-5,
vmax=None,
r_as_periods=True,
rmin=None,
rmax=None,
levels=None,
figsize=(7, 7),
add_colorbar=True,
):
spectrum :NDArray[np.floating],
frequencies :NDArray[np.floating],
directions :NDArray[np.floating],
plot_type: str ="contourf",
title:str | None =None,
label:str | None=None,
cmap:str="Reds",
vmin:float | None =1e-5,
vmax:float | None =None,
r_as_periods:bool=True,
rmin:float | None=None,
rmax:float | None=None,
levels: int| Sequence[float] | None=None,
figsize: Tuple[float,float]=(7, 7),
add_colorbar:bool=True,
) -> Axes:
"""
Plot spectrum in polar coordinates
Expand Down Expand Up @@ -67,7 +72,7 @@ def plot_2dspectrum(
>>> dfs.plot_spectrum(spectrum, rmax=9, title="Wave spectrum T<9s")
"""

import matplotlib.pyplot as plt # type: ignore
import matplotlib.pyplot as plt

if (frequencies is None or len(frequencies) <= 1) and (
directions is None or len(directions) <= 1
Expand All @@ -83,13 +88,14 @@ def plot_2dspectrum(
spectrum = np.fliplr(spectrum)

fig = plt.figure(figsize=figsize)
ax = plt.subplot(111, polar=True)
ax = plt.subplot(111, polar=True) # type: ignore
assert isinstance(ax, PolarAxes)
ax.set_theta_direction(-1)
ax.set_theta_zero_location("N")

ddir = dirs[1] - dirs[0]

def is_circular(dir):
def is_circular(dir: NDArray[np.floating]) -> bool:
dir_diff = np.mod(dir[0], 2 * np.pi) - np.mod(dir[-1] + ddir, 2 * np.pi)
return np.abs(dir_diff) < 1e-6

Expand All @@ -116,9 +122,9 @@ def is_circular(dir):
if levels is None:
levels = 10
n_levels = 10
if np.isscalar(levels):
if isinstance(levels, int):
n_levels = levels
levels = np.linspace(vmin, vmax, n_levels)
levels = np.linspace(vmin, vmax, n_levels) # type: ignore

if plot_type != "shaded":
spectrum[spectrum < vmin] = np.nan
Expand All @@ -128,17 +134,17 @@ def is_circular(dir):
dirs, freq, spectrum.T, levels=levels, cmap=cmap, vmin=vmin, vmax=vmax
)
elif plot_type == "contour":
colorax = ax.contour(
colorax = ax.contour( # type: ignore
dirs, freq, spectrum.T, levels=levels, cmap=cmap, vmin=vmin, vmax=vmax
)
# ax.clabel(colorax, fmt="%1.2f", inline=1, fontsize=9)
if label is not None:
ax.set_title(label)

elif plot_type in ("patch", "shaded", "box"):
shading = "gouraud" if plot_type == "shaded" else "auto"
shading: Literal['flat', 'nearest', 'gouraud', 'auto'] = "gouraud" if plot_type == "shaded" else "auto"
ax.grid(False) # Remove major grid
colorax = ax.pcolormesh(
colorax = ax.pcolormesh( # type: ignore
dirs,
freq,
spectrum.T,
Expand All @@ -147,14 +153,14 @@ def is_circular(dir):
vmin=vmin,
vmax=vmax,
)
ax.grid("on")
ax.grid("on") # type: ignore
else:
raise ValueError(
f"plot_type '{plot_type}' not supported (contour, contourf, patch, shaded)"
)

# TODO: optional
ax.set_thetagrids(
ax.set_thetagrids( # type: ignore
[0.0, 45, 90.0, 135, 180.0, 225, 270.0, 315],
labels=["N", "N-E", "E", "S-E", "S", "S-W", "W", "N-W"],
)
Expand All @@ -171,9 +177,9 @@ def is_circular(dir):
# ax.set_xticks(dfs.directions, minor=True);

if rmin is not None:
ax.set_rmin(rmin)
ax.set_rmin(rmin) # type: ignore
if rmax is not None:
ax.set_rmax(rmax)
ax.set_rmax(rmax) # type: ignore

if add_colorbar:
cbar = fig.colorbar(colorax)
Expand All @@ -188,8 +194,9 @@ def is_circular(dir):
return ax


def calc_m0_from_spectrum(spec, f, dir=None, tail=True):
def calc_m0_from_spectrum(spec: NDArray[np.floating], f: NDArray[np.floating] | None, dir: NDArray[np.floating] | None =None, tail:bool=True) -> NDArray[np.floating]:
if f is None:
assert dir is not None
nd = len(dir)
dtheta = (dir[-1] - dir[0]) / (nd - 1)
return np.sum(spec, axis=-1) * dtheta * np.pi / 180.0
Expand All @@ -208,7 +215,7 @@ def calc_m0_from_spectrum(spec, f, dir=None, tail=True):
return m0


def _f_to_df(f):
def _f_to_df(f: NDArray[np.floating]) -> NDArray[np.floating]:
"""Frequency bins for equidistant or logrithmic frequency axis"""
if np.isclose(np.diff(f).min(), np.diff(f).max()):
# equidistant frequency bins
Expand Down
6 changes: 2 additions & 4 deletions mikeio/_time.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime
from dataclasses import dataclass
from typing import List, Iterable, Optional
from typing import List, Iterable

import pandas as pd

Expand All @@ -14,9 +14,7 @@ class DateTimeSelector:

def isel(
self,
x: Optional[
int | Iterable[int] | str | datetime | pd.DatetimeIndex | slice
] = None,
x: int | Iterable[int] | str | datetime | pd.DatetimeIndex | slice | None = None,
) -> List[int]:
"""Select time steps from a pandas DatetimeIndex
Expand Down
5 changes: 2 additions & 3 deletions mikeio/_track.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
from pathlib import Path
from datetime import datetime
from typing import Callable, Sequence, Tuple

import numpy as np
Expand All @@ -15,8 +14,8 @@
def _extract_track(
*,
deletevalue: float,
start_time: datetime,
end_time: datetime,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
timestep: float,
geometry: GeometryFM2D,
track: str | Dataset | pd.DataFrame,
Expand Down
14 changes: 10 additions & 4 deletions mikeio/dataset/_data_plot.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from __future__ import annotations
from typing import Any, Tuple, TYPE_CHECKING

import numpy as np
from matplotlib.axes import Axes

from ..spatial._FM_utils import _plot_map, _plot_vertical_profile

from .._spectral import plot_2dspectrum

if TYPE_CHECKING:
from ..dataset import DataArray

class _DataArrayPlotter:
"""Context aware plotter (sensible plotting according to geometry)"""

def __init__(self, da) -> None:
def __init__(self, da: "DataArray") -> None:
self.da = da

def __call__(self, ax=None, figsize=None, **kwargs):
def __call__(self, ax:Axes | None=None, figsize=None, **kwargs):
"""Plot DataArray according to geometry
Parameters
Expand Down Expand Up @@ -59,7 +65,7 @@ def _get_fig_ax(ax=None, figsize=None):
fig = plt.gcf()
return fig, ax

def hist(self, ax=None, figsize=None, title=None, **kwargs):
def hist(self, ax: Axes| None=None, figsize:Tuple[float,float] | None=None, title: str | None=None, **kwargs: Any) -> Axes:
"""Plot DataArray as histogram (using ax.hist)
Parameters
Expand Down Expand Up @@ -91,7 +97,7 @@ def hist(self, ax=None, figsize=None, title=None, **kwargs):
ax.set_title(title)
return self._hist(ax, **kwargs)

def _hist(self, ax, **kwargs):
def _hist(self, ax: Axes, **kwargs):
result = ax.hist(self.da.values.ravel(), **kwargs)
ax.set_xlabel(self._label_txt())
return result
Expand Down
Loading

0 comments on commit 149b20c

Please sign in to comment.