Skip to content

Commit

Permalink
fix: missed EMD_matlab with the numpy fix
Browse files Browse the repository at this point in the history
  • Loading branch information
laszukdawid committed Sep 11, 2024
1 parent bd64d2e commit bdf7c0b
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 29 deletions.
14 changes: 2 additions & 12 deletions PyEMD/EMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.interpolate import interp1d

from PyEMD.splines import akima, cubic, cubic_hermite, cubic_spline_3pts, pchip
from PyEMD.utils import deduce_common_type, get_timeline
from PyEMD.utils import get_timeline, unify_types

FindExtremaOutput = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]

Expand Down Expand Up @@ -762,16 +762,6 @@ def check_imf(

return False

@staticmethod
def _common_dtype(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Casts inputs (x, y) into a common numpy DTYPE."""
dtype = deduce_common_type(x.dtype, y.dtype)
if x.dtype != dtype:
x = x.astype(dtype)
if y.dtype != dtype:
y = y.astype(dtype)
return x, y

@staticmethod
def _normalize_time(t: np.ndarray) -> np.ndarray:
"""
Expand Down Expand Up @@ -815,7 +805,7 @@ def emd(self, S: np.ndarray, T: Optional[np.ndarray] = None, max_imf: int = -1)
T = self._normalize_time(T)

# Make sure same types are dealt
S, T = self._common_dtype(S, T)
S, T = unify_types(S, T)
self.DTYPE = S.dtype
N = len(S)

Expand Down
13 changes: 2 additions & 11 deletions PyEMD/EMD_matlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from scipy.interpolate import interp1d

from PyEMD.splines import akima
from PyEMD.utils import deduce_common_type
from PyEMD.utils import unify_types


class EMD:
Expand Down Expand Up @@ -429,15 +429,6 @@ def stop_sifting(self, imf, envMax, envMin, mean, extNo):
return False

@staticmethod
def _common_dtype(x, y):
dtype = deduce_common_type([x.dtype, y.dtype], [])
if x.dtype != dtype:
x = x.astype(dtype)
if y.dtype != dtype:
y = y.astype(dtype)

return x, y

def emd(self, S, T=None, maxImf=None):
"""
Performs Empirical Mode Decomposition on signal S.
Expand Down Expand Up @@ -466,7 +457,7 @@ def emd(self, S, T=None, maxImf=None):
maxImf = -1

# Make sure same types are dealt
S, T = self._common_dtype(S, T)
S, T = unify_type(S, T)
self.DTYPE = S.dtype

Res = S.astype(self.DTYPE)
Expand Down
2 changes: 1 addition & 1 deletion PyEMD/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

__version__ = "1.6.2"
__version__ = "1.6.3"
logger = logging.getLogger("pyemd")

from PyEMD.CEEMDAN import CEEMDAN # noqa
Expand Down
6 changes: 3 additions & 3 deletions PyEMD/experimental/jitemd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from numba.types import float64, int64, unicode_type
from scipy.interpolate import Akima1DInterpolator, interp1d

from PyEMD.utils import deduce_common_type
# from PyEMD.utils import unify_types

FindExtremaOutput = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]

Expand Down Expand Up @@ -802,7 +802,7 @@ def emd(
# T = _normalize_time(T)

# Make sure same types are dealt
# S, T = deduce_common_Types(S, T)
# S, T = unify_types(S, T)
MAX_ITERATION = config["MAX_ITERATION"]
FIXE = config["FIXE"]
FIXE_H = config["FIXE_H"]
Expand Down Expand Up @@ -992,7 +992,7 @@ def smallest_inclusive_dtype(ref_dtype: np.dtype, ref_value) -> np.dtype:
print("Input S.dtype: " + str(S.dtype))

# Prepare and run EMD
config = EmdConfig()
config = default_emd_config
imfs = emd(config, S, T, max_imf)
imfNo = imfs.shape[0]

Expand Down
9 changes: 8 additions & 1 deletion PyEMD/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from PyEMD.utils import deduce_common_type, get_timeline
from PyEMD.utils import deduce_common_type, get_timeline, unify_types


class MyTestCase(unittest.TestCase):
Expand Down Expand Up @@ -36,6 +36,13 @@ def test_deduce_common_types(self):
self.assertEqual(deduce_common_type(np.int32, np.int16), np.int32)
self.assertEqual(deduce_common_type(np.int32, np.int32), np.int32)
self.assertEqual(deduce_common_type(np.float32, np.float64), np.float64)

def test_unify_types(self):
x = np.array([1, 2, 3], dtype=np.int16)
y = np.array([1.1, 2.2, 3.3], dtype=np.float32)
x, y = unify_types(x, y)
self.assertEqual(x.dtype, np.float32)
self.assertEqual(y.dtype, np.float32)


if __name__ == "__main__":
Expand Down
11 changes: 10 additions & 1 deletion PyEMD/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Optional
from typing import Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -67,3 +67,12 @@ def deduce_common_type(xtype: np.dtype, ytype: np.dtype) -> np.dtype:
else:
dtype = np.promote_types(xtype, ytype)
return dtype

def unify_types(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
dtype = deduce_common_type(x.dtype, y.dtype)
if x.dtype != dtype:
x = x.astype(dtype)
if y.dtype != dtype:
y = y.astype(dtype)

return x, y

0 comments on commit bdf7c0b

Please sign in to comment.